Skip to content

Commit b24bf27

Browse files
committed
[NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller
All expansions end with replacing the previous inrinsic with the new expansion and erasing the old one. By moving this operation to the caller, these expansion functions can be called in more contexts and a small amount of duplicated code is consolidated. Pre-req for llvm#88056
1 parent 9f89d31 commit b24bf27

File tree

1 file changed

+60
-75
lines changed

1 file changed

+60
-75
lines changed

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 60 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
5151
return false;
5252
}
5353

54-
static bool expandAbs(CallInst *Orig) {
54+
static Value *expandAbs(CallInst *Orig) {
5555
Value *X = Orig->getOperand(0);
5656
IRBuilder<> Builder(Orig->getParent());
5757
Builder.SetInsertPoint(Orig);
@@ -66,12 +66,10 @@ static bool expandAbs(CallInst *Orig) {
6666
auto *V = Builder.CreateSub(Zero, X);
6767
auto *MaxCall =
6868
Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
69-
Orig->replaceAllUsesWith(MaxCall);
70-
Orig->eraseFromParent();
71-
return true;
69+
return MaxCall;
7270
}
7371

74-
static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
72+
static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
7573
assert(DotIntrinsic == Intrinsic::dx_sdot ||
7674
DotIntrinsic == Intrinsic::dx_udot);
7775
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
@@ -97,12 +95,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
9795
ArrayRef<Value *>{Elt0, Elt1, Result},
9896
nullptr, "dx.mad");
9997
}
100-
Orig->replaceAllUsesWith(Result);
101-
Orig->eraseFromParent();
102-
return true;
98+
return Result;
10399
}
104100

105-
static bool expandExpIntrinsic(CallInst *Orig) {
101+
static Value *expandExpIntrinsic(CallInst *Orig) {
106102
Value *X = Orig->getOperand(0);
107103
IRBuilder<> Builder(Orig->getParent());
108104
Builder.SetInsertPoint(Orig);
@@ -119,23 +115,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
119115
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
120116
Exp2Call->setTailCall(Orig->isTailCall());
121117
Exp2Call->setAttributes(Orig->getAttributes());
122-
Orig->replaceAllUsesWith(Exp2Call);
123-
Orig->eraseFromParent();
124-
return true;
118+
return Exp2Call;
125119
}
126120

127-
static bool expandAnyIntrinsic(CallInst *Orig) {
121+
static Value *expandAnyIntrinsic(CallInst *Orig) {
128122
Value *X = Orig->getOperand(0);
129123
IRBuilder<> Builder(Orig->getParent());
130124
Builder.SetInsertPoint(Orig);
131125
Type *Ty = X->getType();
132126
Type *EltTy = Ty->getScalarType();
133127

128+
Value *Result = nullptr;
134129
if (!Ty->isVectorTy()) {
135-
Value *Cond = EltTy->isFloatingPointTy()
136-
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
137-
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
138-
Orig->replaceAllUsesWith(Cond);
130+
Result = EltTy->isFloatingPointTy()
131+
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
132+
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
139133
} else {
140134
auto *XVec = dyn_cast<FixedVectorType>(Ty);
141135
Value *Cond =
@@ -148,18 +142,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
148142
X, ConstantVector::getSplat(
149143
ElementCount::getFixed(XVec->getNumElements()),
150144
ConstantInt::get(EltTy, 0)));
151-
Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
145+
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
152146
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
153147
Value *Elt = Builder.CreateExtractElement(Cond, I);
154148
Result = Builder.CreateOr(Result, Elt);
155149
}
156-
Orig->replaceAllUsesWith(Result);
157150
}
158-
Orig->eraseFromParent();
159-
return true;
151+
return Result;
160152
}
161153

162-
static bool expandLengthIntrinsic(CallInst *Orig) {
154+
static Value *expandLengthIntrinsic(CallInst *Orig) {
163155
Value *X = Orig->getOperand(0);
164156
IRBuilder<> Builder(Orig->getParent());
165157
Builder.SetInsertPoint(Orig);
@@ -182,30 +174,23 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
182174
Value *Mul = Builder.CreateFMul(Elt, Elt);
183175
Sum = Builder.CreateFAdd(Sum, Mul);
184176
}
185-
Value *Result = Builder.CreateIntrinsic(
186-
EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
187-
188-
Orig->replaceAllUsesWith(Result);
189-
Orig->eraseFromParent();
190-
return true;
177+
return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
178+
nullptr, "elt.sqrt");
191179
}
192180

193-
static bool expandLerpIntrinsic(CallInst *Orig) {
181+
static Value *expandLerpIntrinsic(CallInst *Orig) {
194182
Value *X = Orig->getOperand(0);
195183
Value *Y = Orig->getOperand(1);
196184
Value *S = Orig->getOperand(2);
197185
IRBuilder<> Builder(Orig->getParent());
198186
Builder.SetInsertPoint(Orig);
199187
auto *V = Builder.CreateFSub(Y, X);
200188
V = Builder.CreateFMul(S, V);
201-
auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
202-
Orig->replaceAllUsesWith(Result);
203-
Orig->eraseFromParent();
204-
return true;
189+
return Builder.CreateFAdd(X, V, "dx.lerp");
205190
}
206191

207-
static bool expandLogIntrinsic(CallInst *Orig,
208-
float LogConstVal = numbers::ln2f) {
192+
static Value *expandLogIntrinsic(CallInst *Orig,
193+
float LogConstVal = numbers::ln2f) {
209194
Value *X = Orig->getOperand(0);
210195
IRBuilder<> Builder(Orig->getParent());
211196
Builder.SetInsertPoint(Orig);
@@ -221,16 +206,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
221206
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
222207
Log2Call->setTailCall(Orig->isTailCall());
223208
Log2Call->setAttributes(Orig->getAttributes());
224-
auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
225-
Orig->replaceAllUsesWith(Result);
226-
Orig->eraseFromParent();
227-
return true;
209+
return Builder.CreateFMul(Ln2Const, Log2Call);
228210
}
229-
static bool expandLog10Intrinsic(CallInst *Orig) {
211+
static Value *expandLog10Intrinsic(CallInst *Orig) {
230212
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
231213
}
232214

233-
static bool expandNormalizeIntrinsic(CallInst *Orig) {
215+
static Value *expandNormalizeIntrinsic(CallInst *Orig) {
234216
Value *X = Orig->getOperand(0);
235217
Type *Ty = Orig->getType();
236218
Type *EltTy = Ty->getScalarType();
@@ -245,11 +227,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
245227
report_fatal_error(Twine("Invalid input scalar: length is zero"),
246228
/* gen_crash_diag=*/false);
247229
}
248-
Value *Result = Builder.CreateFDiv(X, X);
249-
250-
Orig->replaceAllUsesWith(Result);
251-
Orig->eraseFromParent();
252-
return true;
230+
return Builder.CreateFDiv(X, X);
253231
}
254232

255233
unsigned XVecSize = XVec->getNumElements();
@@ -291,14 +269,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
291269
nullptr, "dx.rsqrt");
292270

293271
Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
294-
Value *Result = Builder.CreateFMul(X, MultiplicandVec);
295-
296-
Orig->replaceAllUsesWith(Result);
297-
Orig->eraseFromParent();
298-
return true;
272+
return Builder.CreateFMul(X, MultiplicandVec);
299273
}
300274

301-
static bool expandPowIntrinsic(CallInst *Orig) {
275+
static Value *expandPowIntrinsic(CallInst *Orig) {
302276

303277
Value *X = Orig->getOperand(0);
304278
Value *Y = Orig->getOperand(1);
@@ -313,9 +287,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
313287
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
314288
Exp2Call->setTailCall(Orig->isTailCall());
315289
Exp2Call->setAttributes(Orig->getAttributes());
316-
Orig->replaceAllUsesWith(Exp2Call);
317-
Orig->eraseFromParent();
318-
return true;
290+
return Exp2Call;
319291
}
320292

321293
static Intrinsic::ID getMaxForClamp(Type *ElemTy,
@@ -344,7 +316,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
344316
return Intrinsic::minnum;
345317
}
346318

347-
static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
319+
static Value *expandClampIntrinsic(CallInst *Orig,
320+
Intrinsic::ID ClampIntrinsic) {
348321
Value *X = Orig->getOperand(0);
349322
Value *Min = Orig->getOperand(1);
350323
Value *Max = Orig->getOperand(2);
@@ -353,43 +326,55 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
353326
Builder.SetInsertPoint(Orig);
354327
auto *MaxCall = Builder.CreateIntrinsic(
355328
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
356-
auto *MinCall =
357-
Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
358-
{MaxCall, Max}, nullptr, "dx.min");
359-
360-
Orig->replaceAllUsesWith(MinCall);
361-
Orig->eraseFromParent();
362-
return true;
329+
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
330+
{MaxCall, Max}, nullptr, "dx.min");
363331
}
364332

365333
static bool expandIntrinsic(Function &F, CallInst *Orig) {
334+
Value *Result = nullptr;
366335
switch (F.getIntrinsicID()) {
367336
case Intrinsic::abs:
368-
return expandAbs(Orig);
337+
Result = expandAbs(Orig);
338+
break;
369339
case Intrinsic::exp:
370-
return expandExpIntrinsic(Orig);
340+
Result = expandExpIntrinsic(Orig);
341+
break;
371342
case Intrinsic::log:
372-
return expandLogIntrinsic(Orig);
343+
Result = expandLogIntrinsic(Orig);
344+
break;
373345
case Intrinsic::log10:
374-
return expandLog10Intrinsic(Orig);
346+
Result = expandLog10Intrinsic(Orig);
347+
break;
375348
case Intrinsic::pow:
376-
return expandPowIntrinsic(Orig);
349+
Result = expandPowIntrinsic(Orig);
350+
break;
377351
case Intrinsic::dx_any:
378-
return expandAnyIntrinsic(Orig);
352+
Result = expandAnyIntrinsic(Orig);
353+
break;
379354
case Intrinsic::dx_uclamp:
380355
case Intrinsic::dx_clamp:
381-
return expandClampIntrinsic(Orig, F.getIntrinsicID());
356+
Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
357+
break;
382358
case Intrinsic::dx_lerp:
383-
return expandLerpIntrinsic(Orig);
359+
Result = expandLerpIntrinsic(Orig);
360+
break;
384361
case Intrinsic::dx_length:
385-
return expandLengthIntrinsic(Orig);
362+
Result = expandLengthIntrinsic(Orig);
363+
break;
386364
case Intrinsic::dx_normalize:
387-
return expandNormalizeIntrinsic(Orig);
365+
Result = expandNormalizeIntrinsic(Orig);
366+
break;
388367
case Intrinsic::dx_sdot:
389368
case Intrinsic::dx_udot:
390-
return expandIntegerDot(Orig, F.getIntrinsicID());
369+
Result = expandIntegerDot(Orig, F.getIntrinsicID());
370+
break;
391371
}
392-
return false;
372+
373+
if (Result) {
374+
Orig->replaceAllUsesWith(Result);
375+
Orig->eraseFromParent();
376+
}
377+
return !!Result;
393378
}
394379

395380
static bool expansionIntrinsics(Module &M) {

0 commit comments

Comments
 (0)