Skip to content

Conversation

@Wolfram70
Copy link
Contributor

@Wolfram70 Wolfram70 commented Dec 1, 2025

This change adds the following missing half-precision
add/sub/fma intrinsics for the NVPTX target:

  • llvm.nvvm.add.rn{.ftz}.sat.f16
  • llvm.nvvm.add.rn{.ftz}.sat.f16x2
  • llvm.nvvm.mul.rn{.ftz}.sat.f16
  • llvm.nvvm.mul.rn{.ftz}.sat.f16x2
  • llvm.nvvm.fma.rn.oob.*

We lower fneg followed by one of the above addition
intrinsics to the corresponding sub instruction.

This also removes some incorrect bf16 fma intrinsics with no
valid lowering.

PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions

@Wolfram70 Wolfram70 self-assigned this Dec 1, 2025
@llvmbot llvmbot added clang:frontend Language frontend issues, e.g. anything involving "Sema" backend:NVPTX llvm:ir labels Dec 1, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-ir

Author: Srinivasa Ravi (Wolfram70)

Changes

This change adds the following missing half-precision
add/sub/fma intrinsics for the NVPTX target:

  • llvm.nvvm.add.rn{.ftz}.sat.f16
  • llvm.nvvm.add.rn{.ftz}.sat.f16x2
  • llvm.nvvm.sub.rn{.ftz}.sat.f16
  • llvm.nvvm.sub.rn{.ftz}.sat.f16x2
  • llvm.nvvm.fma.rn.oob.*

This also removes some incorrect bf16 fma intrinsics with no
valid lowering.

PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions


Patch is 31.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170079.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsNVPTX.td (+27)
  • (modified) clang/test/CodeGen/builtins-nvptx.c (+64)
  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+61-8)
  • (modified) llvm/lib/IR/AutoUpgrade.cpp (-8)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+41-9)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp (-4)
  • (added) llvm/test/CodeGen/NVPTX/f16-add-sat.ll (+63)
  • (added) llvm/test/CodeGen/NVPTX/f16-mul-sat.ll (+63)
  • (added) llvm/test/CodeGen/NVPTX/f16-sub-sat.ll (+63)
  • (added) llvm/test/CodeGen/NVPTX/fma-oob.ll (+131)
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index ad448766e665f..052c20455b373 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -378,16 +378,24 @@ def __nvvm_fma_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)
 def __nvvm_fma_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_53, PTX42>;
 def __nvvm_fma_rn_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>;
 def __nvvm_fma_rn_ftz_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>;
 def __nvvm_fma_rn_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
 def __nvvm_fma_rn_ftz_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
 def __nvvm_fma_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
 def __nvvm_fma_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
 def __nvvm_fma_rn_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>;
 def __nvvm_fma_rn_ftz_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>;
 def __nvvm_fma_rn_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>;
 def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>;
 def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
 def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
+def __nvvm_fma_rn_oob_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>;
+def __nvvm_fma_rn_oob_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>;
 def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
 def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">;
 def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
@@ -446,6 +454,11 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
 
 // Add
 
+def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
 def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
 def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">;
 def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
@@ -460,6 +473,20 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
 def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
 def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
 
+// Sub
+
+def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
+// Mul
+
+def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
+def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
+
 // Convert
 
 def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index c0ed799970122..d705bcbe208d1 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -31,6 +31,9 @@
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx81 -DPTX=81 \
 // RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM80 %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx81 -DPTX=81\
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM90 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx78 -DPTX=78 \
 // RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX78_SM90 %s
@@ -1470,3 +1473,64 @@ __device__ void nvvm_min_max_sm86() {
 #endif
   // CHECK: ret void
 }
+
+#define F16 (__fp16)0.1f
+#define F16_2 (__fp16)0.2f
+#define F16X2 {(__fp16)0.1f, (__fp16)0.1f}
+#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f}
+
+// CHECK-LABEL: nvvm_add_sub_mul_f16_sat
+__device__ void nvvm_add_sub_mul_f16_sat() {
+  // CHECK: call half @llvm.nvvm.add.rn.sat.f16
+  __nvvm_add_rn_sat_f16(F16, F16_2);
+  // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16
+  __nvvm_add_rn_ftz_sat_f16(F16, F16_2);
+  // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2
+  __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2);
+  // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
+  __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+  // CHECK: call half @llvm.nvvm.sub.rn.sat.f16
+  __nvvm_sub_rn_sat_f16(F16, F16_2);
+  // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16
+  __nvvm_sub_rn_ftz_sat_f16(F16, F16_2);
+  // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2
+  __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
+  // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
+  __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+
+  // CHECK: call half @llvm.nvvm.mul.rn.sat.f16
+  __nvvm_mul_rn_sat_f16(F16, F16_2);
+  // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16
+  __nvvm_mul_rn_ftz_sat_f16(F16, F16_2);
+  // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2
+  __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2);
+  // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2
+  __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2);
+  
+  // CHECK: ret void
+}
+
+// CHECK-LABEL: nvvm_fma_oob
+__device__ void nvvm_fma_oob() {
+#if __CUDA_ARCH__ >= 900 && (PTX >= 81)
+  // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.f16
+  __nvvm_fma_rn_oob_f16(F16, F16_2, F16_2);
+  // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16
+  __nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2);
+  // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2
+  __nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2);
+  // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2
+  __nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2);
+
+  // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16
+  __nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2);
+  // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16
+  __nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2);
+  // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2
+  __nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
+  // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2
+  __nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
+#endif
+  // CHECK: ret void
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 1b485dc8ccd1e..65303ecb48dd8 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1365,16 +1365,38 @@ let TargetPrefix = "nvvm" in {
       def int_nvvm_fma_rn # ftz # variant # _f16x2 :
         PureIntrinsic<[llvm_v2f16_ty],
           [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
-
-      def int_nvvm_fma_rn # ftz # variant # _bf16 : NVVMBuiltin,
-        PureIntrinsic<[llvm_bfloat_ty],
-          [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
-
-      def int_nvvm_fma_rn # ftz # variant # _bf16x2 : NVVMBuiltin,
-        PureIntrinsic<[llvm_v2bf16_ty],
-          [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
     } // ftz
   } // variant
+  
+  foreach relu = ["", "_relu"] in { 
+    def int_nvvm_fma_rn # relu # _bf16 : NVVMBuiltin,
+      PureIntrinsic<[llvm_bfloat_ty],
+        [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
+
+    def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin,
+      PureIntrinsic<[llvm_v2bf16_ty],
+        [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+  } // relu
+
+  // oob (out-of-bounds) - clamps the result to 0 if either of the operands is 
+  // OOB NaN value.
+  foreach relu = ["", "_relu"] in {
+    def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin,
+      PureIntrinsic<[llvm_half_ty],
+        [llvm_half_ty, llvm_half_ty, llvm_half_ty]>;
+
+    def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin,
+      PureIntrinsic<[llvm_v2f16_ty],
+        [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>;
+    
+    def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin,
+      PureIntrinsic<[llvm_bfloat_ty],
+        [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>;
+
+    def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin,
+      PureIntrinsic<[llvm_v2bf16_ty],
+        [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>;
+  } // relu
 
   foreach rnd = ["rn", "rz", "rm", "rp"] in {
     foreach ftz = ["", "_ftz"] in
@@ -1442,6 +1464,15 @@ let TargetPrefix = "nvvm" in {
   //
   // Add
   //
+  foreach ftz = ["", "_ftz"] in {
+    def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin,
+      PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+    def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+      PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+        
+  } // ftz
+
   let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
     foreach rnd = ["rn", "rz", "rm", "rp"] in {
       foreach ftz = ["", "_ftz"] in
@@ -1452,6 +1483,28 @@ let TargetPrefix = "nvvm" in {
           DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
     }
   }
+  
+  //
+  // Sub
+  //
+  foreach ftz = ["", "_ftz"] in {
+    def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin,
+      PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+    def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+      PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+  } // ftz
+
+  //
+  // Mul
+  //
+  foreach ftz = ["", "_ftz"] in {
+    def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin,
+      PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
+
+    def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin,
+      PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
+  } // ftz
 
   //
   // Dot Product
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 58b7ddd0381e5..1e40242213b99 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -1098,16 +1098,8 @@ static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
     return StringSwitch<Intrinsic::ID>(Name)
         .Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
         .Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2)
-        .Case("ftz.bf16", Intrinsic::nvvm_fma_rn_ftz_bf16)
-        .Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2)
-        .Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16)
-        .Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2)
-        .Case("ftz.sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16)
-        .Case("ftz.sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2)
         .Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16)
         .Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2)
-        .Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16)
-        .Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2)
         .Default(Intrinsic::not_intrinsic);
 
   if (Name.consume_front("fmax."))
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ea69a54e6db37..57fdd4dc3c388 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1656,18 +1656,18 @@ multiclass FMA_INST {
       [hasPTX<70>, hasSM<80>]>,
     FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16,
       [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16,
+      [hasPTX<81>, hasSM<90>]>,
+    FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16,
+      [hasPTX<81>, hasSM<90>]>,
 
     FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>,
-    FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, B16,
-      [hasPTX<70>, hasSM<80>]>,
-    FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, B16,
-      [hasPTX<70>, hasSM<80>]>,
-    FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, B16,
-      [hasPTX<70>, hasSM<80>]>,
     FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16,
       [hasPTX<70>, hasSM<80>]>,
-    FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, B16,
-      [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16,
+      [hasPTX<81>, hasSM<90>]>,
+    FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16,
+      [hasPTX<81>, hasSM<90>]>,
 
     FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32,
       [hasPTX<42>, hasSM<53>]>,
@@ -1681,10 +1681,19 @@ multiclass FMA_INST {
       [hasPTX<70>, hasSM<80>]>,
     FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
       B32, [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32,
+      [hasPTX<81>, hasSM<90>]>,
+    FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32,
+      [hasPTX<81>, hasSM<90>]>,
+
     FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32,
       [hasPTX<70>, hasSM<80>]>,
     FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32,
-      [hasPTX<70>, hasSM<80>]>
+      [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32,
+      [hasPTX<81>, hasSM<90>]>,
+    FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32,
+      [hasPTX<81>, hasSM<90>]>,
   ] in {
     def P.Variant :
       F_MATH_3<!strconcat("fma", !subst("_", ".", P.Variant)),
@@ -1792,6 +1801,11 @@ let Predicates = [doRsqrtOpt] in {
 // Add
 //
 
+def INT_NVVM_ADD_RN_SAT_F16 : F_MATH_2<"add.rn.sat.f16", B16, B16, B16, int_nvvm_add_rn_sat_f16>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>;
+def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>;
+def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>;
+
 def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
 def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>;
 def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>;
@@ -1806,6 +1820,24 @@ def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>
 def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>;
 def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>;
 
+//
+// Sub
+//
+
+def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>;
+def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
+def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
+
+//
+// Mul
+//
+
+def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>;
+def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>;
+def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>;
+
 //
 // BFIND
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 64593e6439184..29a81c04395e3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -207,12 +207,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
       return {Intrinsic::fma, FTZ_MustBeOn, true};
     case Intrinsic::nvvm_fma_rn_bf16:
       return {Intrinsic::fma, FTZ_MustBeOff, true};
-    case Intrinsic::nvvm_fma_rn_ftz_bf16:
-      return {Intrinsic::fma, FTZ_MustBeOn, true};
     case Intrinsic::nvvm_fma_rn_bf16x2:
       return {Intrinsic::fma, FTZ_MustBeOff, true};
-    case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
-      return {Intrinsic::fma, FTZ_MustBeOn, true};
     case Intrinsic::nvvm_fmax_d:
       return {Intrinsic::maxnum, FTZ_Any};
     case Intrinsic::nvvm_fmax_f:
diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
new file mode 100644
index 0000000000000..a623d6e5351ab
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
+; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
+
+define half @add_rn_sat_f16(half %a, half %b) {
+; CHECK-LABEL: add_rn_sat_f16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [add_rn_sat_f16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [add_rn_sat_f16_param_1];
+; CHECK-NEXT:    add.rn.sat.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT:    ret;
+  %1 = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %b)
+  ret half %1
+}
+
+define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
+; CHECK-LABEL: add_rn_sat_f16x2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [add_rn_sat_f16x2_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [add_rn_sat_f16x2_param_1];
+; CHECK-NEXT:    add.rn.sat.f16x2 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
+  ret <2 x half> %1
+}
+
+define half @add_rn_ftz_sat_f16(half %a, half %b) {
+; CHECK-LABEL: add_rn_ftz_sat_f16(
+...
[truncated]

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can all these be made overloaded intrinsics? Also, as with #168359 I don't see a case for adding explicit sub intrinsics given that they can be represented with an fneg.

Copy link
Member

Artem-B commented Dec 2, 2025

Can all these be made overloaded intrinsics?

That may work, though the problem with overloaded intrinsics is that the set of types we overload on is somewhat awkward to control if we need intrinsics only for a subset of types. Then we need to deal with the overloads that nominally accepted, but can't be lowered.

As for minimizing the set of intrinsics, I'm OK either way, though I'm biased towards keeping fsub intrinsics around. We have plenty of intrinsics that are used approximately never. Plus, I do not have 100% confidence that with all the NVPTX-specific rounding/saturation/etc modes, add-with-quirky-corner-cases(fneg(x)) will always be equivalent of sub-with-quirky-corner-cases(a). Straightforward intrinsic-to-instruction mapping for fsub keeps things simple.

This change adds the following missing half-precision
add/sub/fma intrinsics for the NVPTX target:
- `llvm.nvvm.add.rn{.ftz}.sat.f16`
- `llvm.nvvm.add.rn{.ftz}.sat.f16x2`
- `llvm.nvvm.sub.rn{.ftz}.sat.f16`
- `llvm.nvvm.sub.rn{.ftz}.sat.f16x2`
- `llvm.nvvm.fma.rn.oob.*`

This also removes some incorrect `bf16` fma intrinsics with no
valid lowering.

PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/packed-fp-intr branch from 4765b02 to 210c875 Compare December 2, 2025 13:15
@github-actions
Copy link

github-actions bot commented Dec 2, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions c,cpp -- clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp clang/test/CodeGen/builtins-nvptx.c llvm/lib/IR/AutoUpgrade.cpp llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp --diff_from_common_commit

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 43e0e8b43..f453c34f7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6541,7 +6541,7 @@ static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
 
   SDValue SubOp1, SubOp2;
 
-  if(Op1.getOpcode() == ISD::FNEG) {
+  if (Op1.getOpcode() == ISD::FNEG) {
     SubOp1 = Op2;
     SubOp2 = Op1.getOperand(0);
   } else if (Op2.getOpcode() == ISD::FNEG) {

@Wolfram70
Copy link
Contributor Author

Wolfram70 commented Dec 2, 2025

It looks like sub is indeed equivalent to an add with a negative operand from this compiler explorer example. Since fneg is converted to an xor to flip the sign-bit for PTX versions below 6.0 (as the neg instruction was introduced in PTX 6.0) whereas sub is supported from PTX 4.2 onwards, I couldn't find a clean way to do this entirely in tablegen. If this is okay, we could do the same for #168359 as well.

As for overloading, I'm not entirely sure about it since it looks like overlapping variants with the same modifiers for the different floating point types are kind of sparse. For example, any rounding mode other than rn can't be overloaded for f16(x2) and bf16(x2) and anything with sat or ftz can't be overloaded for bf16(x2).
But on the other hand, there are also some variants like fma.rn.oob which I think could be overloaded since it supports all fp16 types. Is it okay to have only some of the intrinsic variants be overloaded and with generic names while we have other similar ones tied to a single type (which could be renamed to remove the type in the intrinsic name if we want uniformity in the naming)?

@Wolfram70 Wolfram70 changed the title [clang][NVPTX] Add missing half-precision add/sub/fma intrinsics [clang][NVPTX] Add missing half-precision add/mul/fma intrinsics Dec 2, 2025
@AlexMaclean
Copy link
Member

In general, I think PTX has lots of instructions which are essentially syntactic sugar or can easily be represented by a couple existing instructions. While they may make hand-writing PTX easier, we should probably not represent these as distinct intrinsics in LLVM IR as it will make adding peephole optimizations for these more difficult and make it harder to get a canonical form. We haven't been good about this in the past but I think it's probably smart to be more judicious about adding new intrinsics going forward.

That may work, though the problem with overloaded intrinsics is that the set of types we overload on is somewhat awkward to control if we need intrinsics only for a subset of types. Then we need to deal with the overloads that nominally accepted, but can't be lowered.

With regard to overloaded intrinsics, I think that whenever we have a case where an intrinsic supports multiple types, we should use an overloaded intrinsic. It's true this allows frontends to generate malformed IR that the verifier won't complain about but we cannot actually lower. However, I think there are already many, many cases where LLVM IR is technically valid but the NVPTX backend cannot lower it due to an unsupported SM or type. I think the implicit understand already is that creators of IR for the NVPTX backend need to be careful about what they generate and confirm it can be selected. Using overloaded intrinsics when some types are not supported seems fine within that current status quo. It would be nice if we could specify a supported set of types for overloaded intrinsics though, perhaps the intrinsic records could be extended to support something like this in the future.

As for overloading, I'm not entirely sure about it since it looks like overlapping variants with same modifiers for the different floating point types is kind of sparse. For example, any rounding mode other than rn can't be overloaded for f16(x2) and bf16(x2) and anything with sat or ftz can't be overloaded for bf16(x2).
But on the other hand, there are also some variants like fma.rn.oob which I think could be overloaded since it supports all fp16 types. Is it okay to have only some of the intrinsic variants be overloaded and with generic names while we have other similar ones tied to a single type (which could be renamed to remove the type in the intrinsic name if we want uniformity in the naming)?

I think we should try to either use an overloaded intrinsic (in which case the type will be automatically added as a suffix). Or if only one type is supported we should add a type suffix that is consistent with the suffixes used for overloading (ie v2f16 not f16x2). This way if future hardware supports more variants we can switch to an overloaded intrinsic without needing to auto-upgrade. For this MR I think only fma.rn.oob should be overloaded but the rest should use v2 suffixes.

@Artem-B
Copy link
Member

Artem-B commented Dec 2, 2025

we should probably not represent these as distinct intrinsics in LLVM IR as it will make adding peephole optimizations for these more difficult

That's a good point, but there are gray areas where compiler convenience should be balanced vs user convenience. In this case, I agree that sub variants do not buy us much.

I think there are already many, many cases where LLVM IR is technically valid but the NVPTX backend cannot lower it due to an unsupported SM or type.

I would rather work towards pushing those towards being diagnosed early on, which would give users a chance to recover and properly diagnose the issue within the user app, rather than crashing the compiler, which will be fatal for the application, often with mixed complications on debuggability front.

It would be nice if we could specify a supported set of types for overloaded intrinsics though, perhaps the intrinsic records could be extended to support something like this in the future.

Agreed. It should be possible to augment verifier so it asks back-end whether particular overload is valid. If we had that in place, then using overloads for sparsely populated parameter space would be a pretty obvious choice.

For now, I'm OK with converting to an overloaded intrinsics, despite the current shortcomings.

@llvmbot llvmbot added the clang:codegen IR generation bugs: mangling, exceptions, etc. label Dec 3, 2025
@Wolfram70
Copy link
Contributor Author

I think we should try to either use an overloaded intrinsic (in which case the type will be automatically added as a suffix). Or if only one type is supported we should add a type suffix that is consistent with the suffixes used for overloading (ie v2f16 not f16x2).

Nice, I think this is a good way to keep our options open. I've converted the fma.rn.oob variants into overloaded intrinsics and renamed the other new ones in this manner in the latest revision. Please take a look, thanks!

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVPTX/LLVM portion LGTM module a few final nits.

It would be good to update the NVPTXUsage.rst doc for any new intrinsics as well.

// Sub
//

def sub_rn_sat_node : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think node is redundant in these names

Comment on lines +1874 to +1876
//
// Mul
//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: aren't there other mul intrinsics you can put these with?

}
}

static std::optional<unsigned> getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think this needs to be optional. You should also be able to remove the return nullopt after the llvm_unreachable.

static SDValue combineIntrinsicWOChain(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const NVPTXSubtarget &STI) {
unsigned IntID = N->getConstantOperandVal(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: IID is the more common abbreviation. This one seems like it might have to do with integers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:NVPTX clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" llvm:ir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants