Skip to content

Commit 334ef7e

Browse files
IvyZXGoogle-ML-Automation
authored andcommitted
Add test case to debug_print scalar in Pallas.
PiperOrigin-RevId: 840367743
1 parent 2e006a9 commit 334ef7e

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

tests/pallas/tpu_pallas_call_print_test.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,23 @@ def kernel(x_ref, o_ref):
7878
jax.block_until_ready(compiled_kernel(x))
7979
self.assertIn('It works!', get_output())
8080

81-
def test_debug_print_with_values(self):
81+
@parameterized.product(dtype=[jnp.int32, jnp.float32])
82+
def test_debug_print_with_values(self, dtype):
8283
@functools.partial(
8384
self.pallas_call,
8485
in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),),
8586
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
8687
)
8788
def kernel(x_ref, o_ref):
88-
pl.debug_print('BEGIN1 x[0] == {}', x_ref[0])
89-
pl.debug_print('BEGIN2 x[0] == {} ; x[1] == {} ; END', x_ref[0], x_ref[1])
90-
91-
x = jnp.array([42, 24]).astype(jnp.int32)
89+
if dtype == jnp.int32:
90+
pl.debug_print('BEGIN1 x[0] == {}', x_ref[0])
91+
pl.debug_print(
92+
'BEGIN2 x[0] == {} ; x[1] == {} ; END', x_ref[0], x_ref[1]
93+
)
94+
else:
95+
pl.debug_print('BEGIN1 x[0] == ', x_ref[0])
96+
97+
x = jnp.array([42, 24], dtype=dtype)
9298
compiled_kernel = (
9399
jax.jit(kernel)
94100
.lower(x)
@@ -97,8 +103,11 @@ def kernel(x_ref, o_ref):
97103
with jtu.capture_stderr() as get_output:
98104
jax.block_until_ready(compiled_kernel(x))
99105
output = get_output()
100-
self.assertIn('BEGIN1 x[0] == 42', output)
101-
self.assertIn('BEGIN2 x[0] == 42 ; x[1] == 24 ; END', output)
106+
if dtype == jnp.int32:
107+
self.assertIn('BEGIN1 x[0] == 42', output)
108+
self.assertIn('BEGIN2 x[0] == 42 ; x[1] == 24 ; END', output)
109+
else:
110+
self.assertIn('BEGIN1 x[0] == f32[] 42', output)
102111

103112
@parameterized.named_parameters(
104113
(f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype)

0 commit comments

Comments
 (0)