@@ -78,17 +78,23 @@ def kernel(x_ref, o_ref):
7878 jax .block_until_ready (compiled_kernel (x ))
7979 self .assertIn ('It works!' , get_output ())
8080
81- def test_debug_print_with_values (self ):
81+ @parameterized .product (dtype = [jnp .int32 , jnp .float32 ])
82+ def test_debug_print_with_values (self , dtype ):
8283 @functools .partial (
8384 self .pallas_call ,
8485 in_specs = (pl .BlockSpec (memory_space = pltpu .SMEM ),),
8586 out_shape = jax .ShapeDtypeStruct ((8 , 128 ), jnp .float32 ),
8687 )
8788 def kernel (x_ref , o_ref ):
88- pl .debug_print ('BEGIN1 x[0] == {}' , x_ref [0 ])
89- pl .debug_print ('BEGIN2 x[0] == {} ; x[1] == {} ; END' , x_ref [0 ], x_ref [1 ])
90-
91- x = jnp .array ([42 , 24 ]).astype (jnp .int32 )
89+ if dtype == jnp .int32 :
90+ pl .debug_print ('BEGIN1 x[0] == {}' , x_ref [0 ])
91+ pl .debug_print (
92+ 'BEGIN2 x[0] == {} ; x[1] == {} ; END' , x_ref [0 ], x_ref [1 ]
93+ )
94+ else :
95+ pl .debug_print ('BEGIN1 x[0] == ' , x_ref [0 ])
96+
97+ x = jnp .array ([42 , 24 ], dtype = dtype )
9298 compiled_kernel = (
9399 jax .jit (kernel )
94100 .lower (x )
@@ -97,8 +103,11 @@ def kernel(x_ref, o_ref):
97103 with jtu .capture_stderr () as get_output :
98104 jax .block_until_ready (compiled_kernel (x ))
99105 output = get_output ()
100- self .assertIn ('BEGIN1 x[0] == 42' , output )
101- self .assertIn ('BEGIN2 x[0] == 42 ; x[1] == 24 ; END' , output )
106+ if dtype == jnp .int32 :
107+ self .assertIn ('BEGIN1 x[0] == 42' , output )
108+ self .assertIn ('BEGIN2 x[0] == 42 ; x[1] == 24 ; END' , output )
109+ else :
110+ self .assertIn ('BEGIN1 x[0] == f32[] 42' , output )
102111
103112 @parameterized .named_parameters (
104113 (f"{ '_' .join (map (str , shape ))} _{ dtype .__name__ } " , shape , dtype )
0 commit comments