@@ -866,14 +866,28 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
866866 setOperationAction (ISD::UMUL_LOHI, MVT::i64 , Expand);
867867
868868 // We have some custom DAG combine patterns for these nodes
869- setTargetDAGCombine (
870- {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
871- ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
872- ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
873- ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
874- ISD::SREM, ISD::UREM, ISD::VSELECT,
875- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
876- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
869+ setTargetDAGCombine ({ISD::ADD,
870+ ISD::AND,
871+ ISD::EXTRACT_VECTOR_ELT,
872+ ISD::FADD,
873+ ISD::FMAXNUM,
874+ ISD::FMINNUM,
875+ ISD::FMAXIMUM,
876+ ISD::FMINIMUM,
877+ ISD::FMAXIMUMNUM,
878+ ISD::FMINIMUMNUM,
879+ ISD::MUL,
880+ ISD::SHL,
881+ ISD::SREM,
882+ ISD::UREM,
883+ ISD::VSELECT,
884+ ISD::BUILD_VECTOR,
885+ ISD::ADDRSPACECAST,
886+ ISD::LOAD,
887+ ISD::STORE,
888+ ISD::ZERO_EXTEND,
889+ ISD::SIGN_EXTEND,
890+ ISD::INTRINSIC_WO_CHAIN});
877891
878892 // setcc for f16x2 and bf16x2 needs special handling to prevent
879893 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -6504,6 +6518,143 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
65046518 }
65056519}
65066520
6521+ static std::optional<unsigned > getSubF32Opc (Intrinsic::ID AddIntrinsicID) {
6522+ switch (AddIntrinsicID) {
6523+ default :
6524+ break ;
6525+ case Intrinsic::nvvm_add_rn_f:
6526+ return NVPTXISD::SUB_RN_F;
6527+ case Intrinsic::nvvm_add_rn_sat_f:
6528+ return NVPTXISD::SUB_RN_SAT_F;
6529+ case Intrinsic::nvvm_add_rn_ftz_f:
6530+ return NVPTXISD::SUB_RN_FTZ_F;
6531+ case Intrinsic::nvvm_add_rn_ftz_sat_f:
6532+ return NVPTXISD::SUB_RN_FTZ_SAT_F;
6533+ case Intrinsic::nvvm_add_rz_f:
6534+ return NVPTXISD::SUB_RZ_F;
6535+ case Intrinsic::nvvm_add_rz_sat_f:
6536+ return NVPTXISD::SUB_RZ_SAT_F;
6537+ case Intrinsic::nvvm_add_rz_ftz_f:
6538+ return NVPTXISD::SUB_RZ_FTZ_F;
6539+ case Intrinsic::nvvm_add_rz_ftz_sat_f:
6540+ return NVPTXISD::SUB_RZ_FTZ_SAT_F;
6541+ case Intrinsic::nvvm_add_rm_f:
6542+ return NVPTXISD::SUB_RM_F;
6543+ case Intrinsic::nvvm_add_rm_sat_f:
6544+ return NVPTXISD::SUB_RM_SAT_F;
6545+ case Intrinsic::nvvm_add_rm_ftz_f:
6546+ return NVPTXISD::SUB_RM_FTZ_F;
6547+ case Intrinsic::nvvm_add_rm_ftz_sat_f:
6548+ return NVPTXISD::SUB_RM_FTZ_SAT_F;
6549+ case Intrinsic::nvvm_add_rp_f:
6550+ return NVPTXISD::SUB_RP_F;
6551+ case Intrinsic::nvvm_add_rp_sat_f:
6552+ return NVPTXISD::SUB_RP_SAT_F;
6553+ case Intrinsic::nvvm_add_rp_ftz_f:
6554+ return NVPTXISD::SUB_RP_FTZ_F;
6555+ case Intrinsic::nvvm_add_rp_ftz_sat_f:
6556+ return NVPTXISD::SUB_RP_FTZ_SAT_F;
6557+ }
6558+ llvm_unreachable (" Invalid add intrinsic ID" );
6559+ return std::nullopt ;
6560+ }
6561+
6562+ static std::optional<unsigned > getSubF64Opc (Intrinsic::ID AddIntrinsicID) {
6563+ switch (AddIntrinsicID) {
6564+ default :
6565+ return std::nullopt ;
6566+ case Intrinsic::nvvm_add_rn_d:
6567+ return NVPTXISD::SUB_RN_D;
6568+ case Intrinsic::nvvm_add_rz_d:
6569+ return NVPTXISD::SUB_RZ_D;
6570+ case Intrinsic::nvvm_add_rm_d:
6571+ return NVPTXISD::SUB_RM_D;
6572+ case Intrinsic::nvvm_add_rp_d:
6573+ return NVPTXISD::SUB_RP_D;
6574+ }
6575+ llvm_unreachable (" Invalid add intrinsic ID" );
6576+ return std::nullopt ;
6577+ }
6578+
6579+ static SDValue combineF32AddWithNeg (SDNode *N, SelectionDAG &DAG,
6580+ Intrinsic::ID AddIntrinsicID,
6581+ unsigned PTXVersion, unsigned SmVersion) {
6582+ SDValue Op2 = N->getOperand (2 );
6583+
6584+ if (Op2.getOpcode () != ISD::FNEG)
6585+ return SDValue ();
6586+
6587+ // If PTX > 8.6 and SM >= 100, when Op1 is a fpextend from f16 or bf16, don't
6588+ // fold this pattern as this will be folded to a mixed precision instruction
6589+ // later on.
6590+ SDValue Op1 = N->getOperand (1 );
6591+ if (PTXVersion >= 86 && SmVersion >= 100 &&
6592+ Op1.getOpcode () == ISD::FP_EXTEND) {
6593+ if (Op1.getOperand (0 ).getSimpleValueType () == MVT::f16 ||
6594+ Op1.getOperand (0 ).getSimpleValueType () == MVT::bf16 )
6595+ return SDValue ();
6596+ }
6597+
6598+ std::optional<unsigned > Opc = getSubF32Opc (AddIntrinsicID);
6599+ if (!Opc)
6600+ return SDValue ();
6601+
6602+ SDLoc DL (N);
6603+ return DAG.getNode (*Opc, DL, N->getValueType (0 ), N->getOperand (1 ),
6604+ Op2.getOperand (0 ));
6605+ }
6606+
6607+ static SDValue combineF64AddWithNeg (SDNode *N, SelectionDAG &DAG,
6608+ Intrinsic::ID AddIntrinsicID) {
6609+ SDValue Op2 = N->getOperand (2 );
6610+
6611+ if (Op2.getOpcode () != ISD::FNEG)
6612+ return SDValue ();
6613+
6614+ std::optional<unsigned > Opc = getSubF64Opc (AddIntrinsicID);
6615+ if (!Opc)
6616+ return SDValue ();
6617+
6618+ SDLoc DL (N);
6619+ return DAG.getNode (*Opc, DL, N->getValueType (0 ), N->getOperand (1 ),
6620+ Op2.getOperand (0 ));
6621+ }
6622+
6623+ static SDValue combineIntrinsicWOChain (SDNode *N,
6624+ TargetLowering::DAGCombinerInfo &DCI,
6625+ const NVPTXSubtarget &STI) {
6626+ unsigned IntID = N->getConstantOperandVal (0 );
6627+
6628+ switch (IntID) {
6629+ default :
6630+ break ;
6631+ case Intrinsic::nvvm_add_rn_f:
6632+ case Intrinsic::nvvm_add_rn_sat_f:
6633+ case Intrinsic::nvvm_add_rn_ftz_f:
6634+ case Intrinsic::nvvm_add_rn_ftz_sat_f:
6635+ case Intrinsic::nvvm_add_rz_f:
6636+ case Intrinsic::nvvm_add_rz_sat_f:
6637+ case Intrinsic::nvvm_add_rz_ftz_f:
6638+ case Intrinsic::nvvm_add_rz_ftz_sat_f:
6639+ case Intrinsic::nvvm_add_rm_f:
6640+ case Intrinsic::nvvm_add_rm_sat_f:
6641+ case Intrinsic::nvvm_add_rm_ftz_f:
6642+ case Intrinsic::nvvm_add_rm_ftz_sat_f:
6643+ case Intrinsic::nvvm_add_rp_f:
6644+ case Intrinsic::nvvm_add_rp_sat_f:
6645+ case Intrinsic::nvvm_add_rp_ftz_f:
6646+ case Intrinsic::nvvm_add_rp_ftz_sat_f:
6647+ return combineF32AddWithNeg (N, DCI.DAG , IntID, STI.getPTXVersion (),
6648+ STI.getSmVersion ());
6649+ case Intrinsic::nvvm_add_rn_d:
6650+ case Intrinsic::nvvm_add_rz_d:
6651+ case Intrinsic::nvvm_add_rm_d:
6652+ case Intrinsic::nvvm_add_rp_d:
6653+ return combineF64AddWithNeg (N, DCI.DAG , IntID);
6654+ }
6655+ return SDValue ();
6656+ }
6657+
65076658static SDValue combineProxyReg (SDNode *N,
65086659 TargetLowering::DAGCombinerInfo &DCI) {
65096660
@@ -6570,6 +6721,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
65706721 return combineSTORE (N, DCI, STI);
65716722 case ISD::VSELECT:
65726723 return PerformVSELECTCombine (N, DCI);
6724+ case ISD::INTRINSIC_WO_CHAIN:
6725+ return combineIntrinsicWOChain (N, DCI, STI);
65736726 }
65746727 return SDValue ();
65756728}
0 commit comments