@@ -47,10 +47,13 @@ class RISCVFoldMasks : public MachineFunctionPass {
47
47
StringRef getPassName () const override { return " RISC-V Fold Masks" ; }
48
48
49
49
private:
50
- bool convertToUnmasked (MachineInstr &MI, MachineInstr *MaskDef ) const ;
51
- bool convertVMergeToVMv (MachineInstr &MI, MachineInstr *MaskDef ) const ;
50
+ bool convertToUnmasked (MachineInstr &MI) const ;
51
+ bool convertVMergeToVMv (MachineInstr &MI) const ;
52
52
53
- bool isAllOnesMask (MachineInstr *MaskDef) const ;
53
+ bool isAllOnesMask (const MachineInstr *MaskDef) const ;
54
+
55
+ // / Maps uses of V0 to the corresponding def of V0.
56
+ DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
54
57
};
55
58
56
59
} // namespace
@@ -59,10 +62,9 @@ char RISCVFoldMasks::ID = 0;
59
62
60
63
INITIALIZE_PASS (RISCVFoldMasks, DEBUG_TYPE, " RISC-V Fold Masks" , false , false )
61
64
62
- bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
63
- if (!MaskDef)
64
- return false ;
65
- assert (MaskDef->isCopy () && MaskDef->getOperand (0 ).getReg () == RISCV::V0);
65
+ bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
66
+ assert (MaskDef && MaskDef->isCopy () &&
67
+ MaskDef->getOperand (0 ).getReg () == RISCV::V0);
66
68
Register SrcReg = TRI->lookThruCopyLike (MaskDef->getOperand (1 ).getReg (), MRI);
67
69
if (!SrcReg.isVirtual ())
68
70
return false ;
@@ -89,8 +91,7 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
89
91
90
92
// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
91
93
// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
92
- bool RISCVFoldMasks::convertVMergeToVMv (MachineInstr &MI,
93
- MachineInstr *V0Def) const {
94
+ bool RISCVFoldMasks::convertVMergeToVMv (MachineInstr &MI) const {
94
95
#define CASE_VMERGE_TO_VMV (lmul ) \
95
96
case RISCV::PseudoVMERGE_VVM_##lmul: \
96
97
NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
@@ -116,7 +117,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
116
117
return false ;
117
118
118
119
assert (MI.getOperand (4 ).isReg () && MI.getOperand (4 ).getReg () == RISCV::V0);
119
- if (!isAllOnesMask (V0Def ))
120
+ if (!isAllOnesMask (V0Defs. lookup (&MI) ))
120
121
return false ;
121
122
122
123
MI.setDesc (TII->get (NewOpc));
@@ -133,14 +134,13 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
133
134
return true ;
134
135
}
135
136
136
- bool RISCVFoldMasks::convertToUnmasked (MachineInstr &MI,
137
- MachineInstr *MaskDef) const {
137
+ bool RISCVFoldMasks::convertToUnmasked (MachineInstr &MI) const {
138
138
const RISCV::RISCVMaskedPseudoInfo *I =
139
139
RISCV::getMaskedPseudoInfo (MI.getOpcode ());
140
140
if (!I)
141
141
return false ;
142
142
143
- if (!isAllOnesMask (MaskDef ))
143
+ if (!isAllOnesMask (V0Defs. lookup (&MI) ))
144
144
return false ;
145
145
146
146
// There are two classes of pseudos in the table - compares and
@@ -198,20 +198,26 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
198
198
// $v0:vr = COPY %mask:vr
199
199
// %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
200
200
//
201
- // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
202
- // on each pseudo.
203
- MachineInstr *CurrentV0Def;
204
- for (MachineBasicBlock &MBB : MF) {
205
- CurrentV0Def = nullptr ;
206
- for (MachineInstr &MI : MBB) {
207
- Changed |= convertToUnmasked (MI, CurrentV0Def);
208
- Changed |= convertVMergeToVMv (MI, CurrentV0Def);
201
+ // Because $v0 isn't in SSA, keep track of its definition at each use so we
202
+ // can check mask operands.
203
+ for (const MachineBasicBlock &MBB : MF) {
204
+ const MachineInstr *CurrentV0Def = nullptr ;
205
+ for (const MachineInstr &MI : MBB) {
206
+ if (MI.readsRegister (RISCV::V0, TRI))
207
+ V0Defs[&MI] = CurrentV0Def;
209
208
210
209
if (MI.definesRegister (RISCV::V0, TRI))
211
210
CurrentV0Def = &MI;
212
211
}
213
212
}
214
213
214
+ for (MachineBasicBlock &MBB : MF) {
215
+ for (MachineInstr &MI : MBB) {
216
+ Changed |= convertToUnmasked (MI);
217
+ Changed |= convertVMergeToVMv (MI);
218
+ }
219
+ }
220
+
215
221
return Changed;
216
222
}
217
223
0 commit comments