Skip to content

Commit e4cadda

Browse files
Merge pull request #33660 from jakevdp:ad_checkpoint
PiperOrigin-RevId: 839358516
2 parents 39f36f5 + 360a571 commit e4cadda

15 files changed

+148
-177
lines changed

docs/gradient-checkpointing.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ Another policy which refers to names is `jax.checkpoint_policies.save_only_these
359359
You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory. `jax.checkpoint_policies.offload_dot_with_no_batch_dims` can offload the results of matrix multiplications with no batch dimensions to the CPU.
360360

361361
```{code-cell}
362-
from jax.ad_checkpoint import checkpoint
362+
from jax import checkpoint
363363
364364
def checkpoint_offload_dot_with_no_batch_dims(self):
365365
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
@@ -380,7 +380,8 @@ def checkpoint_offload_dot_with_no_batch_dims(self):
380380
One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.
381381

382382
```{code-cell}
383-
from jax.ad_checkpoint import checkpoint, checkpoint_name
383+
from jax import checkpoint
384+
from jax.ad_checkpoint import checkpoint_name
384385
from jax._src import test_util as jtu
385386
386387
def checkpoint_names_saved_offloaded_recomputed(self):

jax/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@
8585
from jax._src.core import typeof as typeof
8686
from jax._src.api import effects_barrier as effects_barrier
8787
from jax._src.api import block_until_ready as block_until_ready
88-
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
88+
from jax._src.ad_checkpoint import checkpoint as checkpoint
8989
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
90+
from jax._src.ad_checkpoint import remat as remat
9091
from jax._src.api import clear_caches as clear_caches
9192
from jax._src.api import copy_to_host_async as copy_to_host_async
9293
from jax._src.custom_derivatives import closure_convert as closure_convert
@@ -127,7 +128,6 @@
127128
from jax._src.xla_bridge import process_index as process_index
128129
from jax._src.xla_bridge import process_indices as process_indices
129130
from jax._src.callback import pure_callback as pure_callback
130-
from jax._src.ad_checkpoint import checkpoint_wrapper as remat # noqa: F401
131131
from jax._src.core import ShapeDtypeStruct as ShapeDtypeStruct
132132
from jax._src.api import value_and_grad as value_and_grad
133133
from jax._src.api import vjp as vjp

jax/_src/ad_checkpoint.py

Lines changed: 16 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
from collections.abc import Callable, Sequence
18-
import functools
1918
from functools import partial
2019
import logging
2120
from typing import Any
@@ -203,7 +202,7 @@ def policy(prim, *args, **params):
203202
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
204203
policy: Callable[..., bool] | None = None,
205204
static_argnums: int | tuple[int, ...] = (),
206-
) -> Callable:
205+
concrete: bool = False) -> Callable:
207206
"""Make ``fun`` recompute internal linearization points when differentiated.
208207
209208
The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a
@@ -257,6 +256,8 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
257256
returns a boolean indicating whether the corresponding output value(s) can
258257
be saved as residuals (or instead must be recomputed in the (co)tangent
259258
computation if needed).
259+
concrete: Optional boolean; deprecated. Passing a non-False value will
260+
result in NotImplementedError.
260261
261262
Returns:
262263
A function (callable) with the same input/output behavior as ``fun`` but
@@ -344,6 +345,11 @@ def foo(x, y):
344345
``jax.ensure_compile_time_eval``), it may be easier to compute some values
345346
outside the :func:`jax.checkpoint`-decorated function and then close over them.
346347
"""
348+
if concrete:
349+
raise NotImplementedError(
350+
"The concrete option to jax.checkpoint has been deprecated. In its"
351+
" place please use ``static_argnums``; for details refer to"
352+
" https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html.")
347353
if isinstance(static_argnums, int):
348354
static_argnums = static_argnums,
349355
if isinstance(prevent_cse, list):
@@ -373,7 +379,14 @@ def fun_remat(*args, **kwargs):
373379
return tree_unflatten(out_tree, out_flat)
374380
return fun_remat
375381

376-
remat = checkpoint # alias
382+
383+
def remat(fun: Callable, *, prevent_cse: bool = True,
384+
policy: Callable[..., bool] | None = None,
385+
static_argnums: int | tuple[int, ...] = (),
386+
concrete: bool = False) -> Callable:
387+
"""Alias of :func:`jax.checkpoint`."""
388+
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
389+
static_argnums=static_argnums, concrete=concrete)
377390

378391
# This function is similar to api_util.argnums_partial, except the error
379392
# messages are specific to jax.remat (and thus more actionable), the
@@ -894,66 +907,6 @@ def name_batcher(args, dims, *, name):
894907
batching.primitive_batchers[name_p] = name_batcher
895908

896909

897-
@functools.wraps(checkpoint)
898-
def checkpoint_wrapper(
899-
fun: Callable,
900-
*,
901-
concrete: bool = False,
902-
prevent_cse: bool = True,
903-
static_argnums: int | tuple[int, ...] = (),
904-
policy: Callable[..., bool] | None = None,
905-
) -> Callable:
906-
if concrete:
907-
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
908-
"in its place, you can use its `static_argnums` option, and if "
909-
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
910-
"\n"
911-
"For example, if using `concrete=True` for an `is_training` flag:\n"
912-
"\n"
913-
" from functools import partial\n"
914-
"\n"
915-
" @partial(jax.checkpoint, concrete=True)\n"
916-
" def foo(x, is_training):\n"
917-
" if is_training:\n"
918-
" return f(x)\n"
919-
" else:\n"
920-
" return g(x)\n"
921-
"\n"
922-
"replace it with a use of `static_argnums`:\n"
923-
"\n"
924-
" @partial(jax.checkpoint, static_argnums=(1,))\n"
925-
" def foo(x, is_training):\n"
926-
" ...\n"
927-
"\n"
928-
"If jax.numpy operations need to be performed on static arguments, "
929-
"we can use the `jax.ensure_compile_time_eval()` context manager. "
930-
"For example, we can replace this use of `concrete=True`\n:"
931-
"\n"
932-
" @partial(jax.checkpoint, concrete=True)\n"
933-
" def foo(x, y):\n"
934-
" if y > 0:\n"
935-
" return f(x)\n"
936-
" else:\n"
937-
" return g(x)\n"
938-
"\n"
939-
"with this combination of `static_argnums` and "
940-
"`jax.ensure_compile_time_eval()`:\n"
941-
"\n"
942-
" @partial(jax.checkpoint, static_argnums=(1,))\n"
943-
" def foo(x, y):\n"
944-
" with jax.ensure_compile_time_eval():\n"
945-
" y_pos = y > 0\n"
946-
" if y_pos:\n"
947-
" return f(x)\n"
948-
" else:\n"
949-
" return g(x)\n"
950-
"\n"
951-
"See https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html\n")
952-
raise NotImplementedError(msg)
953-
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
954-
static_argnums=static_argnums)
955-
956-
957910
@discharge.register_discharge_rule(remat_p)
958911
def _remat_state_discharge_rule(
959912
in_avals, out_avals, *args, jaxpr, **params):

jax/ad_checkpoint.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,36 @@
1313
# limitations under the License.
1414

1515
from jax._src.ad_checkpoint import (
16-
checkpoint as checkpoint,
16+
checkpoint as _deprecated_checkpoint,
1717
checkpoint_policies as checkpoint_policies,
1818
checkpoint_name as checkpoint_name,
1919
print_saved_residuals as print_saved_residuals,
20-
remat as remat,
2120
)
2221
from jax._src.interpreters.partial_eval import (
2322
Recompute as Recompute,
2423
Saveable as Saveable,
2524
Offloadable as Offloadable,
2625
)
26+
27+
_deprecations = {
28+
# Added for v0.8.2
29+
"checkpoint": (
30+
"jax.ad_checkpoint.checkpoint is deprecated; use jax.checkpoint instead.",
31+
_deprecated_checkpoint
32+
),
33+
"remat": (
34+
"jax.ad_checkpoint.remat is deprecated; use jax.remat instead.",
35+
_deprecated_checkpoint
36+
),
37+
}
38+
39+
import typing as _typing
40+
if _typing.TYPE_CHECKING:
41+
checkpoint = _deprecated_checkpoint
42+
remat = _deprecated_checkpoint
43+
else:
44+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
45+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
46+
del _deprecation_getattr
47+
del _typing
48+
del _deprecated_checkpoint

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ def f(x1):
831831
x3 = jnp.sin(x2)
832832
x4 = jnp.sin(x3)
833833
return x4
834-
remat_f = ad_checkpoint.checkpoint(f)
834+
remat_f = jax.checkpoint(f)
835835

836836
# The computation of grad_f computes "sin" 5 times, 3 for the forward pass
837837
# and then to rematerialize "x2" and "x3" in the backward pass.
@@ -844,7 +844,7 @@ def test_remat_free_var(self):
844844
def f(x):
845845
y = 2 * x
846846

847-
@ad_checkpoint.checkpoint
847+
@jax.checkpoint
848848
def g():
849849
return y
850850

0 commit comments

Comments
 (0)