Skip to content

Commit 210c875

Browse files
committed
fold add with fneg to sub
1 parent 22fb84a commit 210c875

File tree

6 files changed

+79
-40
lines changed

6 files changed

+79
-40
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,13 +473,6 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
473473
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
474474
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
475475

476-
// Sub
477-
478-
def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
479-
def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;
480-
def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
481-
def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>;
482-
483476
// Mul
484477

485478
def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>;

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,15 +1539,6 @@ __device__ void nvvm_add_sub_mul_f16_sat() {
15391539
// CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2
15401540
__nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2);
15411541

1542-
// CHECK: call half @llvm.nvvm.sub.rn.sat.f16
1543-
__nvvm_sub_rn_sat_f16(F16, F16_2);
1544-
// CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16
1545-
__nvvm_sub_rn_ftz_sat_f16(F16, F16_2);
1546-
// CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2
1547-
__nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2);
1548-
// CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2
1549-
__nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2);
1550-
15511542
// CHECK: call half @llvm.nvvm.mul.rn.sat.f16
15521543
__nvvm_mul_rn_sat_f16(F16, F16_2);
15531544
// CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,17 +1608,6 @@ let TargetPrefix = "nvvm" in {
16081608
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
16091609
}
16101610
}
1611-
1612-
//
1613-
// Sub
1614-
//
1615-
foreach ftz = ["", "_ftz"] in {
1616-
def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin,
1617-
PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>;
1618-
1619-
def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin,
1620-
PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>;
1621-
} // ftz
16221611

16231612
//
16241613
// Mul

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
873873
ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
874874
ISD::SREM, ISD::UREM, ISD::VSELECT,
875875
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
876-
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
876+
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
877+
ISD::INTRINSIC_WO_CHAIN});
877878

878879
// setcc for f16x2 and bf16x2 needs special handling to prevent
879880
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -6504,6 +6505,38 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
65046505
}
65056506
}
65066507

6508+
// Combine add.sat(a, fneg(b)) -> sub.sat(a, b)
6509+
static SDValue combineAddSatWithNeg(SDNode *N, SelectionDAG &DAG,
6510+
unsigned SubOpc) {
6511+
SDValue Op2 = N->getOperand(2);
6512+
6513+
if (Op2.getOpcode() != ISD::FNEG)
6514+
return SDValue();
6515+
6516+
SDLoc DL(N);
6517+
return DAG.getNode(SubOpc, DL, N->getValueType(0), N->getOperand(1),
6518+
Op2.getOperand(0));
6519+
}
6520+
6521+
static SDValue combineIntrinsicWOChain(SDNode *N,
6522+
TargetLowering::DAGCombinerInfo &DCI,
6523+
const NVPTXSubtarget &STI) {
6524+
unsigned IntID = N->getConstantOperandVal(0);
6525+
6526+
switch (IntID) {
6527+
case Intrinsic::nvvm_add_rn_sat_f16:
6528+
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16);
6529+
case Intrinsic::nvvm_add_rn_ftz_sat_f16:
6530+
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16);
6531+
case Intrinsic::nvvm_add_rn_sat_f16x2:
6532+
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2);
6533+
case Intrinsic::nvvm_add_rn_ftz_sat_f16x2:
6534+
return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2);
6535+
default:
6536+
return SDValue();
6537+
}
6538+
}
6539+
65076540
static SDValue combineProxyReg(SDNode *N,
65086541
TargetLowering::DAGCombinerInfo &DCI) {
65096542

@@ -6570,6 +6603,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
65706603
return combineSTORE(N, DCI, STI);
65716604
case ISD::VSELECT:
65726605
return PerformVSELECTCombine(N, DCI);
6606+
case ISD::INTRINSIC_WO_CHAIN:
6607+
return combineIntrinsicWOChain(N, DCI, STI);
65736608
}
65746609
return SDValue();
65756610
}

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,10 +1859,34 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
18591859
// Sub
18601860
//
18611861

1862-
def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>;
1863-
def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>;
1864-
def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>;
1865-
def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>;
1862+
def SUB_RN_SAT_F16_NODE : SDNode<"NVPTXISD::SUB_RN_SAT_F16", SDTFPBinOp>;
1863+
def SUB_RN_FTZ_SAT_F16_NODE :
1864+
SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16", SDTFPBinOp>;
1865+
def SUB_RN_SAT_F16X2_NODE :
1866+
SDNode<"NVPTXISD::SUB_RN_SAT_F16X2", SDTFPBinOp>;
1867+
def SUB_RN_FTZ_SAT_F16X2_NODE :
1868+
SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16X2", SDTFPBinOp>;
1869+
1870+
def INT_NVVM_SUB_RN_SAT_F16 :
1871+
BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
1872+
"sub.rn.sat.f16",
1873+
[(set f16:$dst, (SUB_RN_SAT_F16_NODE f16:$a, f16:$b))]>;
1874+
1875+
def INT_NVVM_SUB_RN_FTZ_SAT_F16 :
1876+
BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b),
1877+
"sub.rn.ftz.sat.f16",
1878+
[(set f16:$dst, (SUB_RN_FTZ_SAT_F16_NODE f16:$a, f16:$b))]>;
1879+
1880+
def INT_NVVM_SUB_RN_SAT_F16X2 :
1881+
BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
1882+
"sub.rn.sat.f16x2",
1883+
[(set v2f16:$dst, (SUB_RN_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
1884+
1885+
def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 :
1886+
BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
1887+
"sub.rn.ftz.sat.f16x2",
1888+
[(set v2f16:$dst, (SUB_RN_FTZ_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>;
1889+
18661890

18671891
//
18681892
// Mul
@@ -6154,3 +6178,4 @@ foreach sp = [0, 1] in {
61546178
}
61556179
}
61566180
}
6181+

llvm/test/CodeGen/NVPTX/f16-sub-sat.ll

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
22
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s
3+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | FileCheck %s
34
; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%}
5+
; RUN: %if ptxas-isa-6.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | %ptxas-verify%}
46

57
define half @sub_rn_sat_f16(half %a, half %b) {
68
; CHECK-LABEL: sub_rn_sat_f16(
@@ -13,8 +15,9 @@ define half @sub_rn_sat_f16(half %a, half %b) {
1315
; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2;
1416
; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
1517
; CHECK-NEXT: ret;
16-
%1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b)
17-
ret half %1
18+
%1 = fneg half %b
19+
%res = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %1)
20+
ret half %res
1821
}
1922

2023
define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
@@ -28,8 +31,9 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) {
2831
; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2;
2932
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
3033
; CHECK-NEXT: ret;
31-
%1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b)
32-
ret <2 x half> %1
34+
%1 = fneg <2 x half> %b
35+
%res = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %1)
36+
ret <2 x half> %res
3337
}
3438

3539
define half @sub_rn_ftz_sat_f16(half %a, half %b) {
@@ -43,8 +47,9 @@ define half @sub_rn_ftz_sat_f16(half %a, half %b) {
4347
; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2;
4448
; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
4549
; CHECK-NEXT: ret;
46-
%1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b)
47-
ret half %1
50+
%1 = fneg half %b
51+
%res = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %1)
52+
ret half %res
4853
}
4954

5055
define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
@@ -58,6 +63,7 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) {
5863
; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2;
5964
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
6065
; CHECK-NEXT: ret;
61-
%1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b)
62-
ret <2 x half> %1
66+
%1 = fneg <2 x half> %b
67+
%res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %1)
68+
ret <2 x half> %res
6369
}

0 commit comments

Comments
 (0)