Skip to content

Commit 1b1ec73

Browse files
committed
address comments
1 parent 31fda74 commit 1b1ec73

File tree

3 files changed

+268
-46
lines changed

3 files changed

+268
-46
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,18 +1585,16 @@ let TargetPrefix = "nvvm" in {
15851585
//
15861586
// Sub
15871587
//
1588-
let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
1589-
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
1590-
foreach ftz = ["", "_ftz"] in {
1591-
foreach sat = ["", "_sat"] in {
1592-
def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
1593-
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
1594-
} // sat
1595-
} // ftz
1596-
def int_nvvm_sub # rnd # _d : NVVMBuiltin,
1597-
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
1598-
} // rnd
1599-
}
1588+
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
1589+
foreach ftz = ["", "_ftz"] in {
1590+
foreach sat = ["", "_sat"] in {
1591+
def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
1592+
PureIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
1593+
} // sat
1594+
} // ftz
1595+
def int_nvvm_sub # rnd # _d : NVVMBuiltin,
1596+
PureIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
1597+
} // rnd
16001598

16011599
//
16021600
// Dot Product

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,20 +1739,30 @@ defm INT_NVVM_FMA : FMA_INST;
17391739

17401740
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
17411741
foreach sat = ["", "_sat"] in {
1742-
foreach type = ["f16", "bf16"] in {
1742+
foreach type = [f16, bf16] in {
17431743
def INT_NVVM_MIXED_FMA # rnd # sat # _f32_ # type :
17441744
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
17451745
!subst("_", ".", "fma" # rnd # sat # "_f32_" # type),
17461746
[(set f32:$dst,
17471747
(!cast<Intrinsic>("int_nvvm_fma" # rnd # sat # "_f")
1748-
(f32 (fpextend !cast<ValueType>(type):$a)),
1749-
(f32 (fpextend !cast<ValueType>(type):$b)),
1748+
(f32 (fpextend type:$a)),
1749+
(f32 (fpextend type:$b)),
17501750
f32:$c))]>,
17511751
Requires<[hasSM<100>, hasPTX<86>]>;
17521752
}
17531753
}
17541754
}
17551755

1756+
// Pattern for llvm.fma.f32 intrinsic when there is no FTZ flag
1757+
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1758+
def : Pat<(f32 (fma (f32 (fpextend f16:$a)),
1759+
(f32 (fpextend f16:$b)), f32:$c)),
1760+
(INT_NVVM_MIXED_FMA_rn_f32_f16 B16:$a, B16:$b, B32:$c)>;
1761+
def : Pat<(f32 (fma (f32 (fpextend bf16:$a)),
1762+
(f32 (fpextend bf16:$b)), f32:$c)),
1763+
(INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
1764+
}
1765+
17561766
//
17571767
// Rcp
17581768
//
@@ -1875,19 +1885,28 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
18751885

18761886
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
18771887
foreach sat = ["", "_sat"] in {
1878-
foreach type = ["f16", "bf16"] in {
1888+
foreach type = [f16, bf16] in {
18791889
def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type :
18801890
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
18811891
!subst("_", ".", "add" # rnd # sat # "_f32_" # type),
18821892
[(set f32:$dst,
18831893
(!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
1884-
(f32 (fpextend !cast<ValueType>(type):$a)),
1894+
(f32 (fpextend type:$a)),
18851895
f32:$b))]>,
18861896
Requires<[hasSM<100>, hasPTX<86>]>;
18871897
}
18881898
}
18891899
}
18901900

1901+
// Pattern for fadd when there is no FTZ flag
1902+
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1903+
def : Pat<(f32 (fadd (f32 (fpextend f16:$a)), f32:$b)),
1904+
(INT_NVVM_MIXED_ADD_rn_f32_f16 B16:$a, B32:$b)>;
1905+
def : Pat<(f32 (fadd (f32 (fpextend bf16:$a)), f32:$b)),
1906+
(INT_NVVM_MIXED_ADD_rn_f32_bf16 B16:$a, B32:$b)>;
1907+
}
1908+
1909+
//
18911910
// Sub
18921911
//
18931912

@@ -1915,18 +1934,27 @@ def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>
19151934

19161935
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
19171936
foreach sat = ["", "_sat"] in {
1918-
foreach type = ["f16", "bf16"] in {
1937+
foreach type = [f16, bf16] in {
19191938
def INT_NVVM_MIXED_SUB # rnd # sat # _f32_ # type :
19201939
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
19211940
!subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
19221941
[(set f32:$dst,
19231942
(!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
1924-
(f32 (fpextend !cast<ValueType>(type):$a)),
1943+
(f32 (fpextend type:$a)),
19251944
f32:$b))]>,
19261945
Requires<[hasSM<100>, hasPTX<86>]>;
19271946
}
19281947
}
19291948
}
1949+
1950+
// Pattern for fsub when there is no FTZ flag
1951+
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1952+
def : Pat<(f32 (fsub (f32 (fpextend f16:$a)), f32:$b)),
1953+
(INT_NVVM_MIXED_SUB_rn_f32_f16 B16:$a, B32:$b)>;
1954+
def : Pat<(f32 (fsub (f32 (fpextend bf16:$a)), f32:$b)),
1955+
(INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
1956+
}
1957+
19301958
//
19311959
// BFIND
19321960
//

0 commit comments

Comments
 (0)