@@ -1704,20 +1704,30 @@ defm INT_NVVM_FMA : FMA_INST;
17041704
17051705foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
17061706 foreach sat = ["", "_sat"] in {
1707- foreach type = [" f16", " bf16" ] in {
1707+ foreach type = [f16, bf16] in {
17081708 def INT_NVVM_MIXED_FMA # rnd # sat # _f32_ # type :
17091709 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
17101710 !subst("_", ".", "fma" # rnd # sat # "_f32_" # type),
17111711 [(set f32:$dst,
17121712 (!cast<Intrinsic>("int_nvvm_fma" # rnd # sat # "_f")
1713- (f32 (fpextend !cast<ValueType>( type) :$a)),
1714- (f32 (fpextend !cast<ValueType>( type) :$b)),
1713+ (f32 (fpextend type:$a)),
1714+ (f32 (fpextend type:$b)),
17151715 f32:$c))]>,
17161716 Requires<[hasSM<100>, hasPTX<86>]>;
17171717 }
17181718 }
17191719}
17201720
1721+ // Pattern for llvm.fma.f32 intrinsic when there is no FTZ flag
1722+ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1723+ def : Pat<(f32 (fma (f32 (fpextend f16:$a)),
1724+ (f32 (fpextend f16:$b)), f32:$c)),
1725+ (INT_NVVM_MIXED_FMA_rn_f32_f16 B16:$a, B16:$b, B32:$c)>;
1726+ def : Pat<(f32 (fma (f32 (fpextend bf16:$a)),
1727+ (f32 (fpextend bf16:$b)), f32:$c)),
1728+ (INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
1729+ }
1730+
17211731//
17221732// Rcp
17231733//
@@ -1840,19 +1850,28 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
18401850
18411851foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
18421852 foreach sat = ["", "_sat"] in {
1843- foreach type = [" f16", " bf16" ] in {
1853+ foreach type = [f16, bf16] in {
18441854 def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type :
18451855 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
18461856 !subst("_", ".", "add" # rnd # sat # "_f32_" # type),
18471857 [(set f32:$dst,
18481858 (!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
1849- (f32 (fpextend !cast<ValueType>( type) :$a)),
1859+ (f32 (fpextend type:$a)),
18501860 f32:$b))]>,
18511861 Requires<[hasSM<100>, hasPTX<86>]>;
18521862 }
18531863 }
18541864}
18551865
1866+ // Pattern for fadd when there is no FTZ flag
1867+ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1868+ def : Pat<(f32 (fadd (f32 (fpextend f16:$a)), f32:$b)),
1869+ (INT_NVVM_MIXED_ADD_rn_f32_f16 B16:$a, B32:$b)>;
1870+ def : Pat<(f32 (fadd (f32 (fpextend bf16:$a)), f32:$b)),
1871+ (INT_NVVM_MIXED_ADD_rn_f32_bf16 B16:$a, B32:$b)>;
1872+ }
1873+
1874+ //
18561875// Sub
18571876//
18581877
@@ -1880,18 +1899,27 @@ def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>
18801899
18811900foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
18821901 foreach sat = ["", "_sat"] in {
1883- foreach type = [" f16", " bf16" ] in {
1902+ foreach type = [f16, bf16] in {
18841903 def INT_NVVM_MIXED_SUB # rnd # sat # _f32_ # type :
18851904 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
18861905 !subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
18871906 [(set f32:$dst,
18881907 (!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
1889- (f32 (fpextend !cast<ValueType>( type) :$a)),
1908+ (f32 (fpextend type:$a)),
18901909 f32:$b))]>,
18911910 Requires<[hasSM<100>, hasPTX<86>]>;
18921911 }
18931912 }
18941913}
1914+
1915+ // Pattern for fsub when there is no FTZ flag
1916+ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1917+ def : Pat<(f32 (fsub (f32 (fpextend f16:$a)), f32:$b)),
1918+ (INT_NVVM_MIXED_SUB_rn_f32_f16 B16:$a, B32:$b)>;
1919+ def : Pat<(f32 (fsub (f32 (fpextend bf16:$a)), f32:$b)),
1920+ (INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
1921+ }
1922+
18951923//
18961924// BFIND
18971925//
0 commit comments