Skip to content

Commit f9230e2

Browse files
Rifur13Google-ML-Automation
authored andcommitted
[Pallas MGPU] Use the cp.async.bulk instruction for large contiguous copies. Currently we’re limited to 256 elements per dimension when using the tensormap in cp.async.bulk.tensor.
PiperOrigin-RevId: 838501119
1 parent 4592cfb commit f9230e2

File tree

3 files changed

+193
-16
lines changed

3 files changed

+193
-16
lines changed

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -859,11 +859,6 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
859859
f" {collective_size}"
860860
)
861861

862-
if max(slice_shape) > 256:
863-
raise ValueError(
864-
"Async copies only support copying <=256 elements along each"
865-
" dimension"
866-
)
867862
if (zeroth_bw := slice_shape[-1] * element_bitwidth) % 128 != 0:
868863
raise ValueError(
869864
"Async copies require the number of bits copied along the last"
@@ -1264,6 +1259,109 @@ def async_copy(
12641259
return
12651260

12661261
assert gather_indices is None # Only tiled TMA handled below.
1262+
1263+
def check_contiguous_slice(slice_shape, strides):
1264+
assert strides[-1] == 1
1265+
1266+
expected_stride = 1
1267+
for dim, stride in zip(reversed(slice_shape), reversed(strides), strict=True):
1268+
if dim != 1 and stride != expected_stride:
1269+
return False
1270+
expected_stride *= dim
1271+
1272+
return True
1273+
1274+
gmem_ref = _find_kernel_argument_for_gmem_ref(gmem_ref)
1275+
ref = gmem_ref
1276+
for t in gmem_transform:
1277+
ref = t.apply(ref)
1278+
ref_ty = ir.MemRefType(ref.type)
1279+
strides, _ = ref_ty.get_strides_and_offset()
1280+
1281+
# Use the simpler copy instruction for contiguous transfers.
1282+
is_raw_contiguous_copy = (
1283+
check_contiguous_slice(slice_shape, strides)
1284+
and reduction_op is None
1285+
and (
1286+
swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle
1287+
)
1288+
and collective_size == 1
1289+
and partitioned is None
1290+
)
1291+
if isinstance(predicate, _DefaultPredicate):
1292+
predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP)
1293+
if predicate is None:
1294+
predicate = c(1, ir.IntegerType.get_signless(1))
1295+
1296+
smem_ptr = utils.memref_ptr(smem_ref, memory_space=3)
1297+
if is_raw_contiguous_copy:
1298+
index = ir.IndexType.get()
1299+
i64 = ir.IntegerType.get_signless(64)
1300+
base, base_offset, *_ = memref.extract_strided_metadata(gmem_ref)
1301+
1302+
dyn_offset = base_offset
1303+
for dyn_idx, stride in zip(dyn_base_indices, strides):
1304+
step = arith.muli(dyn_idx, c(stride, index))
1305+
dyn_offset = arith.addi(dyn_offset, step)
1306+
dyn_offset_i64 = arith.index_cast(i64, dyn_offset)
1307+
1308+
gmem_base_ptr = utils.getelementptr(
1309+
utils.memref_ptr(base), [dyn_offset_i64], src_ref_ty.element_type
1310+
)
1311+
1312+
if gmem_peer_id is not None:
1313+
assert gmem_peer_id is not GLOBAL_BROADCAST
1314+
self._ensure_nvshmem_decls()
1315+
if not isinstance(gmem_peer_id, ir.Value):
1316+
gmem_peer_id = c(gmem_peer_id, i32)
1317+
1318+
gmem_base_ptr = llvm.call(
1319+
gmem_base_ptr.type,
1320+
[gmem_base_ptr, gmem_peer_id],
1321+
[],
1322+
[],
1323+
callee="nvshmem_ptr",
1324+
)
1325+
gmem_base_ptr = llvm.addrspacecast(
1326+
ir.Type.parse("!llvm.ptr<1>"), gmem_base_ptr
1327+
)
1328+
1329+
if gmem_ref is src_ref:
1330+
assert barrier is not None # for pytype
1331+
barrier_ptr = barrier.get_ptr()
1332+
if arrive:
1333+
nvvm.mbarrier_arrive_expect_tx(
1334+
barrier_ptr, transfer_bytes, predicate=predicate
1335+
)
1336+
llvm.inline_asm(
1337+
ir.Type.parse("!llvm.void"),
1338+
[predicate, smem_ptr, gmem_base_ptr, transfer_bytes, barrier_ptr],
1339+
"""
1340+
@$0 cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$2], $3, [$4];
1341+
""",
1342+
"b,l,l,r,l",
1343+
has_side_effects=True,
1344+
)
1345+
else:
1346+
llvm.inline_asm(
1347+
ir.Type.parse("!llvm.void"),
1348+
[predicate, gmem_base_ptr, smem_ptr, transfer_bytes],
1349+
"""
1350+
@$0 cp.async.bulk.global.shared::cta.bulk_group [$1], [$2], $3;
1351+
""",
1352+
"b,l,l,r",
1353+
has_side_effects=True,
1354+
)
1355+
if arrive:
1356+
nvvm.cp_async_bulk_commit_group()
1357+
return
1358+
1359+
# Below are tiled TMA copies using a tensormap.
1360+
if max(slice_shape) > 256:
1361+
raise ValueError(
1362+
"Async copies only support copying <=256 elements along each"
1363+
" dimension"
1364+
)
12671365
tma_desc = self._get_tma_desc(
12681366
gmem_ref, gmem_transform, gmem_peer_id,
12691367
tuple(slice_shape), swizzle, reduction_op,
@@ -1272,11 +1370,6 @@ def async_copy(
12721370
rev_dyn_base_indices = [
12731371
arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices)
12741372
]
1275-
if isinstance(predicate, _DefaultPredicate):
1276-
predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP)
1277-
if predicate is None:
1278-
predicate = c(1, ir.IntegerType.get_signless(1))
1279-
smem_ptr = utils.memref_ptr(smem_ref, memory_space=3)
12801373
if gmem_ref is src_ref:
12811374
assert barrier is not None # for pytype
12821375
barrier_ptr = barrier.get_ptr()

tests/pallas/gpu_pallas_distributed_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from jax.experimental.pallas.ops.gpu.reduce_scatter_mgpu import reduce_scatter
3030
from jax.experimental.pallas.ops.gpu.all_gather_mgpu import all_gather
3131
import jax.numpy as jnp
32+
import math
3233
import numpy as np
3334

3435

@@ -306,6 +307,52 @@ def _store():
306307
ref = lax.broadcasted_iota(jnp.int32, (128, 128), 1)
307308
np.testing.assert_array_equal(y, np.concat([ref, ref], axis=0))
308309

310+
def test_contiguous_copy_tma(self):
311+
if jax.process_index() > 2:
312+
return # Only 2 processes needed.
313+
314+
shape = (512,)
315+
316+
def kernel(y_ref, smem_ref, sem):
317+
dev_id = lax.axis_index("y")
318+
other_dev_id = 1 - dev_id
319+
320+
# Device ID must be an int32.
321+
zero = jnp.int32(0)
322+
323+
@pl.when(dev_id == zero)
324+
def _store():
325+
output = plgpu.layout_cast(
326+
jnp.arange(math.prod(shape)).reshape(shape),
327+
plgpu.Layout.WG_STRIDED(shape, vec_size=1),
328+
)
329+
smem_ref[...] = output
330+
plgpu.commit_smem()
331+
plgpu.copy_smem_to_gmem(smem_ref, plgpu.remote_ref(y_ref, (zero, dev_id)))
332+
plgpu.copy_smem_to_gmem(smem_ref, plgpu.remote_ref(y_ref, (zero, other_dev_id)))
333+
plgpu.wait_smem_to_gmem(0)
334+
pl.semaphore_signal(sem, 1, device_id=(zero, other_dev_id))
335+
pl.semaphore_wait(sem)
336+
337+
kernel_call = pl.pallas_call(
338+
kernel,
339+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
340+
out_shape=jax.ShapeDtypeStruct(shape, jnp.int32),
341+
scratch_shapes=[
342+
plgpu.SMEM(shape, jnp.int32),
343+
plgpu.SemaphoreType.REGULAR,
344+
],
345+
)
346+
mesh = jtu.create_mesh((1, 2), ("x", "y"))
347+
y = jax.jit(
348+
jax.shard_map(
349+
kernel_call, mesh=mesh, in_specs=(), out_specs=P("y"), check_vma=False,
350+
)
351+
)()
352+
y = multihost_utils.process_allgather(y, tiled=True)
353+
ref = jnp.arange(math.prod(shape)).reshape(shape)
354+
np.testing.assert_array_equal(y, np.concat([ref, ref], axis=0))
355+
309356

310357
class PallasCallMultimemTest(TestCase):
311358

tests/pallas/mosaic_gpu_test.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,41 @@ def kernel(x_ref, o_ref_gmem, o_ref_gmem_alias, scratch_ref):
676676
output_val = x.reshape(-1, 128).sum(axis=0)
677677
np.testing.assert_array_equal(output, output_val)
678678

679+
@parameterized.parameters(
680+
((64, 128,), (slice(2, 3), slice(0, 128)), jnp.bfloat16),
681+
((256,), (...,), jnp.bfloat16),
682+
((64, 128,), (...,), jnp.bfloat16),
683+
((3, 64, 1, 128), (0, slice(0, 32), 0, slice(0, 128)), jnp.float32),
684+
((3, 64, 1, 128), (...,), jnp.float32),
685+
((3, 64, 128), (...,), jnp.float32),
686+
((10, 10, 512,), (4, 4), jnp.bfloat16),
687+
((10, 1024,), (4,), jnp.bfloat16),
688+
((8192,), (...,), jnp.bfloat16),
689+
((8192,), (slice(4096, 8192),), jnp.bfloat16),
690+
((8192,), (slice(4096, 8192),), jnp.float32),
691+
)
692+
def test_copy_gmem_to_smem_contiguous(self, shape, indexer, dtype):
693+
@functools.partial(
694+
self.pallas_call,
695+
out_shape=jax.ShapeDtypeStruct(shape, dtype),
696+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
697+
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
698+
scratch_shapes=[plgpu.SMEM(shape, dtype), plgpu.Barrier()],
699+
grid=(1,),
700+
)
701+
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
702+
plgpu.copy_gmem_to_smem(
703+
x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier_ref
704+
)
705+
plgpu.barrier_wait(barrier_ref)
706+
scratch_ref[indexer] = scratch_ref[indexer] + 1
707+
plgpu.commit_smem()
708+
plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref.at[indexer])
709+
plgpu.wait_smem_to_gmem(0)
710+
711+
x = jax.random.normal(jax.random.key(0), shape, dtype=dtype)
712+
np.testing.assert_allclose(kernel(x)[indexer], x[indexer] + 1.0)
713+
679714
@parameterized.named_parameters(
680715
{"testcase_name": "1d_none",
681716
"shape": (256,), "indexers": (slice(0, 128), slice(None, 32))},
@@ -1482,15 +1517,15 @@ def kernel(out1_ref, out2_ref):
14821517
def test_program_id_in_block_spec(self):
14831518
@functools.partial(
14841519
self.pallas_call,
1485-
in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),),
1486-
out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),
1487-
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32),
1520+
in_specs=(pl.BlockSpec((1, 128), lambda i: (pl.program_id(0), i)),),
1521+
out_specs=pl.BlockSpec((1, 128), lambda i: (pl.program_id(0), i)),
1522+
out_shape=jax.ShapeDtypeStruct([2, 256], jnp.int32),
14881523
grid=2,
14891524
)
14901525
def kernel(x_ref, o_ref):
14911526
o_ref[...] = x_ref[...]
14921527

1493-
x = jnp.arange(2 * 128, dtype=jnp.int32).reshape([2, 128])
1528+
x = jnp.arange(2 * 256, dtype=jnp.int32).reshape([2, 256])
14941529
np.testing.assert_array_equal(kernel(x), x)
14951530

14961531
def test_num_programs(self):
@@ -2528,8 +2563,10 @@ def kernel(x_ref, o_ref):
25282563
ptx = output()
25292564
self.assertIn(".file", ptx)
25302565
self.assertIn(".loc", ptx)
2531-
[path] = re.findall(r'.file\s+\d+\s+"(.+)"', ptx)
2532-
self.assertEndsWith(__file__, path)
2566+
paths = re.findall(r'.file\s+\d+\s+"(.+)"', ptx)
2567+
paths = [p for p in paths if p != "-"]
2568+
self.assertLen(paths, 1)
2569+
self.assertEndsWith(__file__, paths[0])
25332570

25342571
def test_collective_arrival_count(self):
25352572
def kernel(dst, collective_barrier):

0 commit comments

Comments
 (0)