Skip to content

[RISCV] Select mask operands as virtual registers and eliminate uses of vmv0 #125026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ add_llvm_target(RISCVCodeGen
RISCVVectorMaskDAGMutation.cpp
RISCVVectorPeephole.cpp
RISCVVLOptimizer.cpp
RISCVVMV0Elimination.cpp
RISCVZacasABIFix.cpp
GISel/RISCVCallLowering.cpp
GISel/RISCVInstructionSelector.cpp
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCV.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ void initializeRISCVPreLegalizerCombinerPass(PassRegistry &);

FunctionPass *createRISCVVLOptimizerPass();
void initializeRISCVVLOptimizerPass(PassRegistry &);

FunctionPass *createRISCVVMV0EliminationPass();
void initializeRISCVVMV0EliminationPass(PassRegistry &);
} // namespace llvm

#endif
107 changes: 15 additions & 92 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ void RISCVDAGToDAGISel::addVectorLoadStoreOperands(
bool IsMasked, bool IsStridedOrIndexed, SmallVectorImpl<SDValue> &Operands,
bool IsLoad, MVT *IndexVT) {
SDValue Chain = Node->getOperand(0);
SDValue Glue;

Operands.push_back(Node->getOperand(CurOp++)); // Base pointer.

Expand All @@ -265,11 +264,8 @@ void RISCVDAGToDAGISel::addVectorLoadStoreOperands(
}

if (IsMasked) {
// Mask needs to be copied to V0.
SDValue Mask = Node->getOperand(CurOp++);
Chain = CurDAG->getCopyToReg(Chain, DL, RISCV::V0, Mask, SDValue());
Glue = Chain.getValue(1);
Operands.push_back(CurDAG->getRegister(RISCV::V0, Mask.getValueType()));
Operands.push_back(Mask);
}
SDValue VL;
selectVLOp(Node->getOperand(CurOp++), VL);
Expand All @@ -291,8 +287,6 @@ void RISCVDAGToDAGISel::addVectorLoadStoreOperands(
}

Operands.push_back(Chain); // Chain.
if (Glue)
Operands.push_back(Glue);
}

void RISCVDAGToDAGISel::selectVLSEG(SDNode *Node, unsigned NF, bool IsMasked,
Expand Down Expand Up @@ -1844,19 +1838,13 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
return;
}

// Mask needs to be copied to V0.
SDValue Chain = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
RISCV::V0, Mask, SDValue());
SDValue Glue = Chain.getValue(1);
SDValue V0 = CurDAG->getRegister(RISCV::V0, VT);

if (IsCmpConstant) {
SDValue Imm =
selectImm(CurDAG, SDLoc(Src2), XLenVT, CVal - 1, *Subtarget);

ReplaceNode(Node, CurDAG->getMachineNode(
VMSGTMaskOpcode, DL, VT,
{MaskedOff, Src1, Imm, V0, VL, SEW, Glue}));
{MaskedOff, Src1, Imm, Mask, VL, SEW}));
return;
}

Expand All @@ -1867,7 +1855,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
// the agnostic result can be either undisturbed or all 1.
SDValue Cmp = SDValue(
CurDAG->getMachineNode(VMSLTMaskOpcode, DL, VT,
{MaskedOff, Src1, Src2, V0, VL, SEW, Glue}),
{MaskedOff, Src1, Src2, Mask, VL, SEW}),
0);
// vmxor.mm vd, vd, v0 is used to update active value.
ReplaceNode(Node, CurDAG->getMachineNode(VMXOROpcode, DL, VT,
Expand Down Expand Up @@ -3287,12 +3275,10 @@ static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo,
return false;
assert(RISCVII::hasVLOp(TSFlags));

bool HasGlueOp = User->getGluedNode() != nullptr;
unsigned ChainOpIdx = User->getNumOperands() - HasGlueOp - 1;
unsigned ChainOpIdx = User->getNumOperands() - 1;
bool HasChainOp = User->getOperand(ChainOpIdx).getValueType() == MVT::Other;
bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TSFlags);
unsigned VLIdx =
User->getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
unsigned VLIdx = User->getNumOperands() - HasVecPolicyOp - HasChainOp - 2;
const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);

if (UserOpNo == VLIdx)
Expand Down Expand Up @@ -3759,43 +3745,7 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
return false;
}

// After ISel, a vector pseudo's mask will be copied to V0 via a CopyToReg
// that's glued to the pseudo. This tries to look up the value that was copied
// to V0.
static SDValue getMaskSetter(SDValue MaskOp, SDValue GlueOp) {
// Check that we're using V0 as a mask register.
if (!isa<RegisterSDNode>(MaskOp) ||
cast<RegisterSDNode>(MaskOp)->getReg() != RISCV::V0)
return SDValue();

// The glued user defines V0.
const auto *Glued = GlueOp.getNode();

if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
return SDValue();

// Check that we're defining V0 as a mask register.
if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
return SDValue();

SDValue MaskSetter = Glued->getOperand(2);

// Sometimes the VMSET is wrapped in a COPY_TO_REGCLASS, e.g. if the mask came
// from an extract_subvector or insert_subvector.
if (MaskSetter->isMachineOpcode() &&
MaskSetter->getMachineOpcode() == RISCV::COPY_TO_REGCLASS)
MaskSetter = MaskSetter->getOperand(0);

return MaskSetter;
}

static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
// Check the instruction defining V0; it needs to be a VMSET pseudo.
SDValue MaskSetter = getMaskSetter(MaskOp, GlueOp);
if (!MaskSetter)
return false;

static bool usesAllOnesMask(SDValue MaskOp) {
const auto IsVMSet = [](unsigned Opc) {
return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
Expand All @@ -3806,14 +3756,7 @@ static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
// TODO: Check that the VMSET is the expected bitwidth? The pseudo has
// undefined behaviour if it's the wrong bitwidth, so we could choose to
// assume that it's all-ones? Same applies to its VL.
return MaskSetter->isMachineOpcode() &&
IsVMSet(MaskSetter.getMachineOpcode());
}

// Return true if we can make sure mask of N is all-ones mask.
static bool usesAllOnesMask(SDNode *N, unsigned MaskOpIdx) {
return usesAllOnesMask(N->getOperand(MaskOpIdx),
N->getOperand(N->getNumOperands() - 1));
return MaskOp->isMachineOpcode() && IsVMSet(MaskOp.getMachineOpcode());
}

static bool isImplicitDef(SDValue V) {
Expand All @@ -3829,17 +3772,15 @@ static bool isImplicitDef(SDValue V) {
}

// Optimize masked RVV pseudo instructions with a known all-ones mask to their
// corresponding "unmasked" pseudo versions. The mask we're interested in will
// take the form of a V0 physical register operand, with a glued
// register-setting instruction.
// corresponding "unmasked" pseudo versions.
bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(MachineSDNode *N) {
const RISCV::RISCVMaskedPseudoInfo *I =
RISCV::getMaskedPseudoInfo(N->getMachineOpcode());
if (!I)
return false;

unsigned MaskOpIdx = I->MaskOpIdx;
if (!usesAllOnesMask(N, MaskOpIdx))
if (!usesAllOnesMask(N->getOperand(MaskOpIdx)))
return false;

// There are two classes of pseudos in the table - compares and
Expand All @@ -3863,18 +3804,13 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(MachineSDNode *N) {
// Skip the passthru operand at index 0 if the unmasked don't have one.
bool ShouldSkip = !HasPassthru && MaskedHasPassthru;
for (unsigned I = ShouldSkip, E = N->getNumOperands(); I != E; I++) {
// Skip the mask, and the Glue.
// Skip the mask
SDValue Op = N->getOperand(I);
if (I == MaskOpIdx || Op.getValueType() == MVT::Glue)
if (I == MaskOpIdx)
continue;
Ops.push_back(Op);
}

// Transitively apply any node glued to our new node.
const auto *Glued = N->getGluedNode();
if (auto *TGlued = Glued->getGluedNode())
Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1));

MachineSDNode *Result =
CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops);

Expand Down Expand Up @@ -3910,17 +3846,13 @@ static bool IsVMerge(SDNode *N) {
// The resulting policy is the effective policy the vmerge would have had,
// i.e. whether or not it's passthru operand was implicit-def.
bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
SDValue Passthru, False, True, VL, Mask, Glue;
SDValue Passthru, False, True, VL, Mask;
assert(IsVMerge(N));
Passthru = N->getOperand(0);
False = N->getOperand(1);
True = N->getOperand(2);
Mask = N->getOperand(3);
VL = N->getOperand(4);
// We always have a glue node for the mask at v0.
Glue = N->getOperand(N->getNumOperands() - 1);
assert(cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
assert(Glue.getValueType() == MVT::Glue);

// If the EEW of True is different from vmerge's SEW, then we can't fold.
if (True.getSimpleValueType() != N->getSimpleValueType(0))
Expand Down Expand Up @@ -3963,12 +3895,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (TII->get(TrueOpc).hasUnmodeledSideEffects())
return false;

// The last operand of a masked instruction may be glued.
bool HasGlueOp = True->getGluedNode() != nullptr;

// The chain operand may exist either before the glued operands or in the last
// position.
unsigned TrueChainOpIdx = True.getNumOperands() - HasGlueOp - 1;
unsigned TrueChainOpIdx = True.getNumOperands() - 1;
bool HasChainOp =
True.getOperand(TrueChainOpIdx).getValueType() == MVT::Other;

Expand All @@ -3980,15 +3907,14 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
LoopWorklist.push_back(False.getNode());
LoopWorklist.push_back(Mask.getNode());
LoopWorklist.push_back(VL.getNode());
LoopWorklist.push_back(Glue.getNode());
if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist))
return false;
}

// The vector policy operand may be present for masked intrinsics
bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TrueTSFlags);
unsigned TrueVLIndex =
True.getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
True.getNumOperands() - HasVecPolicyOp - HasChainOp - 2;
SDValue TrueVL = True.getOperand(TrueVLIndex);
SDValue SEW = True.getOperand(TrueVLIndex + 1);

Expand Down Expand Up @@ -4020,7 +3946,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (RISCVII::elementsDependOnVL(TrueBaseMCID.TSFlags) && (TrueVL != VL))
return false;
if (RISCVII::elementsDependOnMask(TrueBaseMCID.TSFlags) &&
(Mask && !usesAllOnesMask(Mask, Glue)))
(Mask && !usesAllOnesMask(Mask)))
return false;

// Make sure it doesn't raise any observable fp exceptions, since changing the
Expand Down Expand Up @@ -4077,9 +4003,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (HasChainOp)
Ops.push_back(True.getOperand(TrueChainOpIdx));

// Add the glue for the CopyToReg of mask->v0.
Ops.push_back(Glue);

MachineSDNode *Result =
CurDAG->getMachineNode(MaskedOpc, DL, True->getVTList(), Ops);
Result->setFlags(True->getFlags());
Expand Down
Loading