Skip to content

Commit bd9a335

Browse files
committed
[RISCV] Move vmv.v.v peephole from SelectionDAG to RISCVVectorPeephole
This is split off from llvm#71764, and moves only the vmv.v.v part of performCombineVMergeAndVOps to work on MachineInstrs. In retrospect trying to handle PseudoVMV_V_V and PseudoVMERGE_VVM in the same function makes the code quite hard to read, so this just does it in a separate peephole. This turns out to be simpler since for PseudoVMV_V_V we don't need to convert the Src instruction to a masked variant, and we don't need to create a fake all ones mask.
1 parent 7231776 commit bd9a335

File tree

2 files changed

+154
-71
lines changed

2 files changed

+154
-71
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3663,32 +3663,6 @@ static bool IsVMerge(SDNode *N) {
36633663
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMERGE_VVM;
36643664
}
36653665

3666-
static bool IsVMv(SDNode *N) {
3667-
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMV_V_V;
3668-
}
3669-
3670-
static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
3671-
switch (LMUL) {
3672-
case RISCVII::LMUL_F8:
3673-
return RISCV::PseudoVMSET_M_B1;
3674-
case RISCVII::LMUL_F4:
3675-
return RISCV::PseudoVMSET_M_B2;
3676-
case RISCVII::LMUL_F2:
3677-
return RISCV::PseudoVMSET_M_B4;
3678-
case RISCVII::LMUL_1:
3679-
return RISCV::PseudoVMSET_M_B8;
3680-
case RISCVII::LMUL_2:
3681-
return RISCV::PseudoVMSET_M_B16;
3682-
case RISCVII::LMUL_4:
3683-
return RISCV::PseudoVMSET_M_B32;
3684-
case RISCVII::LMUL_8:
3685-
return RISCV::PseudoVMSET_M_B64;
3686-
case RISCVII::LMUL_RESERVED:
3687-
llvm_unreachable("Unexpected LMUL");
3688-
}
3689-
llvm_unreachable("Unknown VLMUL enum");
3690-
}
3691-
36923666
// Try to fold away VMERGE_VVM instructions into their true operands:
36933667
//
36943668
// %true = PseudoVADD_VV ...
@@ -3703,35 +3677,22 @@ static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
37033677
// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
37043678
// mask is all ones.
37053679
//
3706-
// We can also fold a VMV_V_V into its true operand, since it is equivalent to a
3707-
// VMERGE_VVM with an all ones mask.
3708-
//
37093680
// The resulting VL is the minimum of the two VLs.
37103681
//
37113682
// The resulting policy is the effective policy the vmerge would have had,
37123683
// i.e. whether or not it's passthru operand was implicit-def.
37133684
bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
37143685
SDValue Passthru, False, True, VL, Mask, Glue;
3715-
// A vmv.v.v is equivalent to a vmerge with an all-ones mask.
3716-
if (IsVMv(N)) {
3717-
Passthru = N->getOperand(0);
3718-
False = N->getOperand(0);
3719-
True = N->getOperand(1);
3720-
VL = N->getOperand(2);
3721-
// A vmv.v.v won't have a Mask or Glue, instead we'll construct an all-ones
3722-
// mask later below.
3723-
} else {
3724-
assert(IsVMerge(N));
3725-
Passthru = N->getOperand(0);
3726-
False = N->getOperand(1);
3727-
True = N->getOperand(2);
3728-
Mask = N->getOperand(3);
3729-
VL = N->getOperand(4);
3730-
// We always have a glue node for the mask at v0.
3731-
Glue = N->getOperand(N->getNumOperands() - 1);
3732-
}
3733-
assert(!Mask || cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
3734-
assert(!Glue || Glue.getValueType() == MVT::Glue);
3686+
assert(IsVMerge(N));
3687+
Passthru = N->getOperand(0);
3688+
False = N->getOperand(1);
3689+
True = N->getOperand(2);
3690+
Mask = N->getOperand(3);
3691+
VL = N->getOperand(4);
3692+
// We always have a glue node for the mask at v0.
3693+
Glue = N->getOperand(N->getNumOperands() - 1);
3694+
assert(cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
3695+
assert(Glue.getValueType() == MVT::Glue);
37353696

37363697
// If the EEW of True is different from vmerge's SEW, then we can't fold.
37373698
if (True.getSimpleValueType() != N->getSimpleValueType(0))
@@ -3779,7 +3740,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
37793740

37803741
// If True is masked then the vmerge must have either the same mask or an all
37813742
// 1s mask, since we're going to keep the mask from True.
3782-
if (IsMasked && Mask) {
3743+
if (IsMasked) {
37833744
// FIXME: Support mask agnostic True instruction which would have an
37843745
// undef passthru operand.
37853746
SDValue TrueMask =
@@ -3809,11 +3770,9 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
38093770
SmallVector<const SDNode *, 4> LoopWorklist;
38103771
SmallPtrSet<const SDNode *, 16> Visited;
38113772
LoopWorklist.push_back(False.getNode());
3812-
if (Mask)
3813-
LoopWorklist.push_back(Mask.getNode());
3773+
LoopWorklist.push_back(Mask.getNode());
38143774
LoopWorklist.push_back(VL.getNode());
3815-
if (Glue)
3816-
LoopWorklist.push_back(Glue.getNode());
3775+
LoopWorklist.push_back(Glue.getNode());
38173776
if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist))
38183777
return false;
38193778
}
@@ -3873,21 +3832,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
38733832
Glue = True->getOperand(True->getNumOperands() - 1);
38743833
assert(Glue.getValueType() == MVT::Glue);
38753834
}
3876-
// If we end up using the vmerge mask the vmerge is actually a vmv.v.v, create
3877-
// an all-ones mask to use.
3878-
else if (IsVMv(N)) {
3879-
unsigned TSFlags = TII->get(N->getMachineOpcode()).TSFlags;
3880-
unsigned VMSetOpc = GetVMSetForLMul(RISCVII::getLMul(TSFlags));
3881-
ElementCount EC = N->getValueType(0).getVectorElementCount();
3882-
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);
3883-
3884-
SDValue AllOnesMask =
3885-
SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0);
3886-
SDValue MaskCopy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
3887-
RISCV::V0, AllOnesMask, SDValue());
3888-
Mask = CurDAG->getRegister(RISCV::V0, MaskVT);
3889-
Glue = MaskCopy.getValue(1);
3890-
}
38913835

38923836
unsigned MaskedOpc = Info->MaskedPseudo;
38933837
#ifndef NDEBUG
@@ -3966,7 +3910,7 @@ bool RISCVDAGToDAGISel::doPeepholeMergeVVMFold() {
39663910
if (N->use_empty() || !N->isMachineOpcode())
39673911
continue;
39683912

3969-
if (IsVMerge(N) || IsVMv(N))
3913+
if (IsVMerge(N))
39703914
MadeChange |= performCombineVMergeAndVOps(N);
39713915
}
39723916
return MadeChange;

llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
6565
bool convertToWholeRegister(MachineInstr &MI) const;
6666
bool convertToUnmasked(MachineInstr &MI) const;
6767
bool convertVMergeToVMv(MachineInstr &MI) const;
68+
bool foldVMV_V_V(MachineInstr &MI);
6869

6970
bool isAllOnesMask(const MachineInstr *MaskDef) const;
7071
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
@@ -324,6 +325,143 @@ bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
324325
return true;
325326
}
326327

328+
/// Given two VL operands, returns the one known to be the smallest or nullptr
329+
/// if unknown.
330+
static const MachineOperand *getKnownMinVL(const MachineOperand *LHS,
331+
const MachineOperand *RHS) {
332+
if (LHS->isReg() && RHS->isReg() && LHS->getReg().isVirtual() &&
333+
LHS->getReg() == RHS->getReg())
334+
return LHS;
335+
if (LHS->isImm() && LHS->getImm() == RISCV::VLMaxSentinel)
336+
return RHS;
337+
if (RHS->isImm() && RHS->getImm() == RISCV::VLMaxSentinel)
338+
return LHS;
339+
if (!LHS->isImm() || !RHS->isImm())
340+
return nullptr;
341+
return LHS->getImm() <= RHS->getImm() ? LHS : RHS;
342+
}
343+
344+
/// Check if it's safe to move From down to To, checking that no physical
345+
/// registers are clobbered.
346+
static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
347+
assert(From.getParent() == To.getParent() && !From.hasImplicitDef());
348+
SmallVector<Register> PhysUses;
349+
for (const MachineOperand &MO : From.all_uses())
350+
if (MO.getReg().isPhysical())
351+
PhysUses.push_back(MO.getReg());
352+
bool SawStore = false;
353+
for (auto II = From.getIterator(); II != To.getIterator(); II++) {
354+
for (Register PhysReg : PhysUses)
355+
if (II->definesRegister(PhysReg, nullptr))
356+
return false;
357+
if (II->mayStore()) {
358+
SawStore = true;
359+
break;
360+
}
361+
}
362+
return From.isSafeToMove(nullptr, SawStore);
363+
}
364+
365+
static const RISCV::RISCVMaskedPseudoInfo *
366+
lookupMaskedPseudoInfo(const MachineInstr &MI) {
367+
const RISCV::RISCVMaskedPseudoInfo *Info =
368+
RISCV::lookupMaskedIntrinsicByUnmasked(MI.getOpcode());
369+
if (!Info)
370+
Info = RISCV::getMaskedPseudoInfo(MI.getOpcode());
371+
return Info;
372+
}
373+
374+
/// If a PseudoVMV_V_V is the only user of it's input, fold its passthru and VL
375+
/// into it.
376+
///
377+
/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl, sew, policy
378+
/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl, sew, policy
379+
///
380+
/// ->
381+
///
382+
/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl, sew, policy
383+
bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
384+
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
385+
return false;
386+
387+
MachineOperand &Passthru = MI.getOperand(1);
388+
MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
389+
390+
if (!MRI->hasOneUse(MI.getOperand(2).getReg()))
391+
return false;
392+
393+
if (!Src || Src->hasUnmodeledSideEffects() ||
394+
Src->getParent() != MI.getParent())
395+
return false;
396+
397+
// Src needs to be a pseudo that's opted into this transform.
398+
const RISCV::RISCVMaskedPseudoInfo *Info = lookupMaskedPseudoInfo(*Src);
399+
if (!Info)
400+
return false;
401+
402+
assert(Src->getNumDefs() == 1 &&
403+
RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) &&
404+
RISCVII::hasVLOp(Src->getDesc().TSFlags) &&
405+
RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags));
406+
407+
// Src needs to have the same passthru as VMV_V_V
408+
if (Src->getOperand(1).getReg() != RISCV::NoRegister &&
409+
Src->getOperand(1).getReg() != Passthru.getReg())
410+
return false;
411+
412+
// Because Src and MI have the same passthru, we can use either AVL as long as
413+
// it's the smaller of the two.
414+
//
415+
// (src pt, ..., vl=5) x x x x x|. . .
416+
// (vmv.v.v pt, src, vl=3) x x x|. . . . .
417+
// ->
418+
// (src pt, ..., vl=3) x x x|. . . . .
419+
//
420+
// (src pt, ..., vl=3) x x x|. . . . .
421+
// (vmv.v.v pt, src, vl=6) x x x . . .|. .
422+
// ->
423+
// (src pt, ..., vl=3) x x x|. . . . .
424+
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
425+
const MachineOperand *MinVL = getKnownMinVL(&MI.getOperand(3), &SrcVL);
426+
if (!MinVL)
427+
return false;
428+
429+
bool VLChanged = !MinVL->isIdenticalTo(SrcVL);
430+
bool RaisesFPExceptions = MI.getDesc().mayRaiseFPException() &&
431+
!MI.getFlag(MachineInstr::MIFlag::NoFPExcept);
432+
if (VLChanged && (Info->ActiveElementsAffectResult || RaisesFPExceptions))
433+
return false;
434+
435+
if (!isSafeToMove(*Src, MI))
436+
return false;
437+
438+
// Move Src down to MI, then replace all uses of MI with it.
439+
Src->moveBefore(&MI);
440+
441+
Src->getOperand(1).setReg(Passthru.getReg());
442+
// If Src is masked then its passthru needs to be in VRNoV0.
443+
if (Passthru.getReg() != RISCV::NoRegister)
444+
MRI->constrainRegClass(Passthru.getReg(),
445+
TII->getRegClass(Src->getDesc(), 1, TRI,
446+
*Src->getParent()->getParent()));
447+
448+
if (MinVL->isImm())
449+
SrcVL.ChangeToImmediate(MinVL->getImm());
450+
else if (MinVL->isReg())
451+
SrcVL.ChangeToRegister(MinVL->getReg(), false);
452+
453+
// Use a conservative tu,mu policy, RISCVInsertVSETVLI will relax it if
454+
// passthru is undef.
455+
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()))
456+
.setImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED);
457+
458+
MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg());
459+
MI.eraseFromParent();
460+
V0Defs.erase(&MI);
461+
462+
return true;
463+
}
464+
327465
bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
328466
if (skipFunction(MF.getFunction()))
329467
return false;
@@ -358,11 +496,12 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
358496
}
359497

360498
for (MachineBasicBlock &MBB : MF) {
361-
for (MachineInstr &MI : MBB) {
499+
for (MachineInstr &MI : make_early_inc_range(MBB)) {
362500
Changed |= convertToVLMAX(MI);
363501
Changed |= convertToUnmasked(MI);
364502
Changed |= convertToWholeRegister(MI);
365503
Changed |= convertVMergeToVMv(MI);
504+
Changed |= foldVMV_V_V(MI);
366505
}
367506
}
368507

0 commit comments

Comments
 (0)