Skip to content

Commit 1e826b2

Browse files
committed
[RISCV] Look through copies for True operand in vmerge fold
In llvm#170070, PseudoVMERGE_V* instructions will have copies to NoV0 reg classes in their operands. In order to continue folding them we need to look through these copies. We previously looked through copies when comparing if the false and passthru operands were equivalent, but didn't look through copies for the true operand. This looks through the copies up front for all operands, and not just when we're comparing equality.
1 parent 5f38ab2 commit 1e826b2

File tree

2 files changed

+27
-31
lines changed

2 files changed

+27
-31
lines changed

llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
7373
bool isAllOnesMask(const MachineInstr *MaskDef) const;
7474
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
7575
bool ensureDominates(const MachineOperand &Use, MachineInstr &Src) const;
76-
bool isKnownSameDefs(Register A, Register B) const;
76+
Register lookThruCopies(Register Reg) const;
7777
};
7878

7979
} // namespace
@@ -387,23 +387,18 @@ bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
387387
return true;
388388
}
389389

390-
bool RISCVVectorPeephole::isKnownSameDefs(Register A, Register B) const {
391-
if (A.isPhysical() || B.isPhysical())
392-
return false;
393-
394-
auto LookThruVirtRegCopies = [this](Register Reg) {
395-
while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) {
396-
if (!Def->isFullCopy())
397-
break;
398-
Register Src = Def->getOperand(1).getReg();
399-
if (!Src.isVirtual())
400-
break;
401-
Reg = Src;
402-
}
403-
return Reg;
404-
};
405-
406-
return LookThruVirtRegCopies(A) == LookThruVirtRegCopies(B);
390+
/// If \p Reg is defined by one or more COPYs of virtual registers, traverses
391+
/// the chain and returns the root non-COPY source.
392+
Register RISCVVectorPeephole::lookThruCopies(Register Reg) const {
393+
while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) {
394+
if (!Def->isFullCopy())
395+
break;
396+
Register Src = Def->getOperand(1).getReg();
397+
if (!Src.isVirtual())
398+
break;
399+
Reg = Src;
400+
}
401+
return Reg;
407402
}
408403

409404
/// If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
@@ -428,10 +423,11 @@ bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
428423
if (!TrueMaskedInfo || !hasSameEEW(MI, *True))
429424
return false;
430425

431-
const MachineOperand &TrueMask =
432-
True->getOperand(TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs());
433-
const MachineOperand &MIMask = MI.getOperand(4);
434-
if (!isKnownSameDefs(TrueMask.getReg(), MIMask.getReg()))
426+
Register TrueMaskReg = lookThruCopies(
427+
True->getOperand(TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs())
428+
.getReg());
429+
Register MIMaskReg = lookThruCopies(MI.getOperand(4).getReg());
430+
if (!TrueMaskReg.isVirtual() || TrueMaskReg != MIMaskReg)
435431
return false;
436432

437433
// Masked off lanes past TrueVL will come from False, and converting to vmv
@@ -717,9 +713,9 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
717713
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMERGE_VVM)
718714
return false;
719715

720-
Register PassthruReg = MI.getOperand(1).getReg();
721-
Register FalseReg = MI.getOperand(2).getReg();
722-
Register TrueReg = MI.getOperand(3).getReg();
716+
Register PassthruReg = lookThruCopies(MI.getOperand(1).getReg());
717+
Register FalseReg = lookThruCopies(MI.getOperand(2).getReg());
718+
Register TrueReg = lookThruCopies(MI.getOperand(3).getReg());
723719
if (!TrueReg.isVirtual() || !MRI->hasOneUse(TrueReg))
724720
return false;
725721
MachineInstr &True = *MRI->getUniqueVRegDef(TrueReg);
@@ -740,16 +736,17 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
740736

741737
// We require that either passthru and false are the same, or that passthru
742738
// is undefined.
743-
if (PassthruReg && !isKnownSameDefs(PassthruReg, FalseReg))
739+
if (PassthruReg && !(PassthruReg.isVirtual() && PassthruReg == FalseReg))
744740
return false;
745741

746742
std::optional<std::pair<unsigned, unsigned>> NeedsCommute;
747743

748744
// If True has a passthru operand then it needs to be the same as vmerge's
749745
// False, since False will be used for the result's passthru operand.
750-
Register TruePassthru = True.getOperand(True.getNumExplicitDefs()).getReg();
746+
Register TruePassthru =
747+
lookThruCopies(True.getOperand(True.getNumExplicitDefs()).getReg());
751748
if (RISCVII::isFirstDefTiedToFirstUse(True.getDesc()) && TruePassthru &&
752-
!isKnownSameDefs(TruePassthru, FalseReg)) {
749+
!(TruePassthru.isVirtual() && TruePassthru == FalseReg)) {
753750
// If True's passthru != False, check if it uses False in another operand
754751
// and try to commute it.
755752
int OtherIdx = True.findRegisterUseOperandIdx(FalseReg, TRI);

llvm/test/CodeGen/RISCV/rvv/vmerge-peephole.mir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,9 @@ body: |
126126
; CHECK-NEXT: {{ $}}
127127
; CHECK-NEXT: %avl:gprnox0 = COPY $x8
128128
; CHECK-NEXT: %passthru:vrnov0 = COPY $v8
129-
; CHECK-NEXT: %x:vr = PseudoVLE32_V_M1 $noreg, $noreg, %avl, 5 /* e32 */, 2 /* tu, ma */ :: (load unknown-size, align 1)
130129
; CHECK-NEXT: %mask:vmv0 = COPY $v0
131-
; CHECK-NEXT: %y:vrnov0 = COPY %x
132-
; CHECK-NEXT: %z:vrnov0 = PseudoVMERGE_VVM_M1 %passthru, %passthru, %y, %mask, %avl, 5 /* e32 */
130+
; CHECK-NEXT: %z:vrnov0 = PseudoVLE32_V_M1_MASK %passthru, $noreg, %mask, %avl, 5 /* e32 */, 0 /* tu, mu */ :: (load unknown-size, align 1)
131+
; CHECK-NEXT: %y:vrnov0 = COPY %z
133132
%avl:gprnox0 = COPY $x8
134133
%passthru:vrnov0 = COPY $v8
135134
%x:vr = PseudoVLE32_V_M1 $noreg, $noreg, %avl, 5 /* e32 */, 2 /* tu, ma */ :: (load unknown-size)

0 commit comments

Comments
 (0)