Skip to content

Commit 6238b99

Browse files
committed
Cherry-pick Luke's changes
1 parent 5f32862 commit 6238b99

File tree

9 files changed

+428
-635
lines changed

9 files changed

+428
-635
lines changed

llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
7373
bool isAllOnesMask(const MachineInstr *MaskDef) const;
7474
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
7575
bool ensureDominates(const MachineOperand &Use, MachineInstr &Src) const;
76-
bool isKnownSameDefs(Register A, Register B) const;
76+
Register lookThruCopies(Register Reg) const;
7777
};
7878

7979
} // namespace
@@ -387,23 +387,18 @@ bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
387387
return true;
388388
}
389389

390-
bool RISCVVectorPeephole::isKnownSameDefs(Register A, Register B) const {
391-
if (A.isPhysical() || B.isPhysical())
392-
return false;
393-
394-
auto LookThruVirtRegCopies = [this](Register Reg) {
395-
while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) {
396-
if (!Def->isFullCopy())
397-
break;
398-
Register Src = Def->getOperand(1).getReg();
399-
if (!Src.isVirtual())
400-
break;
401-
Reg = Src;
402-
}
403-
return Reg;
404-
};
405-
406-
return LookThruVirtRegCopies(A) == LookThruVirtRegCopies(B);
390+
// If \p Reg is defined by one or more COPYs of virtual registers, traverses
391+
/// the chain and returns the root non-COPY source.
392+
Register RISCVVectorPeephole::lookThruCopies(Register Reg) const {
393+
while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) {
394+
if (!Def->isFullCopy())
395+
break;
396+
Register Src = Def->getOperand(1).getReg();
397+
if (!Src.isVirtual())
398+
break;
399+
Reg = Src;
400+
}
401+
return Reg;
407402
}
408403

409404
/// If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
@@ -421,23 +416,18 @@ bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
421416
return false;
422417
MachineInstr *True = MRI->getVRegDef(MI.getOperand(3).getReg());
423418

424-
// Peek through COPY.
425-
if (True && True->isCopy()) {
426-
if (Register TrueReg = True->getOperand(1).getReg(); TrueReg.isVirtual())
427-
True = MRI->getVRegDef(TrueReg);
428-
}
429-
430419
if (!True || True->getParent() != MI.getParent())
431420
return false;
432421

433422
auto *TrueMaskedInfo = RISCV::getMaskedPseudoInfo(True->getOpcode());
434423
if (!TrueMaskedInfo || !hasSameEEW(MI, *True))
435424
return false;
436425

437-
const MachineOperand &TrueMask =
438-
True->getOperand(TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs());
439-
const MachineOperand &MIMask = MI.getOperand(4);
440-
if (!isKnownSameDefs(TrueMask.getReg(), MIMask.getReg()))
426+
Register TrueMaskReg = lookThruCopies(
427+
True->getOperand(TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs())
428+
.getReg());
429+
Register MIMaskReg = lookThruCopies(MI.getOperand(4).getReg());
430+
if (!TrueMaskReg.isVirtual() || TrueMaskReg != MIMaskReg)
441431
return false;
442432

443433
// Masked off lanes past TrueVL will come from False, and converting to vmv
@@ -723,20 +713,12 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
723713
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMERGE_VVM)
724714
return false;
725715

726-
Register PassthruReg = MI.getOperand(1).getReg();
727-
Register FalseReg = MI.getOperand(2).getReg();
728-
Register TrueReg = MI.getOperand(3).getReg();
716+
Register PassthruReg = lookThruCopies(MI.getOperand(1).getReg());
717+
Register FalseReg = lookThruCopies(MI.getOperand(2).getReg());
718+
Register TrueReg = lookThruCopies(MI.getOperand(3).getReg());
729719
if (!TrueReg.isVirtual() || !MRI->hasOneUse(TrueReg))
730720
return false;
731-
MachineInstr *TrueMI = MRI->getUniqueVRegDef(TrueReg);
732-
// Peek through COPY.
733-
if (TrueMI->isCopy()) {
734-
if (TrueReg = TrueMI->getOperand(1).getReg();
735-
TrueReg.isVirtual() && MRI->hasOneUse(TrueReg))
736-
TrueMI = MRI->getVRegDef(TrueReg);
737-
}
738-
739-
MachineInstr &True = *TrueMI;
721+
MachineInstr &True = *MRI->getUniqueVRegDef(TrueReg);
740722
if (True.getParent() != MI.getParent())
741723
return false;
742724
const MachineOperand &MaskOp = MI.getOperand(4);
@@ -754,16 +736,17 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
754736

755737
// We require that either passthru and false are the same, or that passthru
756738
// is undefined.
757-
if (PassthruReg && !isKnownSameDefs(PassthruReg, FalseReg))
739+
if (PassthruReg && !(PassthruReg.isVirtual() && PassthruReg == FalseReg))
758740
return false;
759741

760742
std::optional<std::pair<unsigned, unsigned>> NeedsCommute;
761743

762744
// If True has a passthru operand then it needs to be the same as vmerge's
763745
// False, since False will be used for the result's passthru operand.
764-
Register TruePassthru = True.getOperand(True.getNumExplicitDefs()).getReg();
746+
Register TruePassthru =
747+
lookThruCopies(True.getOperand(True.getNumExplicitDefs()).getReg());
765748
if (RISCVII::isFirstDefTiedToFirstUse(True.getDesc()) && TruePassthru &&
766-
!isKnownSameDefs(TruePassthru, FalseReg)) {
749+
!(TruePassthru.isVirtual() && TruePassthru == FalseReg)) {
767750
// If True's passthru != False, check if it uses False in another operand
768751
// and try to commute it.
769752
int OtherIdx = True.findRegisterUseOperandIdx(FalseReg, TRI);

0 commit comments

Comments
 (0)