Skip to content

Commit e704639

Browse files
Merge pull request #33661 from jax-ml:more-traceback-regression-tests
PiperOrigin-RevId: 839382284
2 parents e4cadda + 4871fca commit e704639

File tree

1 file changed

+78
-3
lines changed

1 file changed

+78
-3
lines changed

tests/api_test.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7914,21 +7914,96 @@ class TracebackTest(jtu.JaxTestCase):
79147914
def cur_depth(self):
79157915
return len(inspect.stack())
79167916

7917+
def test_traceback_test(self):
7918+
expected_depth_foo = 1
7919+
expected_depth_bar = 2
7920+
init_depth = self.cur_depth()
7921+
def foo():
7922+
self.assertExpectedDepth(init_depth, expected_depth_foo)
7923+
def bar():
7924+
self.assertExpectedDepth(init_depth, expected_depth_bar)
7925+
bar()
7926+
7927+
foo()
7928+
79177929
def assertExpectedDepth(self, init_depth, expected_depth):
7918-
# `+ 1` is for the `assertExpectedDepth` stack frame itself
7919-
self.assertEqual(self.cur_depth() - init_depth, expected_depth + 1)
7930+
# `- 1` is for the `assertExpectedDepth` stack frame itself
7931+
self.assertEqual(self.cur_depth() - init_depth - 1, expected_depth)
79207932

79217933
def test_scan_traceback(self):
79227934
expected_depth = 5
79237935
init_depth = self.cur_depth()
79247936

79257937
def f(c, x):
7926-
frames = inspect.stack()
79277938
self.assertExpectedDepth(init_depth, expected_depth)
79287939
return (c, ())
79297940

79307941
jax.lax.scan(f, 0, jnp.arange(4))
79317942

7943+
def test_cond_traceback(self):
7944+
if sys.version_info < (3, 14):
7945+
# Fails because 3.11 adds an extra stack frame due to a list comprehension
7946+
self.skipTest("Expected failure.")
7947+
expected_depth = 8
7948+
init_depth = self.cur_depth()
7949+
7950+
def f():
7951+
self.assertExpectedDepth(init_depth, expected_depth)
7952+
7953+
lax.cond(True, f, lambda: None)
7954+
7955+
def test_jit_traceback(self):
7956+
# TODO(dougalm): improve this! jit can (and should) be nested a lot.
7957+
expected_depth = 14
7958+
init_depth = self.cur_depth()
7959+
@jit
7960+
def foo(x):
7961+
self.assertExpectedDepth(init_depth, expected_depth)
7962+
return x
7963+
foo(1)
7964+
7965+
def test_grad_traceback(self):
7966+
# TODO(dougalm): improve this
7967+
expected_depth = 13
7968+
init_depth = self.cur_depth()
7969+
7970+
def foo(x):
7971+
self.assertExpectedDepth(init_depth, expected_depth)
7972+
return x
7973+
7974+
grad(foo)(1.0)
7975+
7976+
def test_vmap_traceback(self):
7977+
# TODO(dougalm): improve this
7978+
expected_depth = 8
7979+
init_depth = self.cur_depth()
7980+
7981+
def foo(x):
7982+
self.assertExpectedDepth(init_depth, expected_depth)
7983+
return x
7984+
7985+
jax.vmap(foo)(np.arange(3))
7986+
7987+
def test_custom_vjp_traceback(self):
7988+
# TODO(dougalm): improve this
7989+
expected_depth_f = 11
7990+
expected_depth_f_fwd = 22
7991+
expected_depth_f_rev = 13
7992+
init_depth = self.cur_depth()
7993+
@jax.custom_vjp
7994+
def f(x):
7995+
self.assertExpectedDepth(init_depth, expected_depth_f)
7996+
return x
7997+
def f_fwd(x):
7998+
self.assertExpectedDepth(init_depth, expected_depth_f_fwd)
7999+
return x, None
8000+
def f_rev(_, g):
8001+
self.assertExpectedDepth(init_depth, expected_depth_f_rev)
8002+
return (g,)
8003+
f.defvjp(f_fwd, f_rev)
8004+
8005+
f(1.0)
8006+
grad(f)(1.0)
79328007

79338008
if __name__ == '__main__':
79348009
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)