@@ -81,6 +81,73 @@ class AMDGPULateCodeGenPrepare
81
81
bool visitLoadInst (LoadInst &LI);
82
82
};
83
83
84
+ using ValueToValueMap = DenseMap<const Value *, Value *>;
85
+
86
+ class LiveRegOptimizer {
87
+ private:
88
+ Module *Mod = nullptr ;
89
+ const DataLayout *DL = nullptr ;
90
+ const GCNSubtarget *ST;
91
+ // / The scalar type to convert to
92
+ Type *ConvertToScalar;
93
+ // / The set of visited Instructions
94
+ SmallPtrSet<Instruction *, 4 > Visited;
95
+ // / The set of Instructions to be deleted
96
+ SmallPtrSet<Instruction *, 4 > DeadInstrs;
97
+ // / Map of Value -> Converted Value
98
+ ValueToValueMap ValMap;
99
+ // / Map of containing conversions from Optimal Type -> Original Type per BB.
100
+ DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101
+
102
+ public:
103
+ // / Calculate the and \p return the type to convert to given a problematic \p
104
+ // / OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105
+ Type *calculateConvertType (Type *OriginalType);
106
+ // / Convert the virtual register defined by \p V to the compatible vector of
107
+ // / legal type
108
+ Value *convertToOptType (Instruction *V, BasicBlock::iterator &InstPt);
109
+ // / Convert the virtual register defined by \p V back to the original type \p
110
+ // / ConvertType, stripping away the MSBs in cases where there was an imperfect
111
+ // / fit (e.g. v2i32 -> v7i8)
112
+ Value *convertFromOptType (Type *ConvertType, Instruction *V,
113
+ BasicBlock::iterator &InstPt,
114
+ BasicBlock *InsertBlock);
115
+ // / Check for problematic PHI nodes or cross-bb values based on the value
116
+ // / defined by \p I, and coerce to legal types if necessary. For problematic
117
+ // / PHI node, we coerce all incoming values in a single invocation.
118
+ bool optimizeLiveType (Instruction *I);
119
+
120
+ // / Remove all instructions that have become dead (i.e. all the re-typed PHIs)
121
+ void removeDeadInstrs ();
122
+
123
+ // Whether or not the type should be replaced to avoid inefficient
124
+ // legalization code
125
+ bool shouldReplace (Type *ITy) {
126
+ FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
127
+ if (!VTy)
128
+ return false ;
129
+
130
+ auto TLI = ST->getTargetLowering ();
131
+
132
+ Type *EltTy = VTy->getElementType ();
133
+ // If the element size is not less than the convert to scalar size, then we
134
+ // can't do any bit packing
135
+ if (!EltTy->isIntegerTy () ||
136
+ EltTy->getScalarSizeInBits () > ConvertToScalar->getScalarSizeInBits ())
137
+ return false ;
138
+
139
+ // Only coerce illegal types
140
+ TargetLoweringBase::LegalizeKind LK =
141
+ TLI->getTypeConversion (EltTy->getContext (), EVT::getEVT (EltTy, false ));
142
+ return LK.first != TargetLoweringBase::TypeLegal;
143
+ }
144
+
145
+ LiveRegOptimizer (Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
146
+ DL = &Mod->getDataLayout ();
147
+ ConvertToScalar = Type::getInt32Ty (Mod->getContext ());
148
+ }
149
+ };
150
+
84
151
} // end anonymous namespace
85
152
86
153
bool AMDGPULateCodeGenPrepare::doInitialization (Module &M) {
@@ -102,14 +169,238 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
102
169
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache (F);
103
170
UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo ();
104
171
172
+ // "Optimize" the virtual regs that cross basic block boundaries. When
173
+ // building the SelectionDAG, vectors of illegal types that cross basic blocks
174
+ // will be scalarized and widened, with each scalar living in its
175
+ // own register. To work around this, this optimization converts the
176
+ // vectors to equivalent vectors of legal type (which are converted back
177
+ // before uses in subsequent blocks), to pack the bits into fewer physical
178
+ // registers (used in CopyToReg/CopyFromReg pairs).
179
+ LiveRegOptimizer LRO (Mod, &ST);
180
+
105
181
bool Changed = false ;
182
+
106
183
for (auto &BB : F)
107
- for (Instruction &I : llvm:: make_early_inc_range (BB))
184
+ for (Instruction &I : make_early_inc_range (BB)) {
108
185
Changed |= visit (I);
186
+ Changed |= LRO.optimizeLiveType (&I);
187
+ }
109
188
189
+ LRO.removeDeadInstrs ();
110
190
return Changed;
111
191
}
112
192
193
+ Type *LiveRegOptimizer::calculateConvertType (Type *OriginalType) {
194
+ assert (OriginalType->getScalarSizeInBits () <=
195
+ ConvertToScalar->getScalarSizeInBits ());
196
+
197
+ FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
198
+
199
+ TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
200
+ TypeSize ConvertScalarSize = DL->getTypeSizeInBits (ConvertToScalar);
201
+ unsigned ConvertEltCount =
202
+ (OriginalSize + ConvertScalarSize - 1 ) / ConvertScalarSize;
203
+
204
+ if (OriginalSize <= ConvertScalarSize)
205
+ return IntegerType::get (Mod->getContext (), ConvertScalarSize);
206
+
207
+ return VectorType::get (Type::getIntNTy (Mod->getContext (), ConvertScalarSize),
208
+ ConvertEltCount, false );
209
+ }
210
+
211
+ Value *LiveRegOptimizer::convertToOptType (Instruction *V,
212
+ BasicBlock::iterator &InsertPt) {
213
+ FixedVectorType *VTy = cast<FixedVectorType>(V->getType ());
214
+ Type *NewTy = calculateConvertType (V->getType ());
215
+
216
+ TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
217
+ TypeSize NewSize = DL->getTypeSizeInBits (NewTy);
218
+
219
+ IRBuilder<> Builder (V->getParent (), InsertPt);
220
+ // If there is a bitsize match, we can fit the old vector into a new vector of
221
+ // desired type.
222
+ if (OriginalSize == NewSize)
223
+ return Builder.CreateBitCast (V, NewTy, V->getName () + " .bc" );
224
+
225
+ // If there is a bitsize mismatch, we must use a wider vector.
226
+ assert (NewSize > OriginalSize);
227
+ uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits ();
228
+
229
+ SmallVector<int , 8 > ShuffleMask;
230
+ uint64_t OriginalElementCount = VTy->getElementCount ().getFixedValue ();
231
+ for (unsigned I = 0 ; I < OriginalElementCount; I++)
232
+ ShuffleMask.push_back (I);
233
+
234
+ for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
235
+ ShuffleMask.push_back (OriginalElementCount);
236
+
237
+ Value *ExpandedVec = Builder.CreateShuffleVector (V, ShuffleMask);
238
+ return Builder.CreateBitCast (ExpandedVec, NewTy, V->getName () + " .bc" );
239
+ }
240
+
241
+ Value *LiveRegOptimizer::convertFromOptType (Type *ConvertType, Instruction *V,
242
+ BasicBlock::iterator &InsertPt,
243
+ BasicBlock *InsertBB) {
244
+ FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
245
+
246
+ TypeSize OriginalSize = DL->getTypeSizeInBits (V->getType ());
247
+ TypeSize NewSize = DL->getTypeSizeInBits (NewVTy);
248
+
249
+ IRBuilder<> Builder (InsertBB, InsertPt);
250
+ // If there is a bitsize match, we simply convert back to the original type.
251
+ if (OriginalSize == NewSize)
252
+ return Builder.CreateBitCast (V, NewVTy, V->getName () + " .bc" );
253
+
254
+ // If there is a bitsize mismatch, then we must have used a wider value to
255
+ // hold the bits.
256
+ assert (OriginalSize > NewSize);
257
+ // For wide scalars, we can just truncate the value.
258
+ if (!V->getType ()->isVectorTy ()) {
259
+ Instruction *Trunc = cast<Instruction>(
260
+ Builder.CreateTrunc (V, IntegerType::get (Mod->getContext (), NewSize)));
261
+ return cast<Instruction>(Builder.CreateBitCast (Trunc, NewVTy));
262
+ }
263
+
264
+ // For wider vectors, we must strip the MSBs to convert back to the original
265
+ // type.
266
+ VectorType *ExpandedVT = VectorType::get (
267
+ Type::getIntNTy (Mod->getContext (), NewVTy->getScalarSizeInBits ()),
268
+ (OriginalSize / NewVTy->getScalarSizeInBits ()), false );
269
+ Instruction *Converted =
270
+ cast<Instruction>(Builder.CreateBitCast (V, ExpandedVT));
271
+
272
+ unsigned NarrowElementCount = NewVTy->getElementCount ().getFixedValue ();
273
+ SmallVector<int , 8 > ShuffleMask (NarrowElementCount);
274
+ std::iota (ShuffleMask.begin (), ShuffleMask.end (), 0 );
275
+
276
+ return Builder.CreateShuffleVector (Converted, ShuffleMask);
277
+ }
278
+
279
+ bool LiveRegOptimizer::optimizeLiveType (Instruction *I) {
280
+ SmallVector<Instruction *, 4 > Worklist;
281
+ SmallPtrSet<PHINode *, 4 > PhiNodes;
282
+ SmallPtrSet<Instruction *, 4 > Defs;
283
+ SmallPtrSet<Instruction *, 4 > Uses;
284
+
285
+ Worklist.push_back (cast<Instruction>(I));
286
+ while (!Worklist.empty ()) {
287
+ Instruction *II = Worklist.pop_back_val ();
288
+
289
+ if (!Visited.insert (II).second )
290
+ continue ;
291
+
292
+ if (!shouldReplace (II->getType ()))
293
+ continue ;
294
+
295
+ if (PHINode *Phi = dyn_cast<PHINode>(II)) {
296
+ PhiNodes.insert (Phi);
297
+ // Collect all the incoming values of problematic PHI nodes.
298
+ for (Value *V : Phi->incoming_values ()) {
299
+ // Repeat the collection process for newly found PHI nodes.
300
+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
301
+ if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
302
+ Worklist.push_back (OpPhi);
303
+ continue ;
304
+ }
305
+
306
+ Instruction *IncInst = dyn_cast<Instruction>(V);
307
+ // Other incoming value types (e.g. vector literals) are unhandled
308
+ if (!IncInst && !isa<ConstantAggregateZero>(V))
309
+ return false ;
310
+
311
+ // Collect all other incoming values for coercion.
312
+ if (IncInst)
313
+ Defs.insert (IncInst);
314
+ }
315
+ }
316
+
317
+ // Collect all relevant uses.
318
+ for (User *V : II->users ()) {
319
+ // Repeat the collection process for problematic PHI nodes.
320
+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
321
+ if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
322
+ Worklist.push_back (OpPhi);
323
+ continue ;
324
+ }
325
+
326
+ Instruction *UseInst = cast<Instruction>(V);
327
+ // Collect all uses of PHINodes and any use the crosses BB boundaries.
328
+ if (UseInst->getParent () != II->getParent () || isa<PHINode>(II)) {
329
+ Uses.insert (UseInst);
330
+ if (!Defs.count (II) && !isa<PHINode>(II)) {
331
+ Defs.insert (II);
332
+ }
333
+ }
334
+ }
335
+ }
336
+
337
+ // Coerce and track the defs.
338
+ for (Instruction *D : Defs) {
339
+ if (!ValMap.contains (D)) {
340
+ BasicBlock::iterator InsertPt = std::next (D->getIterator ());
341
+ Value *ConvertVal = convertToOptType (D, InsertPt);
342
+ assert (ConvertVal);
343
+ ValMap[D] = ConvertVal;
344
+ }
345
+ }
346
+
347
+ // Construct new-typed PHI nodes.
348
+ for (PHINode *Phi : PhiNodes) {
349
+ ValMap[Phi] = PHINode::Create (calculateConvertType (Phi->getType ()),
350
+ Phi->getNumIncomingValues (),
351
+ Phi->getName () + " .tc" , Phi->getIterator ());
352
+ }
353
+
354
+ // Connect all the PHI nodes with their new incoming values.
355
+ for (PHINode *Phi : PhiNodes) {
356
+ PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
357
+ bool MissingIncVal = false ;
358
+ for (int I = 0 , E = Phi->getNumIncomingValues (); I < E; I++) {
359
+ Value *IncVal = Phi->getIncomingValue (I);
360
+ if (isa<ConstantAggregateZero>(IncVal)) {
361
+ Type *NewType = calculateConvertType (Phi->getType ());
362
+ NewPhi->addIncoming (ConstantInt::get (NewType, 0 , false ),
363
+ Phi->getIncomingBlock (I));
364
+ } else if (ValMap.contains (IncVal))
365
+ NewPhi->addIncoming (ValMap[IncVal], Phi->getIncomingBlock (I));
366
+ else
367
+ MissingIncVal = true ;
368
+ }
369
+ DeadInstrs.insert (MissingIncVal ? cast<Instruction>(ValMap[Phi]) : Phi);
370
+ }
371
+ // Coerce back to the original type and replace the uses.
372
+ for (Instruction *U : Uses) {
373
+ // Replace all converted operands for a use.
374
+ for (auto [OpIdx, Op] : enumerate(U->operands ())) {
375
+ if (ValMap.contains (Op)) {
376
+ Value *NewVal = nullptr ;
377
+ if (BBUseValMap.contains (U->getParent ()) &&
378
+ BBUseValMap[U->getParent ()].contains (ValMap[Op]))
379
+ NewVal = BBUseValMap[U->getParent ()][ValMap[Op]];
380
+ else {
381
+ BasicBlock::iterator InsertPt = U->getParent ()->getFirstNonPHIIt ();
382
+ NewVal =
383
+ convertFromOptType (Op->getType (), cast<Instruction>(ValMap[Op]),
384
+ InsertPt, U->getParent ());
385
+ BBUseValMap[U->getParent ()][ValMap[Op]] = NewVal;
386
+ }
387
+ assert (NewVal);
388
+ U->setOperand (OpIdx, NewVal);
389
+ }
390
+ }
391
+ }
392
+
393
+ return true ;
394
+ }
395
+
396
+ void LiveRegOptimizer::removeDeadInstrs () {
397
+ // Remove instrs that have been marked dead after type-coercion.
398
+ for (auto *I : DeadInstrs) {
399
+ I->replaceAllUsesWith (PoisonValue::get (I->getType ()));
400
+ I->eraseFromParent ();
401
+ }
402
+ }
403
+
113
404
bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad (LoadInst &LI) const {
114
405
unsigned AS = LI.getPointerAddressSpace ();
115
406
// Skip non-constant address space.
@@ -119,7 +410,7 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
119
410
// Skip non-simple loads.
120
411
if (!LI.isSimple ())
121
412
return false ;
122
- auto *Ty = LI.getType ();
413
+ Type *Ty = LI.getType ();
123
414
// Skip aggregate types.
124
415
if (Ty->isAggregateType ())
125
416
return false ;
0 commit comments