@@ -55,12 +55,11 @@ class RISCVFoldMasks : public MachineFunctionPass {
55
55
StringRef getPassName () const override { return " RISC-V Fold Masks" ; }
56
56
57
57
private:
58
+ bool convertToUnmasked (MachineInstr &MI, MachineInstr *MaskDef);
58
59
bool foldVMergeIntoOps (MachineInstr &MI, MachineInstr *MaskDef);
59
60
bool convertVMergeToVMv (MachineInstr &MI, MachineInstr *MaskDef);
60
- bool convertToUnmasked (MachineInstr &MI, MachineInstr *MaskDef);
61
61
62
62
bool isAllOnesMask (MachineInstr *MaskDef);
63
- bool isOpSameAs (const MachineOperand &LHS, const MachineOperand &RHS);
64
63
};
65
64
66
65
} // namespace
@@ -119,17 +118,6 @@ static unsigned getVMSetForLMul(RISCVII::VLMUL LMUL) {
119
118
llvm_unreachable (" Unknown VLMUL enum" );
120
119
}
121
120
122
- // Returns true if LHS is the same register as RHS, or if LHS is undefined.
123
- bool RISCVFoldMasks::isOpSameAs (const MachineOperand &LHS,
124
- const MachineOperand &RHS) {
125
- if (LHS.getReg () == RISCV::NoRegister)
126
- return true ;
127
- if (RHS.getReg () == RISCV::NoRegister)
128
- return false ;
129
- return TRI->lookThruCopyLike (LHS.getReg (), MRI) ==
130
- TRI->lookThruCopyLike (RHS.getReg (), MRI);
131
- }
132
-
133
121
// Try to fold away VMERGE_VVM instructions. We handle these cases:
134
122
// -Masked TU VMERGE_VVM combined with an unmasked TA instruction instruction
135
123
// folds to a masked TU instruction. VMERGE_VVM must have have merge operand
@@ -163,10 +151,14 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
163
151
return false ;
164
152
165
153
MachineInstr &TrueMI = *MRI->getVRegDef (True->getReg ());
154
+ if (TrueMI.getParent () != MI.getParent ())
155
+ return false ;
166
156
167
157
// We require that either merge and false are the same, or that merge
168
158
// is undefined.
169
- if (!isOpSameAs (*Merge, *False))
159
+ if (Merge->getReg () != RISCV::NoRegister &&
160
+ TRI->lookThruCopyLike (Merge->getReg (), MRI) !=
161
+ TRI->lookThruCopyLike (False->getReg (), MRI))
170
162
return false ;
171
163
172
164
// N must be the only user of True.
@@ -177,21 +169,22 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
177
169
const MCInstrDesc &TrueMCID = TrueMI.getDesc ();
178
170
bool HasTiedDest = RISCVII::isFirstDefTiedToFirstUse (TrueMCID);
179
171
180
- bool IsMasked = false ;
172
+ const bool MIIsMasked =
173
+ BaseOpc == RISCV::VMERGE_VVM && !isAllOnesMask (MaskDef);
174
+ bool TrueIsMasked = false ;
181
175
const RISCV::RISCVMaskedPseudoInfo *Info =
182
176
RISCV::lookupMaskedIntrinsicByUnmasked (TrueOpc);
183
177
if (!Info && HasTiedDest) {
184
178
Info = RISCV::getMaskedPseudoInfo (TrueOpc);
185
- IsMasked = true ;
179
+ TrueIsMasked = true ;
186
180
}
187
181
188
182
if (!Info)
189
183
return false ;
190
184
191
185
// When Mask is not a true mask, this transformation is illegal for some
192
186
// operations whose results are affected by mask, like viota.m.
193
- if (Info->MaskAffectsResult && BaseOpc == RISCV::VMERGE_VVM &&
194
- !isAllOnesMask (MaskDef))
187
+ if (Info->MaskAffectsResult && MIIsMasked)
195
188
return false ;
196
189
197
190
MachineOperand &TrueMergeOp = TrueMI.getOperand (1 );
@@ -203,20 +196,21 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
203
196
return false ;
204
197
// Both the vmerge instruction and the True instruction must have the same
205
198
// merge operand.
206
- if (!isOpSameAs (TrueMergeOp, *False))
199
+ if (TrueMergeOp.getReg () != RISCV::NoRegister &&
200
+ TrueMergeOp.getReg () != False->getReg ())
207
201
return false ;
208
202
}
209
203
210
- if (IsMasked ) {
204
+ if (TrueIsMasked ) {
211
205
assert (HasTiedDest && " Expected tied dest" );
212
206
// The vmerge instruction must be TU.
213
207
if (Merge->getReg () == RISCV::NoRegister)
214
208
return false ;
215
- // The vmerge instruction must have an all 1s mask since we're going to keep
216
- // the mask from the True instruction.
217
- // FIXME: Support mask agnostic True instruction which would have an
218
- // undef merge operand.
219
- if (BaseOpc == RISCV::VMERGE_VVM && ! isAllOnesMask (MaskDef) )
209
+ // MI must have an all 1s mask since we're going to keep the mask from the
210
+ // True instruction.
211
+ // FIXME: Support mask agnostic True instruction which would have an undef
212
+ // merge operand.
213
+ if (MIIsMasked )
220
214
return false ;
221
215
}
222
216
@@ -225,10 +219,6 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
225
219
if (TII->get (TrueOpc).hasUnmodeledSideEffects ())
226
220
return false ;
227
221
228
- // The vector policy operand may be present for masked intrinsics
229
- const MachineOperand &TrueVL =
230
- TrueMI.getOperand (RISCVII::getVLOpNum (TrueMCID));
231
-
232
222
auto GetMinVL =
233
223
[](const MachineOperand &LHS,
234
224
const MachineOperand &RHS) -> std::optional<MachineOperand> {
@@ -246,7 +236,9 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
246
236
247
237
// Because MI and True must have the same merge operand (or True's operand is
248
238
// implicit_def), the "effective" body is the minimum of their VLs.
249
- const MachineOperand VL = MI.getOperand (RISCVII::getVLOpNum (MI.getDesc ()));
239
+ const MachineOperand &TrueVL =
240
+ TrueMI.getOperand (RISCVII::getVLOpNum (TrueMCID));
241
+ const MachineOperand &VL = MI.getOperand (RISCVII::getVLOpNum (MI.getDesc ()));
250
242
auto MinVL = GetMinVL (TrueVL, VL);
251
243
if (!MinVL)
252
244
return false ;
@@ -255,7 +247,7 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
255
247
// If we end up changing the VL or mask of True, then we need to make sure it
256
248
// doesn't raise any observable fp exceptions, since changing the active
257
249
// elements will affect how fflags is set.
258
- if (VLChanged || !IsMasked )
250
+ if (VLChanged || !TrueIsMasked )
259
251
if (TrueMCID.mayRaiseFPException () &&
260
252
!TrueMI.getFlag (MachineInstr::MIFlag::NoFPExcept))
261
253
return false ;
@@ -287,8 +279,9 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
287
279
// Set the merge to the false operand of the merge.
288
280
TrueMI.getOperand (1 ).setReg (False->getReg ());
289
281
282
+ bool NeedToMoveOldMask = TrueIsMasked;
290
283
// If we're converting it to a masked pseudo, reuse MI's mask.
291
- if (!IsMasked ) {
284
+ if (!TrueIsMasked ) {
292
285
if (BaseOpc == RISCV::VMV_V_V) {
293
286
// If MI is a vmv.v.v, it won't have a mask operand. So insert an all-ones
294
287
// mask just before True.
@@ -302,6 +295,7 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
302
295
BuildMI (*MI.getParent (), TrueMI, MI.getDebugLoc (), TII->get (RISCV::COPY),
303
296
RISCV::V0)
304
297
.addReg (Dest);
298
+ NeedToMoveOldMask = true ;
305
299
}
306
300
307
301
TrueMI.setDesc (MaskedMCID);
@@ -342,9 +336,17 @@ bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
342
336
MRI->constrainRegClass (PassthruReg, V0RC);
343
337
344
338
MRI->replaceRegWith (MI.getOperand (0 ).getReg (), TrueMI.getOperand (0 ).getReg ());
339
+
340
+ // We need to move the old mask copy to after MI if:
341
+ // - TrueMI is masked and we are using its mask instead
342
+ // - We created a new all ones mask that clobbers V0
343
+ if (NeedToMoveOldMask && MaskDef) {
344
+ assert (MaskDef->getParent () == MI.getParent ());
345
+ MaskDef->removeFromParent ();
346
+ MI.getParent ()->insertAfter (MI.getIterator (), MaskDef);
347
+ }
348
+
345
349
MI.eraseFromParent ();
346
- if (IsMasked)
347
- MaskDef->eraseFromParent ();
348
350
349
351
return true ;
350
352
}
@@ -369,8 +371,11 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) {
369
371
CASE_VMERGE_TO_VMV (M8)
370
372
}
371
373
374
+ Register MergeReg = MI.getOperand (1 ).getReg ();
375
+ Register FalseReg = MI.getOperand (2 ).getReg ();
372
376
// Check merge == false (or merge == undef)
373
- if (!isOpSameAs (MI.getOperand (1 ), MI.getOperand (2 )))
377
+ if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike (MergeReg, MRI) !=
378
+ TRI->lookThruCopyLike (FalseReg, MRI))
374
379
return false ;
375
380
376
381
assert (MI.getOperand (4 ).isReg () && MI.getOperand (4 ).getReg () == RISCV::V0);
@@ -468,8 +473,9 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
468
473
469
474
CurrentV0Def = nullptr ;
470
475
for (MachineInstr &MI : MBB) {
471
- Changed |= convertVMergeToVMv (MI, CurrentV0Def);
472
476
Changed |= convertToUnmasked (MI, CurrentV0Def);
477
+ Changed |= convertVMergeToVMv (MI, CurrentV0Def);
478
+
473
479
if (MI.definesRegister (RISCV::V0, TRI))
474
480
CurrentV0Def = &MI;
475
481
}
0 commit comments