@@ -73,7 +73,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
7373 bool isAllOnesMask (const MachineInstr *MaskDef) const ;
7474 std::optional<unsigned > getConstant (const MachineOperand &VL) const ;
7575 bool ensureDominates (const MachineOperand &Use, MachineInstr &Src) const ;
76- bool isKnownSameDefs (Register A, Register B ) const ;
76+ Register lookThruCopies (Register Reg ) const ;
7777};
7878
7979} // namespace
@@ -387,23 +387,18 @@ bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
387387 return true ;
388388}
389389
390- bool RISCVVectorPeephole::isKnownSameDefs (Register A, Register B) const {
391- if (A.isPhysical () || B.isPhysical ())
392- return false ;
393-
394- auto LookThruVirtRegCopies = [this ](Register Reg) {
395- while (MachineInstr *Def = MRI->getUniqueVRegDef (Reg)) {
396- if (!Def->isFullCopy ())
397- break ;
398- Register Src = Def->getOperand (1 ).getReg ();
399- if (!Src.isVirtual ())
400- break ;
401- Reg = Src;
402- }
403- return Reg;
404- };
405-
406- return LookThruVirtRegCopies (A) == LookThruVirtRegCopies (B);
390+ // / If \p Reg is defined by one or more COPYs of virtual registers, traverses
391+ // / the chain and returns the root non-COPY source.
392+ Register RISCVVectorPeephole::lookThruCopies (Register Reg) const {
393+ while (MachineInstr *Def = MRI->getUniqueVRegDef (Reg)) {
394+ if (!Def->isFullCopy ())
395+ break ;
396+ Register Src = Def->getOperand (1 ).getReg ();
397+ if (!Src.isVirtual ())
398+ break ;
399+ Reg = Src;
400+ }
401+ return Reg;
407402}
408403
409404// / If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
@@ -428,10 +423,11 @@ bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
428423 if (!TrueMaskedInfo || !hasSameEEW (MI, *True))
429424 return false ;
430425
431- const MachineOperand &TrueMask =
432- True->getOperand (TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs ());
433- const MachineOperand &MIMask = MI.getOperand (4 );
434- if (!isKnownSameDefs (TrueMask.getReg (), MIMask.getReg ()))
426+ Register TrueMaskReg = lookThruCopies (
427+ True->getOperand (TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs ())
428+ .getReg ());
429+ Register MIMaskReg = lookThruCopies (MI.getOperand (4 ).getReg ());
430+ if (!TrueMaskReg.isVirtual () || TrueMaskReg != MIMaskReg)
435431 return false ;
436432
437433 // Masked off lanes past TrueVL will come from False, and converting to vmv
@@ -717,9 +713,9 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
717713 if (RISCV::getRVVMCOpcode (MI.getOpcode ()) != RISCV::VMERGE_VVM)
718714 return false ;
719715
720- Register PassthruReg = MI.getOperand (1 ).getReg ();
721- Register FalseReg = MI.getOperand (2 ).getReg ();
722- Register TrueReg = MI.getOperand (3 ).getReg ();
716+ Register PassthruReg = lookThruCopies ( MI.getOperand (1 ).getReg () );
717+ Register FalseReg = lookThruCopies ( MI.getOperand (2 ).getReg () );
718+ Register TrueReg = lookThruCopies ( MI.getOperand (3 ).getReg () );
723719 if (!TrueReg.isVirtual () || !MRI->hasOneUse (TrueReg))
724720 return false ;
725721 MachineInstr &True = *MRI->getUniqueVRegDef (TrueReg);
@@ -740,16 +736,17 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
740736
741737 // We require that either passthru and false are the same, or that passthru
742738 // is undefined.
743- if (PassthruReg && !isKnownSameDefs (PassthruReg, FalseReg))
739+ if (PassthruReg && !(PassthruReg. isVirtual () && PassthruReg == FalseReg))
744740 return false ;
745741
746742 std::optional<std::pair<unsigned , unsigned >> NeedsCommute;
747743
748744 // If True has a passthru operand then it needs to be the same as vmerge's
749745 // False, since False will be used for the result's passthru operand.
750- Register TruePassthru = True.getOperand (True.getNumExplicitDefs ()).getReg ();
746+ Register TruePassthru =
747+ lookThruCopies (True.getOperand (True.getNumExplicitDefs ()).getReg ());
751748 if (RISCVII::isFirstDefTiedToFirstUse (True.getDesc ()) && TruePassthru &&
752- !isKnownSameDefs (TruePassthru, FalseReg)) {
749+ !(TruePassthru. isVirtual () && TruePassthru == FalseReg)) {
753750 // If True's passthru != False, check if it uses False in another operand
754751 // and try to commute it.
755752 int OtherIdx = True.findRegisterUseOperandIdx (FalseReg, TRI);
0 commit comments