Skip to content

Commit f72ed2f

Browse files
committed
address comments
1 parent 98876d8 commit f72ed2f

File tree

3 files changed

+268
-46
lines changed

3 files changed

+268
-46
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,18 +1460,16 @@ let TargetPrefix = "nvvm" in {
14601460
//
14611461
// Sub
14621462
//
1463-
let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
1464-
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
1465-
foreach ftz = ["", "_ftz"] in {
1466-
foreach sat = ["", "_sat"] in {
1467-
def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
1468-
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
1469-
} // sat
1470-
} // ftz
1471-
def int_nvvm_sub # rnd # _d : NVVMBuiltin,
1472-
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
1473-
} // rnd
1474-
}
1463+
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
1464+
foreach ftz = ["", "_ftz"] in {
1465+
foreach sat = ["", "_sat"] in {
1466+
def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
1467+
PureIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
1468+
} // sat
1469+
} // ftz
1470+
def int_nvvm_sub # rnd # _d : NVVMBuiltin,
1471+
PureIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
1472+
} // rnd
14751473

14761474
//
14771475
// Dot Product

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,20 +1704,30 @@ defm INT_NVVM_FMA : FMA_INST;
17041704

17051705
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
17061706
foreach sat = ["", "_sat"] in {
1707-
foreach type = ["f16", "bf16"] in {
1707+
foreach type = [f16, bf16] in {
17081708
def INT_NVVM_MIXED_FMA # rnd # sat # _f32_ # type :
17091709
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
17101710
!subst("_", ".", "fma" # rnd # sat # "_f32_" # type),
17111711
[(set f32:$dst,
17121712
(!cast<Intrinsic>("int_nvvm_fma" # rnd # sat # "_f")
1713-
(f32 (fpextend !cast<ValueType>(type):$a)),
1714-
(f32 (fpextend !cast<ValueType>(type):$b)),
1713+
(f32 (fpextend type:$a)),
1714+
(f32 (fpextend type:$b)),
17151715
f32:$c))]>,
17161716
Requires<[hasSM<100>, hasPTX<86>]>;
17171717
}
17181718
}
17191719
}
17201720

1721+
// Pattern for llvm.fma.f32 intrinsic when there is no FTZ flag
1722+
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1723+
def : Pat<(f32 (fma (f32 (fpextend f16:$a)),
1724+
(f32 (fpextend f16:$b)), f32:$c)),
1725+
(INT_NVVM_MIXED_FMA_rn_f32_f16 B16:$a, B16:$b, B32:$c)>;
1726+
def : Pat<(f32 (fma (f32 (fpextend bf16:$a)),
1727+
(f32 (fpextend bf16:$b)), f32:$c)),
1728+
(INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
1729+
}
1730+
17211731
//
17221732
// Rcp
17231733
//
@@ -1840,19 +1850,28 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
18401850

18411851
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
18421852
foreach sat = ["", "_sat"] in {
1843-
foreach type = ["f16", "bf16"] in {
1853+
foreach type = [f16, bf16] in {
18441854
def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type :
18451855
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
18461856
!subst("_", ".", "add" # rnd # sat # "_f32_" # type),
18471857
[(set f32:$dst,
18481858
(!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
1849-
(f32 (fpextend !cast<ValueType>(type):$a)),
1859+
(f32 (fpextend type:$a)),
18501860
f32:$b))]>,
18511861
Requires<[hasSM<100>, hasPTX<86>]>;
18521862
}
18531863
}
18541864
}
18551865

1866+
// Pattern for fadd when there is no FTZ flag
1867+
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1868+
def : Pat<(f32 (fadd (f32 (fpextend f16:$a)), f32:$b)),
1869+
(INT_NVVM_MIXED_ADD_rn_f32_f16 B16:$a, B32:$b)>;
1870+
def : Pat<(f32 (fadd (f32 (fpextend bf16:$a)), f32:$b)),
1871+
(INT_NVVM_MIXED_ADD_rn_f32_bf16 B16:$a, B32:$b)>;
1872+
}
1873+
1874+
//
18561875
// Sub
18571876
//
18581877

@@ -1880,18 +1899,27 @@ def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>
18801899

18811900
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
18821901
foreach sat = ["", "_sat"] in {
1883-
foreach type = ["f16", "bf16"] in {
1902+
foreach type = [f16, bf16] in {
18841903
def INT_NVVM_MIXED_SUB # rnd # sat # _f32_ # type :
18851904
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
18861905
!subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
18871906
[(set f32:$dst,
18881907
(!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
1889-
(f32 (fpextend !cast<ValueType>(type):$a)),
1908+
(f32 (fpextend type:$a)),
18901909
f32:$b))]>,
18911910
Requires<[hasSM<100>, hasPTX<86>]>;
18921911
}
18931912
}
18941913
}
1914+
1915+
// Pattern for fsub when there is no FTZ flag
1916+
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1917+
def : Pat<(f32 (fsub (f32 (fpextend f16:$a)), f32:$b)),
1918+
(INT_NVVM_MIXED_SUB_rn_f32_f16 B16:$a, B32:$b)>;
1919+
def : Pat<(f32 (fsub (f32 (fpextend bf16:$a)), f32:$b)),
1920+
(INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
1921+
}
1922+
18951923
//
18961924
// BFIND
18971925
//

0 commit comments

Comments
 (0)