@@ -1739,20 +1739,30 @@ defm INT_NVVM_FMA : FMA_INST;
17391739
17401740foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
17411741 foreach sat = ["", "_sat"] in {
1742- foreach type = [" f16", " bf16" ] in {
1742+ foreach type = [f16, bf16] in {
17431743 def INT_NVVM_MIXED_FMA # rnd # sat # _f32_ # type :
17441744 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
17451745 !subst("_", ".", "fma" # rnd # sat # "_f32_" # type),
17461746 [(set f32:$dst,
17471747 (!cast<Intrinsic>("int_nvvm_fma" # rnd # sat # "_f")
1748- (f32 (fpextend !cast<ValueType>( type) :$a)),
1749- (f32 (fpextend !cast<ValueType>( type) :$b)),
1748+ (f32 (fpextend type:$a)),
1749+ (f32 (fpextend type:$b)),
17501750 f32:$c))]>,
17511751 Requires<[hasSM<100>, hasPTX<86>]>;
17521752 }
17531753 }
17541754}
17551755
1756+ // Pattern for llvm.fma.f32 intrinsic when there is no FTZ flag
1757+ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1758+ def : Pat<(f32 (fma (f32 (fpextend f16:$a)),
1759+ (f32 (fpextend f16:$b)), f32:$c)),
1760+ (INT_NVVM_MIXED_FMA_rn_f32_f16 B16:$a, B16:$b, B32:$c)>;
1761+ def : Pat<(f32 (fma (f32 (fpextend bf16:$a)),
1762+ (f32 (fpextend bf16:$b)), f32:$c)),
1763+ (INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
1764+ }
1765+
17561766//
17571767// Rcp
17581768//
@@ -1875,19 +1885,28 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>
18751885
18761886foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
18771887 foreach sat = ["", "_sat"] in {
1878- foreach type = [" f16", " bf16" ] in {
1888+ foreach type = [f16, bf16] in {
18791889 def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type :
18801890 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
18811891 !subst("_", ".", "add" # rnd # sat # "_f32_" # type),
18821892 [(set f32:$dst,
18831893 (!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
1884- (f32 (fpextend !cast<ValueType>( type) :$a)),
1894+ (f32 (fpextend type:$a)),
18851895 f32:$b))]>,
18861896 Requires<[hasSM<100>, hasPTX<86>]>;
18871897 }
18881898 }
18891899}
18901900
1901+ // Pattern for fadd when there is no FTZ flag
1902+ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1903+ def : Pat<(f32 (fadd (f32 (fpextend f16:$a)), f32:$b)),
1904+ (INT_NVVM_MIXED_ADD_rn_f32_f16 B16:$a, B32:$b)>;
1905+ def : Pat<(f32 (fadd (f32 (fpextend bf16:$a)), f32:$b)),
1906+ (INT_NVVM_MIXED_ADD_rn_f32_bf16 B16:$a, B32:$b)>;
1907+ }
1908+
1909+ //
18911910// Sub
18921911//
18931912
@@ -1915,18 +1934,27 @@ def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>
19151934
19161935foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
19171936 foreach sat = ["", "_sat"] in {
1918- foreach type = [" f16", " bf16" ] in {
1937+ foreach type = [f16, bf16] in {
19191938 def INT_NVVM_MIXED_SUB # rnd # sat # _f32_ # type :
19201939 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
19211940 !subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
19221941 [(set f32:$dst,
19231942 (!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
1924- (f32 (fpextend !cast<ValueType>( type) :$a)),
1943+ (f32 (fpextend type:$a)),
19251944 f32:$b))]>,
19261945 Requires<[hasSM<100>, hasPTX<86>]>;
19271946 }
19281947 }
19291948}
1949+
1950+ // Pattern for fsub when there is no FTZ flag
1951+ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
1952+ def : Pat<(f32 (fsub (f32 (fpextend f16:$a)), f32:$b)),
1953+ (INT_NVVM_MIXED_SUB_rn_f32_f16 B16:$a, B32:$b)>;
1954+ def : Pat<(f32 (fsub (f32 (fpextend bf16:$a)), f32:$b)),
1955+ (INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
1956+ }
1957+
19301958//
19311959// BFIND
19321960//
0 commit comments