|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | from collections import Counter |
| 18 | +import glob |
18 | 19 | import logging |
19 | 20 | import math |
20 | 21 | import os |
@@ -648,6 +649,35 @@ def test_persistent_cache_enable_xla_caches(self): |
648 | 649 | self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") |
649 | 650 | self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) |
650 | 651 |
|
| 652 | + @jtu.skip_on_devices("tpu") # TPU backend does not dump on deserialize |
| 653 | + def test_dump_on_cache_hit(self): |
| 654 | + previous_counts = Counter(_counts) |
| 655 | + with ( |
| 656 | + config.persistent_cache_min_compile_time_secs(0), |
| 657 | + config.persistent_cache_min_entry_size_bytes(0), |
| 658 | + tempfile.TemporaryDirectory() as dump_dir1, |
| 659 | + tempfile.TemporaryDirectory() as dump_dir2 |
| 660 | + ): |
| 661 | + jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir1})(1) |
| 662 | + self.assertEqual( |
| 663 | + _counts["/jax/compilation_cache/cache_hits"], |
| 664 | + previous_counts["/jax/compilation_cache/cache_hits"], |
| 665 | + ) |
| 666 | + jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir2, "xla_dump_hlo_as_proto": True, "xla_dump_hlo_as_text": True})(1) |
| 667 | + self.assertEqual( |
| 668 | + _counts["/jax/compilation_cache/cache_hits"], |
| 669 | + previous_counts["/jax/compilation_cache/cache_hits"] + 1, |
| 670 | + 1) |
| 671 | + dump1_files = glob.glob(os.path.join(dump_dir1, "*after_optimizations.txt")) |
| 672 | + dump2_files = glob.glob(os.path.join(dump_dir2, "*after_optimizations.txt")) |
| 673 | + self.assertEqual(len(dump1_files), 1) |
| 674 | + self.assertEqual(len(dump2_files), 1) |
| 675 | + with (open(dump1_files[0]) as file1, open(dump2_files[0]) as file2): |
| 676 | + self.assertEqual(file1.read(), file2.read()) |
| 677 | + dump2_pbs = glob.glob(os.path.join(dump_dir2, "*after_optimizations.hlo.pb")) |
| 678 | + self.assertEqual(len(dump2_pbs), 1) |
| 679 | + |
| 680 | + |
651 | 681 | @jtu.with_config( |
652 | 682 | jax_enable_compilation_cache=False, |
653 | 683 | jax_persistent_cache_min_compile_time_secs=0, |
|
0 commit comments