Skip to content

Commit 8fb2160

Browse files
authored
[RISCV] Use DenseMap to track V0 definition. NFC (#84465)
Reviving some of the progress on #71764. To recap, we explored removing the V0 register copies to simplify the pass, but hit a limitation with the register allocator due to our use of the vmv0 singleton reg class and early-clobber constraints. So since we will have to continue to track the definition of V0 ourselves, this patch simplifies it by storing it in a map. It will allow us to move about copies to V0 in #71764 without having to do extra bookkeeping.
1 parent cbcdf12 commit 8fb2160

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

llvm/lib/Target/RISCV/RISCVFoldMasks.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,13 @@ class RISCVFoldMasks : public MachineFunctionPass {
4747
StringRef getPassName() const override { return "RISC-V Fold Masks"; }
4848

4949
private:
50-
bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef) const;
51-
bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef) const;
50+
bool convertToUnmasked(MachineInstr &MI) const;
51+
bool convertVMergeToVMv(MachineInstr &MI) const;
5252

53-
bool isAllOnesMask(MachineInstr *MaskDef) const;
53+
bool isAllOnesMask(const MachineInstr *MaskDef) const;
54+
55+
/// Maps uses of V0 to the corresponding def of V0.
56+
DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
5457
};
5558

5659
} // namespace
@@ -59,10 +62,9 @@ char RISCVFoldMasks::ID = 0;
5962

6063
INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
6164

62-
bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
63-
if (!MaskDef)
64-
return false;
65-
assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0);
65+
bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
66+
assert(MaskDef && MaskDef->isCopy() &&
67+
MaskDef->getOperand(0).getReg() == RISCV::V0);
6668
Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
6769
if (!SrcReg.isVirtual())
6870
return false;
@@ -89,8 +91,7 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
8991

9092
// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
9193
// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
92-
bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
93-
MachineInstr *V0Def) const {
94+
bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
9495
#define CASE_VMERGE_TO_VMV(lmul) \
9596
case RISCV::PseudoVMERGE_VVM_##lmul: \
9697
NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
@@ -116,7 +117,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
116117
return false;
117118

118119
assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
119-
if (!isAllOnesMask(V0Def))
120+
if (!isAllOnesMask(V0Defs.lookup(&MI)))
120121
return false;
121122

122123
MI.setDesc(TII->get(NewOpc));
@@ -133,14 +134,13 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
133134
return true;
134135
}
135136

136-
bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI,
137-
MachineInstr *MaskDef) const {
137+
bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
138138
const RISCV::RISCVMaskedPseudoInfo *I =
139139
RISCV::getMaskedPseudoInfo(MI.getOpcode());
140140
if (!I)
141141
return false;
142142

143-
if (!isAllOnesMask(MaskDef))
143+
if (!isAllOnesMask(V0Defs.lookup(&MI)))
144144
return false;
145145

146146
// There are two classes of pseudos in the table - compares and
@@ -198,20 +198,26 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
198198
// $v0:vr = COPY %mask:vr
199199
// %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
200200
//
201-
// Because $v0 isn't in SSA, keep track of it so we can check the mask operand
202-
// on each pseudo.
203-
MachineInstr *CurrentV0Def;
204-
for (MachineBasicBlock &MBB : MF) {
205-
CurrentV0Def = nullptr;
206-
for (MachineInstr &MI : MBB) {
207-
Changed |= convertToUnmasked(MI, CurrentV0Def);
208-
Changed |= convertVMergeToVMv(MI, CurrentV0Def);
201+
// Because $v0 isn't in SSA, keep track of its definition at each use so we
202+
// can check mask operands.
203+
for (const MachineBasicBlock &MBB : MF) {
204+
const MachineInstr *CurrentV0Def = nullptr;
205+
for (const MachineInstr &MI : MBB) {
206+
if (MI.readsRegister(RISCV::V0, TRI))
207+
V0Defs[&MI] = CurrentV0Def;
209208

210209
if (MI.definesRegister(RISCV::V0, TRI))
211210
CurrentV0Def = &MI;
212211
}
213212
}
214213

214+
for (MachineBasicBlock &MBB : MF) {
215+
for (MachineInstr &MI : MBB) {
216+
Changed |= convertToUnmasked(MI);
217+
Changed |= convertVMergeToVMv(MI);
218+
}
219+
}
220+
215221
return Changed;
216222
}
217223

0 commit comments

Comments
 (0)