Skip to content

Commit 540b857

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Make rng tests inherit from JaxTestCase
PiperOrigin-RevId: 839410599
1 parent 5b9cfa3 commit 540b857

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/pallas/tpu_pallas_random_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def f(rng_key):
243243
self.assertGreaterEqual(jnp.max(y), jnp.min(y))
244244

245245

246-
class BlockInvarianceTest(parameterized.TestCase):
246+
class BlockInvarianceTest(jtu.JaxTestCase):
247247

248248
def setUp(self):
249249
if not jtu.test_device_matches(["tpu"]):
@@ -290,7 +290,7 @@ def body(key_ref, o_ref):
290290
np.testing.assert_array_equal(result_16x128, result_32x256)
291291

292292

293-
class ThreefryTest(parameterized.TestCase):
293+
class ThreefryTest(jtu.JaxTestCase):
294294

295295
def setUp(self):
296296
if not jtu.test_device_matches(["tpu"]):
@@ -373,7 +373,7 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape):
373373
np.testing.assert_array_equal(jax_gen, pl_gen)
374374

375375

376-
class PhiloxTest(parameterized.TestCase):
376+
class PhiloxTest(jtu.JaxTestCase):
377377

378378
def setUp(self):
379379
if not jtu.test_device_matches(["tpu"]):

0 commit comments

Comments
 (0)