Skip to content

Commit 7a395f3

Browse files
Merge pull request #33634 from jakevdp:ad-checkpoint-doc
PiperOrigin-RevId: 839347694
2 parents c1adda8 + 3f0d0f8 commit 7a395f3

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

docs/jax.ad_checkpoint.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
``jax.ad_checkpoint`` module
2+
============================
3+
4+
.. currentmodule:: jax.ad_checkpoint
5+
6+
.. automodule:: jax.ad_checkpoint
7+
8+
.. autosummary::
9+
:toctree: _autosummary
10+
11+
checkpoint_name

docs/jax.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Subpackages
1414
jax.lax
1515
jax.random
1616
jax.sharding
17+
jax.ad_checkpoint
1718
jax.debug
1819
jax.dlpack
1920
jax.distributed

tests/documentation_coverage_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def jax_docs_dir() -> str:
5252

5353
UNDOCUMENTED_APIS = {
5454
'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'version'],
55+
'jax.ad_checkpoint': ['checkpoint', 'checkpoint_policies', 'print_saved_residuals', 'remat', 'Offloadable', 'Recompute', 'Saveable'],
5556
'jax.custom_batching': ['custom_vmap', 'sequential_vmap'],
5657
'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'],
5758
'jax.custom_transpose': ['custom_transpose'],
@@ -75,7 +76,6 @@ def jax_docs_dir() -> str:
7576
# A list of modules to skip entirely, either because they cannot be imported
7677
# or because they are not expected to be documented.
7778
MODULES_TO_SKIP = [
78-
"jax.ad_checkpoint", # internal tools, not documented.
7979
"jax.api_util", # internal tools, not documented.
8080
"jax.cloud_tpu_init", # deprecated in JAX v0.8.1
8181
"jax.collect_profile", # fails when xprof is not available.

0 commit comments

Comments
 (0)