Skip to content

Commit d0fd8f0

Browse files
[Mosaic GPU] Fix broken tests when jax is coupled with an old jax lib version.
PiperOrigin-RevId: 840181806
1 parent fdfffdc commit d0fd8f0

File tree

4 files changed

+31
-12
lines changed

4 files changed

+31
-12
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,21 +249,31 @@ def _copy_smem_to_gmem_lowering(
249249
)
250250
assert not copy_params.get("gmem_transform")
251251
if reduction_op is not None:
252+
# TODO(b/415721295): Call mgpu.dialect.async_store after the if, after
253+
# the minimal jaxlib version is 0.8.2.
254+
if not hasattr(mgpu.dialect, "TMAReduction"):
255+
raise NotImplementedError("Reduction op is not supported yet.")
252256
reduction_op_attr = getattr(
253257
mgpu.dialect.TMAReduction, reduction_op.capitalize()
254258
)
259+
mgpu.dialect.async_store(
260+
src,
261+
dst,
262+
indices,
263+
slice_lengths,
264+
predicate=predicate,
265+
commit_group=commit_group, # type: ignore[call-arg]
266+
reduction_op=reduction_op_attr,
267+
)
255268
else:
256-
reduction_op_attr = None
257-
258-
mgpu.dialect.async_store(
259-
src,
260-
dst,
261-
indices,
262-
slice_lengths,
263-
predicate=predicate,
264-
commit_group=commit_group, # type: ignore[call-arg]
265-
reduction_op=reduction_op_attr,
266-
)
269+
mgpu.dialect.async_store(
270+
src,
271+
dst,
272+
indices,
273+
slice_lengths,
274+
predicate=predicate,
275+
commit_group=commit_group, # type: ignore[call-arg]
276+
)
267277
return ()
268278

269279

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,8 @@ def _mgpu_async_store_op_lowering_rule(
991991
# flatten -> async_copy -> unflatted here, as long as flattened size is a
992992
# multiple of 16.
993993

994-
if store_op.reduction_op is not None:
994+
# TODO(b/415721295):Simplify, after the minimal jaxlib version is 0.8.2.
995+
if hasattr(mgpu, "TMAReduction") and store_op.reduction_op is not None:
995996
reduction_op = mgpu.TMAReduction(store_op.reduction_op.value).name.lower()
996997
else:
997998
reduction_op = None

tests/mosaic/gpu_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5348,6 +5348,10 @@ def body(
53485348

53495349
@parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.float16)
53505350
def test_async_store_add_reduction(self, dtype):
5351+
# TODO(b/415721295):Remove after the minimal jaxlib version is 0.8.2.
5352+
if not hasattr(mgpu_dialect, "TMAReduction"):
5353+
self.skipTest("TMAReduction op is required.")
5354+
53515355
shape = (8, 128)
53525356

53535357
def body(ctx, src, dst, smem):

tests/pallas/mosaic_gpu_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,10 @@ def kernel(x_ref, o_ref_gmem, scratch_ref):
654654

655655
@parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32)
656656
def test_copy_smem_to_gmem_reduction(self, dtype):
657+
# TODO(b/415721295):Remove after the minimal jaxlib version is 0.8.2.
658+
if not hasattr(mgpu.dialect, "TMAReduction"):
659+
self.skip_if_wg_semantics()
660+
657661
@functools.partial(
658662
self.pallas_call,
659663
grid=(200,),

0 commit comments

Comments
 (0)