@@ -7914,21 +7914,96 @@ class TracebackTest(jtu.JaxTestCase):
79147914 def cur_depth (self ):
79157915 return len (inspect .stack ())
79167916
7917+ def test_traceback_test (self ):
7918+ expected_depth_foo = 1
7919+ expected_depth_bar = 2
7920+ init_depth = self .cur_depth ()
7921+ def foo ():
7922+ self .assertExpectedDepth (init_depth , expected_depth_foo )
7923+ def bar ():
7924+ self .assertExpectedDepth (init_depth , expected_depth_bar )
7925+ bar ()
7926+
7927+ foo ()
7928+
79177929 def assertExpectedDepth (self , init_depth , expected_depth ):
7918- # `+ 1` is for the `assertExpectedDepth` stack frame itself
7919- self .assertEqual (self .cur_depth () - init_depth , expected_depth + 1 )
7930+ # `- 1` is for the `assertExpectedDepth` stack frame itself
7931+ self .assertEqual (self .cur_depth () - init_depth - 1 , expected_depth )
79207932
79217933 def test_scan_traceback (self ):
79227934 expected_depth = 5
79237935 init_depth = self .cur_depth ()
79247936
79257937 def f (c , x ):
7926- frames = inspect .stack ()
79277938 self .assertExpectedDepth (init_depth , expected_depth )
79287939 return (c , ())
79297940
79307941 jax .lax .scan (f , 0 , jnp .arange (4 ))
79317942
7943+ def test_cond_traceback (self ):
7944+ if sys .version_info < (3 , 14 ):
7945+ # Fails because 3.11 adds an extra stack frame due to a list comprehension
7946+ self .skipTest ("Expected failure." )
7947+ expected_depth = 8
7948+ init_depth = self .cur_depth ()
7949+
7950+ def f ():
7951+ self .assertExpectedDepth (init_depth , expected_depth )
7952+
7953+ lax .cond (True , f , lambda : None )
7954+
7955+ def test_jit_traceback (self ):
7956+ # TODO(dougalm): improve this! jit can (and should) be nested a lot.
7957+ expected_depth = 14
7958+ init_depth = self .cur_depth ()
7959+ @jit
7960+ def foo (x ):
7961+ self .assertExpectedDepth (init_depth , expected_depth )
7962+ return x
7963+ foo (1 )
7964+
7965+ def test_grad_traceback (self ):
7966+ # TODO(dougalm): improve this
7967+ expected_depth = 13
7968+ init_depth = self .cur_depth ()
7969+
7970+ def foo (x ):
7971+ self .assertExpectedDepth (init_depth , expected_depth )
7972+ return x
7973+
7974+ grad (foo )(1.0 )
7975+
7976+ def test_vmap_traceback (self ):
7977+ # TODO(dougalm): improve this
7978+ expected_depth = 8
7979+ init_depth = self .cur_depth ()
7980+
7981+ def foo (x ):
7982+ self .assertExpectedDepth (init_depth , expected_depth )
7983+ return x
7984+
7985+ jax .vmap (foo )(np .arange (3 ))
7986+
7987+ def test_custom_vjp_traceback (self ):
7988+ # TODO(dougalm): improve this
7989+ expected_depth_f = 11
7990+ expected_depth_f_fwd = 22
7991+ expected_depth_f_rev = 13
7992+ init_depth = self .cur_depth ()
7993+ @jax .custom_vjp
7994+ def f (x ):
7995+ self .assertExpectedDepth (init_depth , expected_depth_f )
7996+ return x
7997+ def f_fwd (x ):
7998+ self .assertExpectedDepth (init_depth , expected_depth_f_fwd )
7999+ return x , None
8000+ def f_rev (_ , g ):
8001+ self .assertExpectedDepth (init_depth , expected_depth_f_rev )
8002+ return (g ,)
8003+ f .defvjp (f_fwd , f_rev )
8004+
8005+ f (1.0 )
8006+ grad (f )(1.0 )
79328007
79338008if __name__ == '__main__' :
79348009 absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments