-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][NVVM] Add support for shared::cta destination #168056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add support for shared::cta destination #168056
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: Durgadoss R (durga4github) ChangesThis patch adds support for shared::cta as destination space in
The related intrinsic changes were merged through PR #167508. Full diff: https://github.com/llvm/llvm-project/pull/168056.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 995ade5c9b033..a10db1648887b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3342,16 +3342,17 @@ def NVVM_CpAsyncBulkTensorReduceOp :
def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
- let summary = "Async bulk copy from global memory to Shared cluster memory";
+ let summary = "Async bulk copy from global to Shared {cta or cluster} memory";
let description = [{
- Initiates an asynchronous copy operation from global memory to cluster's
- shared memory.
+ Initiates an asynchronous copy operation from global memory to shared::cta
+ or shared::cluster memory.
- The `multicastMask` operand is optional. When it is present, the Op copies
+ The `multicastMask` operand is optional and can be used only when the
+ destination is shared::cluster memory. When it is present, this Op copies
data from global memory to shared memory of multiple CTAs in the cluster.
Operand `multicastMask` specifies the destination CTAs in the cluster such
that each bit position in the 16-bit `multicastMask` operand corresponds to
- the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
+ the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
@@ -3360,7 +3361,7 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
}];
let arguments = (ins
- LLVM_PointerSharedCluster:$dstMem,
+ AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem,
LLVM_PointerGlobal:$srcMem,
LLVM_PointerShared:$mbar,
I32:$size,
@@ -3374,6 +3375,8 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
attr-dict `:` type($dstMem) `,` type($srcMem)
}];
+ let hasVerifier = 1;
+
let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 0f7b3638fb30d..7ac427dbe3941 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -212,6 +212,14 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
return success();
}
+LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
+ bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
+ if (isSharedCTA && getMulticastMask())
+ return emitError("Multicast is not supported with shared::cta mode.");
+
+ return success();
+}
+
LogicalResult ConvertFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
@@ -1980,11 +1988,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));
- // Multicast mask, if available.
+ // Multicast mask for shared::cluster only, if available.
mlir::Value multicastMask = thisOp.getMulticastMask();
const bool hasMulticastMask = static_cast<bool>(multicastMask);
- llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
- args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+ const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
+ if (!isSharedCTA) {
+ llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
+ : i16Unused);
+ }
// Cache hint, if available.
mlir::Value cacheHint = thisOp.getL2CacheHint();
@@ -1993,11 +2005,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
// Flag arguments for multicast and cachehint.
- args.push_back(builder.getInt1(hasMulticastMask));
+ if (!isSharedCTA)
+ args.push_back(builder.getInt1(hasMulticastMask));
args.push_back(builder.getInt1(hasCacheHint));
llvm::Intrinsic::ID id =
- llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+ isSharedCTA
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
+ : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
return {id, std::move(args)};
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
index 0daf24536a672..240fab5b63908 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
@@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>,
llvm.return
}
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta
+llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
+ nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
+
+ nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
+
+ llvm.return
+}
+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
new file mode 100644
index 0000000000000..d762ff3ff1e76
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) {
+ // expected-error @below {{Multicast is not supported with shared::cta mode.}}
+ nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1>
+
+ llvm.return
+}
|
ef2cbf5 to
e6c8013
Compare
This patch adds support for shared::cta as destination space in the TMA non-tensor copy Op (from global to shared::cta). * Appropriate verifier checks are added. * Unit tests are added to verify the lowering. The related intrinsic changes were merged through PR llvm#167508. Signed-off-by: Durgadoss R <[email protected]>
e6c8013 to
bd293f8
Compare
|
@clementval , when you have a moment, could you please take a look at the file-check update I have in It seems the printer now looks for a fully qualified type after I updated the |
That looks ok to me. |
This patch adds support for shared::cta as destination space in
the TMA non-tensor copy Op (from global to shared::cta).
The related intrinsic changes were merged through PR #167508.