Skip to content

Conversation

@durga4github
Copy link
Contributor

This patch adds support for shared::cta as destination space in
the TMA non-tensor copy Op (from global to shared::cta).

  • Appropriate verifier checks are added.
  • Unit tests are added to verify the lowering.

The related intrinsic changes were merged through PR #167508.

@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Durgadoss R (durga4github)

Changes

This patch adds support for shared::cta as destination space in
the TMA non-tensor copy Op (from global to shared::cta).

  • Appropriate verifier checks are added.
  • Unit tests are added to verify the lowering.

The related intrinsic changes were merged through PR #167508.


Full diff: https://github.com/llvm/llvm-project/pull/168056.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+20-5)
  • (modified) mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir (+11)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 995ade5c9b033..a10db1648887b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3342,16 +3342,17 @@ def NVVM_CpAsyncBulkTensorReduceOp :
 
 def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
   NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
-  let summary = "Async bulk copy from global memory to Shared cluster memory";
+  let summary = "Async bulk copy from global to Shared {cta or cluster} memory";
   let description = [{
-    Initiates an asynchronous copy operation from global memory to cluster's
-    shared memory.
+    Initiates an asynchronous copy operation from global memory to shared::cta
+    or shared::cluster memory.
 
-    The `multicastMask` operand is optional. When it is present, the Op copies
+    The `multicastMask` operand is optional and can be used only when the
+    destination is shared::cluster memory. When it is present, this Op copies
     data from global memory to shared memory of multiple CTAs in the cluster.
     Operand `multicastMask` specifies the destination CTAs in the cluster such
     that each bit position in the 16-bit `multicastMask` operand corresponds to
-    the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
+    the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. 
 
     The `l2CacheHint` operand is optional, and it is used to specify cache
     eviction policy that may be used during the memory access.
@@ -3360,7 +3361,7 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
   }];
 
   let arguments = (ins
-    LLVM_PointerSharedCluster:$dstMem,
+    AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem,
     LLVM_PointerGlobal:$srcMem,
     LLVM_PointerShared:$mbar,
     I32:$size,
@@ -3374,6 +3375,8 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
     attr-dict  `:` type($dstMem) `,` type($srcMem)
   }];
 
+  let hasVerifier = 1;
+
   let extraClassDeclaration = [{
     static mlir::NVVM::IDArgPair
     getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 0f7b3638fb30d..7ac427dbe3941 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -212,6 +212,14 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
   return success();
 }
 
+LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
+  bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
+  if (isSharedCTA && getMulticastMask())
+    return emitError("Multicast is not supported with shared::cta mode.");
+
+  return success();
+}
+
 LogicalResult ConvertFloatToTF32Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
   switch (getRnd()) {
@@ -1980,11 +1988,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
   args.push_back(mt.lookupValue(thisOp.getSrcMem()));
   args.push_back(mt.lookupValue(thisOp.getSize()));
 
-  // Multicast mask, if available.
+  // Multicast mask for shared::cluster only, if available.
   mlir::Value multicastMask = thisOp.getMulticastMask();
   const bool hasMulticastMask = static_cast<bool>(multicastMask);
-  llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
-  args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+  const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
+  if (!isSharedCTA) {
+    llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+    args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
+                                    : i16Unused);
+  }
 
   // Cache hint, if available.
   mlir::Value cacheHint = thisOp.getL2CacheHint();
@@ -1993,11 +2005,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
   args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
 
   // Flag arguments for multicast and cachehint.
-  args.push_back(builder.getInt1(hasMulticastMask));
+  if (!isSharedCTA)
+    args.push_back(builder.getInt1(hasMulticastMask));
   args.push_back(builder.getInt1(hasCacheHint));
 
   llvm::Intrinsic::ID id =
-      llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+      isSharedCTA
+          ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
+          : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
 
   return {id, std::move(args)};
 }
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
index 0daf24536a672..240fab5b63908 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
@@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>,
   llvm.return
 }
 
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta
+llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
+
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
+
+  llvm.return
+}
+
 // CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
 llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
   // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
new file mode 100644
index 0000000000000..d762ff3ff1e76
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) {
+  // expected-error @below {{Multicast is not supported with shared::cta mode.}}
+  nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1>
+
+  llvm.return
+}

@durga4github durga4github force-pushed the durgadossr/mlir_tma_g2s_update branch from ef2cbf5 to e6c8013 Compare November 14, 2025 18:49
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 14, 2025
This patch adds support for shared::cta as destination space
in the TMA non-tensor copy Op (from global to shared::cta).

* Appropriate verifier checks are added.
* Unit tests are added to verify the lowering.

The related intrinsic changes were merged through PR llvm#167508.

Signed-off-by: Durgadoss R <[email protected]>
@durga4github durga4github force-pushed the durgadossr/mlir_tma_g2s_update branch from e6c8013 to bd293f8 Compare November 14, 2025 18:52
@durga4github
Copy link
Contributor Author

durga4github commented Nov 14, 2025

@clementval , when you have a moment, could you please take a look at the file-check update I have in flang/test/Lower/CUDA/cuda-device-proc.cuf ?

It seems the printer now looks for a fully qualified type after I updated the dstMem 's type to use AnyTypeOf[shared, shared_cluster].

@clementval
Copy link
Contributor

@clementval , when you have a moment, could you please take a look at the file-check update I have in flang/test/Lower/CUDA/cuda-device-proc.cuf ?

It seems the printer now looks for a fully qualified type after I updated the dstMem 's type to use AnyTypeOf[shared, shared_cluster].

That looks ok to me.

@durga4github durga4github merged commit 95aa70c into llvm:main Nov 17, 2025
10 checks passed
@durga4github durga4github deleted the durgadossr/mlir_tma_g2s_update branch November 17, 2025 08:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category mlir:llvm mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants