Skip to content

Commit ded9564

Browse files
committed
[AMDGPU] Add IR LiveReg type-based optimization
Change-Id: Ia0d11b79b8302e79247fe193ccabc0dad2d359a0
1 parent 9f10252 commit ded9564

11 files changed

+2565
-2028
lines changed

llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp

Lines changed: 293 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,73 @@ class AMDGPULateCodeGenPrepare
8181
bool visitLoadInst(LoadInst &LI);
8282
};
8383

84+
using ValueToValueMap = DenseMap<const Value *, Value *>;
85+
86+
class LiveRegOptimizer {
87+
private:
88+
Module *Mod = nullptr;
89+
const DataLayout *DL = nullptr;
90+
const GCNSubtarget *ST;
91+
/// The scalar type to convert to
92+
Type *ConvertToScalar;
93+
/// The set of visited Instructions
94+
SmallPtrSet<Instruction *, 4> Visited;
95+
/// The set of Instructions to be deleted
96+
SmallPtrSet<Instruction *, 4> DeadInstrs;
97+
/// Map of Value -> Converted Value
98+
ValueToValueMap ValMap;
99+
/// Map of containing conversions from Optimal Type -> Original Type per BB.
100+
DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101+
102+
public:
103+
/// Calculate the and \p return the type to convert to given a problematic \p
104+
/// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105+
Type *calculateConvertType(Type *OriginalType);
106+
/// Convert the virtual register defined by \p V to the compatible vector of
107+
/// legal type
108+
Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
109+
/// Convert the virtual register defined by \p V back to the original type \p
110+
/// ConvertType, stripping away the MSBs in cases where there was an imperfect
111+
/// fit (e.g. v2i32 -> v7i8)
112+
Value *convertFromOptType(Type *ConvertType, Instruction *V,
113+
BasicBlock::iterator &InstPt,
114+
BasicBlock *InsertBlock);
115+
/// Check for problematic PHI nodes or cross-bb values based on the value
116+
/// defined by \p I, and coerce to legal types if necessary. For problematic
117+
/// PHI node, we coerce all incoming values in a single invocation.
118+
bool optimizeLiveType(Instruction *I);
119+
120+
/// Remove all instructions that have become dead (i.e. all the re-typed PHIs)
121+
void removeDeadInstrs();
122+
123+
// Whether or not the type should be replaced to avoid inefficient
124+
// legalization code
125+
bool shouldReplace(Type *ITy) {
126+
FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
127+
if (!VTy)
128+
return false;
129+
130+
auto TLI = ST->getTargetLowering();
131+
132+
Type *EltTy = VTy->getElementType();
133+
// If the element size is not less than the convert to scalar size, then we
134+
// can't do any bit packing
135+
if (!EltTy->isIntegerTy() ||
136+
EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
137+
return false;
138+
139+
// Only coerce illegal types
140+
TargetLoweringBase::LegalizeKind LK =
141+
TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
142+
return LK.first != TargetLoweringBase::TypeLegal;
143+
}
144+
145+
LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
146+
DL = &Mod->getDataLayout();
147+
ConvertToScalar = Type::getInt32Ty(Mod->getContext());
148+
}
149+
};
150+
84151
} // end anonymous namespace
85152

86153
bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) {
@@ -102,14 +169,238 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
102169
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
103170
UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
104171

172+
// "Optimize" the virtual regs that cross basic block boundaries. When
173+
// building the SelectionDAG, vectors of illegal types that cross basic blocks
174+
// will be scalarized and widened, with each scalar living in its
175+
// own register. To work around this, this optimization converts the
176+
// vectors to equivalent vectors of legal type (which are converted back
177+
// before uses in subsequent blocks), to pack the bits into fewer physical
178+
// registers (used in CopyToReg/CopyFromReg pairs).
179+
LiveRegOptimizer LRO(Mod, &ST);
180+
105181
bool Changed = false;
182+
106183
for (auto &BB : F)
107-
for (Instruction &I : llvm::make_early_inc_range(BB))
184+
for (Instruction &I : make_early_inc_range(BB)) {
108185
Changed |= visit(I);
186+
Changed |= LRO.optimizeLiveType(&I);
187+
}
109188

189+
LRO.removeDeadInstrs();
110190
return Changed;
111191
}
112192

193+
Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
194+
assert(OriginalType->getScalarSizeInBits() <=
195+
ConvertToScalar->getScalarSizeInBits());
196+
197+
FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
198+
199+
TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
200+
TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
201+
unsigned ConvertEltCount =
202+
(OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
203+
204+
if (OriginalSize <= ConvertScalarSize)
205+
return IntegerType::get(Mod->getContext(), ConvertScalarSize);
206+
207+
return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
208+
ConvertEltCount, false);
209+
}
210+
211+
Value *LiveRegOptimizer::convertToOptType(Instruction *V,
212+
BasicBlock::iterator &InsertPt) {
213+
FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
214+
Type *NewTy = calculateConvertType(V->getType());
215+
216+
TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
217+
TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
218+
219+
IRBuilder<> Builder(V->getParent(), InsertPt);
220+
// If there is a bitsize match, we can fit the old vector into a new vector of
221+
// desired type.
222+
if (OriginalSize == NewSize)
223+
return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
224+
225+
// If there is a bitsize mismatch, we must use a wider vector.
226+
assert(NewSize > OriginalSize);
227+
uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
228+
229+
SmallVector<int, 8> ShuffleMask;
230+
uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
231+
for (unsigned I = 0; I < OriginalElementCount; I++)
232+
ShuffleMask.push_back(I);
233+
234+
for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
235+
ShuffleMask.push_back(OriginalElementCount);
236+
237+
Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
238+
return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
239+
}
240+
241+
Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
242+
BasicBlock::iterator &InsertPt,
243+
BasicBlock *InsertBB) {
244+
FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
245+
246+
TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
247+
TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
248+
249+
IRBuilder<> Builder(InsertBB, InsertPt);
250+
// If there is a bitsize match, we simply convert back to the original type.
251+
if (OriginalSize == NewSize)
252+
return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
253+
254+
// If there is a bitsize mismatch, then we must have used a wider value to
255+
// hold the bits.
256+
assert(OriginalSize > NewSize);
257+
// For wide scalars, we can just truncate the value.
258+
if (!V->getType()->isVectorTy()) {
259+
Instruction *Trunc = cast<Instruction>(
260+
Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
261+
return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
262+
}
263+
264+
// For wider vectors, we must strip the MSBs to convert back to the original
265+
// type.
266+
VectorType *ExpandedVT = VectorType::get(
267+
Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
268+
(OriginalSize / NewVTy->getScalarSizeInBits()), false);
269+
Instruction *Converted =
270+
cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
271+
272+
unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
273+
SmallVector<int, 8> ShuffleMask(NarrowElementCount);
274+
std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
275+
276+
return Builder.CreateShuffleVector(Converted, ShuffleMask);
277+
}
278+
279+
bool LiveRegOptimizer::optimizeLiveType(Instruction *I) {
280+
SmallVector<Instruction *, 4> Worklist;
281+
SmallPtrSet<PHINode *, 4> PhiNodes;
282+
SmallPtrSet<Instruction *, 4> Defs;
283+
SmallPtrSet<Instruction *, 4> Uses;
284+
285+
Worklist.push_back(cast<Instruction>(I));
286+
while (!Worklist.empty()) {
287+
Instruction *II = Worklist.pop_back_val();
288+
289+
if (!Visited.insert(II).second)
290+
continue;
291+
292+
if (!shouldReplace(II->getType()))
293+
continue;
294+
295+
if (PHINode *Phi = dyn_cast<PHINode>(II)) {
296+
PhiNodes.insert(Phi);
297+
// Collect all the incoming values of problematic PHI nodes.
298+
for (Value *V : Phi->incoming_values()) {
299+
// Repeat the collection process for newly found PHI nodes.
300+
if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
301+
if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
302+
Worklist.push_back(OpPhi);
303+
continue;
304+
}
305+
306+
Instruction *IncInst = dyn_cast<Instruction>(V);
307+
// Other incoming value types (e.g. vector literals) are unhandled
308+
if (!IncInst && !isa<ConstantAggregateZero>(V))
309+
return false;
310+
311+
// Collect all other incoming values for coercion.
312+
if (IncInst)
313+
Defs.insert(IncInst);
314+
}
315+
}
316+
317+
// Collect all relevant uses.
318+
for (User *V : II->users()) {
319+
// Repeat the collection process for problematic PHI nodes.
320+
if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
321+
if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
322+
Worklist.push_back(OpPhi);
323+
continue;
324+
}
325+
326+
Instruction *UseInst = cast<Instruction>(V);
327+
// Collect all uses of PHINodes and any use the crosses BB boundaries.
328+
if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
329+
Uses.insert(UseInst);
330+
if (!Defs.count(II) && !isa<PHINode>(II)) {
331+
Defs.insert(II);
332+
}
333+
}
334+
}
335+
}
336+
337+
// Coerce and track the defs.
338+
for (Instruction *D : Defs) {
339+
if (!ValMap.contains(D)) {
340+
BasicBlock::iterator InsertPt = std::next(D->getIterator());
341+
Value *ConvertVal = convertToOptType(D, InsertPt);
342+
assert(ConvertVal);
343+
ValMap[D] = ConvertVal;
344+
}
345+
}
346+
347+
// Construct new-typed PHI nodes.
348+
for (PHINode *Phi : PhiNodes) {
349+
ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
350+
Phi->getNumIncomingValues(),
351+
Phi->getName() + ".tc", Phi->getIterator());
352+
}
353+
354+
// Connect all the PHI nodes with their new incoming values.
355+
for (PHINode *Phi : PhiNodes) {
356+
PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
357+
bool MissingIncVal = false;
358+
for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
359+
Value *IncVal = Phi->getIncomingValue(I);
360+
if (isa<ConstantAggregateZero>(IncVal)) {
361+
Type *NewType = calculateConvertType(Phi->getType());
362+
NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
363+
Phi->getIncomingBlock(I));
364+
} else if (ValMap.contains(IncVal))
365+
NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
366+
else
367+
MissingIncVal = true;
368+
}
369+
DeadInstrs.insert(MissingIncVal ? cast<Instruction>(ValMap[Phi]) : Phi);
370+
}
371+
// Coerce back to the original type and replace the uses.
372+
for (Instruction *U : Uses) {
373+
// Replace all converted operands for a use.
374+
for (auto [OpIdx, Op] : enumerate(U->operands())) {
375+
if (ValMap.contains(Op)) {
376+
Value *NewVal = nullptr;
377+
if (BBUseValMap.contains(U->getParent()) &&
378+
BBUseValMap[U->getParent()].contains(ValMap[Op]))
379+
NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
380+
else {
381+
BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
382+
NewVal =
383+
convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
384+
InsertPt, U->getParent());
385+
BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
386+
}
387+
assert(NewVal);
388+
U->setOperand(OpIdx, NewVal);
389+
}
390+
}
391+
}
392+
393+
return true;
394+
}
395+
396+
void LiveRegOptimizer::removeDeadInstrs() {
397+
// Remove instrs that have been marked dead after type-coercion.
398+
for (auto *I : DeadInstrs) {
399+
I->replaceAllUsesWith(PoisonValue::get(I->getType()));
400+
I->eraseFromParent();
401+
}
402+
}
403+
113404
bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
114405
unsigned AS = LI.getPointerAddressSpace();
115406
// Skip non-constant address space.
@@ -119,7 +410,7 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
119410
// Skip non-simple loads.
120411
if (!LI.isSimple())
121412
return false;
122-
auto *Ty = LI.getType();
413+
Type *Ty = LI.getType();
123414
// Skip aggregate types.
124415
if (Ty->isAggregateType())
125416
return false;

llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,10 +1197,10 @@ bool GCNPassConfig::addPreISel() {
11971197
AMDGPUPassConfig::addPreISel();
11981198

11991199
if (TM->getOptLevel() > CodeGenOptLevel::None)
1200-
addPass(createAMDGPULateCodeGenPreparePass());
1200+
addPass(createSinkingPass());
12011201

12021202
if (TM->getOptLevel() > CodeGenOptLevel::None)
1203-
addPass(createSinkingPass());
1203+
addPass(createAMDGPULateCodeGenPreparePass());
12041204

12051205
// Merge divergent exit nodes. StructurizeCFG won't recognize the multi-exit
12061206
// regions formed by them.

0 commit comments

Comments
 (0)