@@ -73,7 +73,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
7373 bool isAllOnesMask (const MachineInstr *MaskDef) const ;
7474 std::optional<unsigned > getConstant (const MachineOperand &VL) const ;
7575 bool ensureDominates (const MachineOperand &Use, MachineInstr &Src) const ;
76- bool isKnownSameDefs (Register A, Register B ) const ;
76+ Register lookThruCopies (Register Reg ) const ;
7777};
7878
7979} // namespace
@@ -387,23 +387,18 @@ bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
387387 return true ;
388388}
389389
390- bool RISCVVectorPeephole::isKnownSameDefs (Register A, Register B) const {
391- if (A.isPhysical () || B.isPhysical ())
392- return false ;
393-
394- auto LookThruVirtRegCopies = [this ](Register Reg) {
395- while (MachineInstr *Def = MRI->getUniqueVRegDef (Reg)) {
396- if (!Def->isFullCopy ())
397- break ;
398- Register Src = Def->getOperand (1 ).getReg ();
399- if (!Src.isVirtual ())
400- break ;
401- Reg = Src;
402- }
403- return Reg;
404- };
405-
406- return LookThruVirtRegCopies (A) == LookThruVirtRegCopies (B);
390+ // If \p Reg is defined by one or more COPYs of virtual registers, traverses
391+ // / the chain and returns the root non-COPY source.
392+ Register RISCVVectorPeephole::lookThruCopies (Register Reg) const {
393+ while (MachineInstr *Def = MRI->getUniqueVRegDef (Reg)) {
394+ if (!Def->isFullCopy ())
395+ break ;
396+ Register Src = Def->getOperand (1 ).getReg ();
397+ if (!Src.isVirtual ())
398+ break ;
399+ Reg = Src;
400+ }
401+ return Reg;
407402}
408403
409404// / If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
@@ -421,23 +416,18 @@ bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
421416 return false ;
422417 MachineInstr *True = MRI->getVRegDef (MI.getOperand (3 ).getReg ());
423418
424- // Peek through COPY.
425- if (True && True->isCopy ()) {
426- if (Register TrueReg = True->getOperand (1 ).getReg (); TrueReg.isVirtual ())
427- True = MRI->getVRegDef (TrueReg);
428- }
429-
430419 if (!True || True->getParent () != MI.getParent ())
431420 return false ;
432421
433422 auto *TrueMaskedInfo = RISCV::getMaskedPseudoInfo (True->getOpcode ());
434423 if (!TrueMaskedInfo || !hasSameEEW (MI, *True))
435424 return false ;
436425
437- const MachineOperand &TrueMask =
438- True->getOperand (TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs ());
439- const MachineOperand &MIMask = MI.getOperand (4 );
440- if (!isKnownSameDefs (TrueMask.getReg (), MIMask.getReg ()))
426+ Register TrueMaskReg = lookThruCopies (
427+ True->getOperand (TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs ())
428+ .getReg ());
429+ Register MIMaskReg = lookThruCopies (MI.getOperand (4 ).getReg ());
430+ if (!TrueMaskReg.isVirtual () || TrueMaskReg != MIMaskReg)
441431 return false ;
442432
443433 // Masked off lanes past TrueVL will come from False, and converting to vmv
@@ -723,20 +713,12 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
723713 if (RISCV::getRVVMCOpcode (MI.getOpcode ()) != RISCV::VMERGE_VVM)
724714 return false ;
725715
726- Register PassthruReg = MI.getOperand (1 ).getReg ();
727- Register FalseReg = MI.getOperand (2 ).getReg ();
728- Register TrueReg = MI.getOperand (3 ).getReg ();
716+ Register PassthruReg = lookThruCopies ( MI.getOperand (1 ).getReg () );
717+ Register FalseReg = lookThruCopies ( MI.getOperand (2 ).getReg () );
718+ Register TrueReg = lookThruCopies ( MI.getOperand (3 ).getReg () );
729719 if (!TrueReg.isVirtual () || !MRI->hasOneUse (TrueReg))
730720 return false ;
731- MachineInstr *TrueMI = MRI->getUniqueVRegDef (TrueReg);
732- // Peek through COPY.
733- if (TrueMI->isCopy ()) {
734- if (TrueReg = TrueMI->getOperand (1 ).getReg ();
735- TrueReg.isVirtual () && MRI->hasOneUse (TrueReg))
736- TrueMI = MRI->getVRegDef (TrueReg);
737- }
738-
739- MachineInstr &True = *TrueMI;
721+ MachineInstr &True = *MRI->getUniqueVRegDef (TrueReg);
740722 if (True.getParent () != MI.getParent ())
741723 return false ;
742724 const MachineOperand &MaskOp = MI.getOperand (4 );
@@ -754,16 +736,17 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
754736
755737 // We require that either passthru and false are the same, or that passthru
756738 // is undefined.
757- if (PassthruReg && !isKnownSameDefs (PassthruReg, FalseReg))
739+ if (PassthruReg && !(PassthruReg. isVirtual () && PassthruReg == FalseReg))
758740 return false ;
759741
760742 std::optional<std::pair<unsigned , unsigned >> NeedsCommute;
761743
762744 // If True has a passthru operand then it needs to be the same as vmerge's
763745 // False, since False will be used for the result's passthru operand.
764- Register TruePassthru = True.getOperand (True.getNumExplicitDefs ()).getReg ();
746+ Register TruePassthru =
747+ lookThruCopies (True.getOperand (True.getNumExplicitDefs ()).getReg ());
765748 if (RISCVII::isFirstDefTiedToFirstUse (True.getDesc ()) && TruePassthru &&
766- !isKnownSameDefs (TruePassthru, FalseReg)) {
749+ !(TruePassthru. isVirtual () && TruePassthru == FalseReg)) {
767750 // If True's passthru != False, check if it uses False in another operand
768751 // and try to commute it.
769752 int OtherIdx = True.findRegisterUseOperandIdx (FalseReg, TRI);
0 commit comments