-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[clang][NVPTX] Add support for mixed-precision FP arithmetic #168359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-clang Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds NVVM intrinsics and clang builtins for mixed-precision Tests are added in PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions Patch is 37.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168359.diff 6 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index d923d2a90e908..47ba12bef058c 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -401,6 +401,24 @@ def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">;
+def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+
+def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+
// Rcp
def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">;
@@ -460,6 +478,52 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
+def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
+// Sub
+
+def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 8a1cab3417d98..6f57620f0fb00 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -415,6 +415,17 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
+static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID,
+ const CallExpr *E,
+ CodeGenFunction &CGF) {
+ SmallVector<llvm::Value *, 3> Args;
+ for (unsigned i = 0; i < E->getNumArgs(); ++i) {
+ Args.push_back(CGF.EmitScalarExpr(E->getArg(i)));
+ }
+ return CGF.Builder.CreateCall(
+ CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args);
+}
+
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -1197,6 +1208,118 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count),
{EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
+ case NVPTX::BI__nvvm_add_mixed_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32,
+ E, *this);
default:
return nullptr;
}
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index e3be262622844..1753b4c7767e9 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1466,3 +1466,136 @@ __device__ void nvvm_min_max_sm86() {
#endif
// CHECK: ret void
}
+
+#define F16 (__fp16)0.1f
+#define F16_2 (__fp16)0.2f
+
+__device__ void nvvm_add_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_bf16_f32(BF16, 1.0f);
+#endif
+}
+
+__device__ void nvvm_sub_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ _...
[truncated]
|
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
|
Ping @AlexMaclean for review. |
| foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { | ||
| foreach ftz = ["", "_ftz"] in { | ||
| foreach sat = ["", "_sat"] in { | ||
| def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a motivating case for sub intrinsics? Can we just fold the add variants with fneg? Given that intrinsics are pretty easy to add and very difficult to remove, I'm generally in favor of being conservative since each intrinsic needs documentation, constant folding, ect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a DAG combine to fold fneg with the add intrinsics to sub similar to #170079. Please take a look, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the DAG combine is only needed for fp16 instructions since only those variants of neg were introduced later (and thus were being changed to xor). I've moved the pattern matching to tablegen now which looks much cleaner.
…hmetic This change adds NVVM intrinsics and clang builtins for mixed-precision FP arithmetic instructions. Tests are added in `mixed-precision-fp.ll` and `builtins-nvptx.c` and verified through `ptxas-13.0`. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions
f72ed2f to
992685c
Compare
277e36f to
e4133d2
Compare
This change adds support for mixed precision floating point
arithmetic for
f16andbf16where the following patterns:are lowered to the corresponding mixed precision instructions which
combine the conversion and operation into one instruction from
sm_100onwards.This also adds the following intrinsics to complete support for
all variants of the floating point
add/fmaoperations in orderto support the corresponding mixed-precision instructions:
llvm.nvvm.add.(rn/rz/rm/rp){.ftz}.sat.fllvm.nvvm.fma.(rn/rz/rm/rp){.ftz}.sat.fWe lower
fnegfollowed by one of the above additionintrinsics to the corresponding
subinstruction.Tests are added in
fp-arith-sat.ll,fp-fold-sub.ll, andbultins-nvptx.cfor the newly added intrinsics and builtins, and in
mixed-precision-fp.llfor the mixed precision instructions.
PTX spec reference for mixed precision instructions: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions