Skip to content

Commit 174d514

Browse files
Merge pull request #31328 from olupton:test-dump-on-deserialize
PiperOrigin-RevId: 840268767
2 parents eae3b49 + 1588d16 commit 174d514

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/compilation_cache_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from collections import Counter
18+
import glob
1819
import logging
1920
import math
2021
import os
@@ -648,6 +649,35 @@ def test_persistent_cache_enable_xla_caches(self):
648649
self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir")
649650
self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE)
650651

652+
@jtu.skip_on_devices("tpu") # TPU backend does not dump on deserialize
653+
def test_dump_on_cache_hit(self):
654+
previous_counts = Counter(_counts)
655+
with (
656+
config.persistent_cache_min_compile_time_secs(0),
657+
config.persistent_cache_min_entry_size_bytes(0),
658+
tempfile.TemporaryDirectory() as dump_dir1,
659+
tempfile.TemporaryDirectory() as dump_dir2
660+
):
661+
jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir1})(1)
662+
self.assertEqual(
663+
_counts["/jax/compilation_cache/cache_hits"],
664+
previous_counts["/jax/compilation_cache/cache_hits"],
665+
)
666+
jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir2, "xla_dump_hlo_as_proto": True, "xla_dump_hlo_as_text": True})(1)
667+
self.assertEqual(
668+
_counts["/jax/compilation_cache/cache_hits"],
669+
previous_counts["/jax/compilation_cache/cache_hits"] + 1,
670+
1)
671+
dump1_files = glob.glob(os.path.join(dump_dir1, "*after_optimizations.txt"))
672+
dump2_files = glob.glob(os.path.join(dump_dir2, "*after_optimizations.txt"))
673+
self.assertEqual(len(dump1_files), 1)
674+
self.assertEqual(len(dump2_files), 1)
675+
with (open(dump1_files[0]) as file1, open(dump2_files[0]) as file2):
676+
self.assertEqual(file1.read(), file2.read())
677+
dump2_pbs = glob.glob(os.path.join(dump_dir2, "*after_optimizations.hlo.pb"))
678+
self.assertEqual(len(dump2_pbs), 1)
679+
680+
651681
@jtu.with_config(
652682
jax_enable_compilation_cache=False,
653683
jax_persistent_cache_min_compile_time_secs=0,

0 commit comments

Comments
 (0)