|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | from collections.abc import Callable, Sequence |
18 | | -import functools |
19 | 18 | from functools import partial |
20 | 19 | import logging |
21 | 20 | from typing import Any |
@@ -203,7 +202,7 @@ def policy(prim, *args, **params): |
203 | 202 | def checkpoint(fun: Callable, *, prevent_cse: bool = True, |
204 | 203 | policy: Callable[..., bool] | None = None, |
205 | 204 | static_argnums: int | tuple[int, ...] = (), |
206 | | - ) -> Callable: |
| 205 | + concrete: bool = False) -> Callable: |
207 | 206 | """Make ``fun`` recompute internal linearization points when differentiated. |
208 | 207 |
|
209 | 208 | The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a |
@@ -257,6 +256,8 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True, |
257 | 256 | returns a boolean indicating whether the corresponding output value(s) can |
258 | 257 | be saved as residuals (or instead must be recomputed in the (co)tangent |
259 | 258 | computation if needed). |
| 259 | + concrete: Optional boolean; deprecated. Passing a non-False value will |
| 260 | + result in NotImplementedError. |
260 | 261 |
|
261 | 262 | Returns: |
262 | 263 | A function (callable) with the same input/output behavior as ``fun`` but |
@@ -344,6 +345,11 @@ def foo(x, y): |
344 | 345 | ``jax.ensure_compile_time_eval``), it may be easier to compute some values |
345 | 346 | outside the :func:`jax.checkpoint`-decorated function and then close over them. |
346 | 347 | """ |
| 348 | + if concrete: |
| 349 | + raise NotImplementedError( |
| 350 | + "The concrete option to jax.checkpoint has been deprecated. In its" |
| 351 | + " place please use ``static_argnums``; for details refer to" |
| 352 | + " https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html.") |
347 | 353 | if isinstance(static_argnums, int): |
348 | 354 | static_argnums = static_argnums, |
349 | 355 | if isinstance(prevent_cse, list): |
@@ -373,7 +379,14 @@ def fun_remat(*args, **kwargs): |
373 | 379 | return tree_unflatten(out_tree, out_flat) |
374 | 380 | return fun_remat |
375 | 381 |
|
376 | | -remat = checkpoint # alias |
| 382 | + |
| 383 | +def remat(fun: Callable, *, prevent_cse: bool = True, |
| 384 | + policy: Callable[..., bool] | None = None, |
| 385 | + static_argnums: int | tuple[int, ...] = (), |
| 386 | + concrete: bool = False) -> Callable: |
| 387 | + """Alias of :func:`jax.checkpoint`.""" |
| 388 | + return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, |
| 389 | + static_argnums=static_argnums, concrete=concrete) |
377 | 390 |
|
378 | 391 | # This function is similar to api_util.argnums_partial, except the error |
379 | 392 | # messages are specific to jax.remat (and thus more actionable), the |
@@ -894,66 +907,6 @@ def name_batcher(args, dims, *, name): |
894 | 907 | batching.primitive_batchers[name_p] = name_batcher |
895 | 908 |
|
896 | 909 |
|
897 | | -@functools.wraps(checkpoint) |
898 | | -def checkpoint_wrapper( |
899 | | - fun: Callable, |
900 | | - *, |
901 | | - concrete: bool = False, |
902 | | - prevent_cse: bool = True, |
903 | | - static_argnums: int | tuple[int, ...] = (), |
904 | | - policy: Callable[..., bool] | None = None, |
905 | | -) -> Callable: |
906 | | - if concrete: |
907 | | - msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; " |
908 | | - "in its place, you can use its `static_argnums` option, and if " |
909 | | - "necessary the `jax.ensure_compile_time_eval()` context manager.\n" |
910 | | - "\n" |
911 | | - "For example, if using `concrete=True` for an `is_training` flag:\n" |
912 | | - "\n" |
913 | | - " from functools import partial\n" |
914 | | - "\n" |
915 | | - " @partial(jax.checkpoint, concrete=True)\n" |
916 | | - " def foo(x, is_training):\n" |
917 | | - " if is_training:\n" |
918 | | - " return f(x)\n" |
919 | | - " else:\n" |
920 | | - " return g(x)\n" |
921 | | - "\n" |
922 | | - "replace it with a use of `static_argnums`:\n" |
923 | | - "\n" |
924 | | - " @partial(jax.checkpoint, static_argnums=(1,))\n" |
925 | | - " def foo(x, is_training):\n" |
926 | | - " ...\n" |
927 | | - "\n" |
928 | | - "If jax.numpy operations need to be performed on static arguments, " |
929 | | - "we can use the `jax.ensure_compile_time_eval()` context manager. " |
930 | | - "For example, we can replace this use of `concrete=True`\n:" |
931 | | - "\n" |
932 | | - " @partial(jax.checkpoint, concrete=True)\n" |
933 | | - " def foo(x, y):\n" |
934 | | - " if y > 0:\n" |
935 | | - " return f(x)\n" |
936 | | - " else:\n" |
937 | | - " return g(x)\n" |
938 | | - "\n" |
939 | | - "with this combination of `static_argnums` and " |
940 | | - "`jax.ensure_compile_time_eval()`:\n" |
941 | | - "\n" |
942 | | - " @partial(jax.checkpoint, static_argnums=(1,))\n" |
943 | | - " def foo(x, y):\n" |
944 | | - " with jax.ensure_compile_time_eval():\n" |
945 | | - " y_pos = y > 0\n" |
946 | | - " if y_pos:\n" |
947 | | - " return f(x)\n" |
948 | | - " else:\n" |
949 | | - " return g(x)\n" |
950 | | - "\n" |
951 | | - "See https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html\n") |
952 | | - raise NotImplementedError(msg) |
953 | | - return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, |
954 | | - static_argnums=static_argnums) |
955 | | - |
956 | | - |
957 | 910 | @discharge.register_discharge_rule(remat_p) |
958 | 911 | def _remat_state_discharge_rule( |
959 | 912 | in_avals, out_avals, *args, jaxpr, **params): |
|
0 commit comments