@@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
51
51
return false ;
52
52
}
53
53
54
- static bool expandAbs (CallInst *Orig) {
54
+ static Value * expandAbs (CallInst *Orig) {
55
55
Value *X = Orig->getOperand (0 );
56
56
IRBuilder<> Builder (Orig->getParent ());
57
57
Builder.SetInsertPoint (Orig);
@@ -66,12 +66,10 @@ static bool expandAbs(CallInst *Orig) {
66
66
auto *V = Builder.CreateSub (Zero, X);
67
67
auto *MaxCall =
68
68
Builder.CreateIntrinsic (Ty, Intrinsic::smax, {X, V}, nullptr , " dx.max" );
69
- Orig->replaceAllUsesWith (MaxCall);
70
- Orig->eraseFromParent ();
71
- return true ;
69
+ return MaxCall;
72
70
}
73
71
74
- static bool expandIntegerDot (CallInst *Orig, Intrinsic::ID DotIntrinsic) {
72
+ static Value * expandIntegerDot (CallInst *Orig, Intrinsic::ID DotIntrinsic) {
75
73
assert (DotIntrinsic == Intrinsic::dx_sdot ||
76
74
DotIntrinsic == Intrinsic::dx_udot);
77
75
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
@@ -97,12 +95,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
97
95
ArrayRef<Value *>{Elt0, Elt1, Result},
98
96
nullptr , " dx.mad" );
99
97
}
100
- Orig->replaceAllUsesWith (Result);
101
- Orig->eraseFromParent ();
102
- return true ;
98
+ return Result;
103
99
}
104
100
105
- static bool expandExpIntrinsic (CallInst *Orig) {
101
+ static Value * expandExpIntrinsic (CallInst *Orig) {
106
102
Value *X = Orig->getOperand (0 );
107
103
IRBuilder<> Builder (Orig->getParent ());
108
104
Builder.SetInsertPoint (Orig);
@@ -119,23 +115,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
119
115
Builder.CreateIntrinsic (Ty, Intrinsic::exp2 , {NewX}, nullptr , " dx.exp2" );
120
116
Exp2Call->setTailCall (Orig->isTailCall ());
121
117
Exp2Call->setAttributes (Orig->getAttributes ());
122
- Orig->replaceAllUsesWith (Exp2Call);
123
- Orig->eraseFromParent ();
124
- return true ;
118
+ return Exp2Call;
125
119
}
126
120
127
- static bool expandAnyIntrinsic (CallInst *Orig) {
121
+ static Value * expandAnyIntrinsic (CallInst *Orig) {
128
122
Value *X = Orig->getOperand (0 );
129
123
IRBuilder<> Builder (Orig->getParent ());
130
124
Builder.SetInsertPoint (Orig);
131
125
Type *Ty = X->getType ();
132
126
Type *EltTy = Ty->getScalarType ();
133
127
128
+ Value *Result = nullptr ;
134
129
if (!Ty->isVectorTy ()) {
135
- Value *Cond = EltTy->isFloatingPointTy ()
136
- ? Builder.CreateFCmpUNE (X, ConstantFP::get (EltTy, 0 ))
137
- : Builder.CreateICmpNE (X, ConstantInt::get (EltTy, 0 ));
138
- Orig->replaceAllUsesWith (Cond);
130
+ Result = EltTy->isFloatingPointTy ()
131
+ ? Builder.CreateFCmpUNE (X, ConstantFP::get (EltTy, 0 ))
132
+ : Builder.CreateICmpNE (X, ConstantInt::get (EltTy, 0 ));
139
133
} else {
140
134
auto *XVec = dyn_cast<FixedVectorType>(Ty);
141
135
Value *Cond =
@@ -148,18 +142,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
148
142
X, ConstantVector::getSplat (
149
143
ElementCount::getFixed (XVec->getNumElements ()),
150
144
ConstantInt::get (EltTy, 0 )));
151
- Value * Result = Builder.CreateExtractElement (Cond, (uint64_t )0 );
145
+ Result = Builder.CreateExtractElement (Cond, (uint64_t )0 );
152
146
for (unsigned I = 1 ; I < XVec->getNumElements (); I++) {
153
147
Value *Elt = Builder.CreateExtractElement (Cond, I);
154
148
Result = Builder.CreateOr (Result, Elt);
155
149
}
156
- Orig->replaceAllUsesWith (Result);
157
150
}
158
- Orig->eraseFromParent ();
159
- return true ;
151
+ return Result;
160
152
}
161
153
162
- static bool expandLengthIntrinsic (CallInst *Orig) {
154
+ static Value * expandLengthIntrinsic (CallInst *Orig) {
163
155
Value *X = Orig->getOperand (0 );
164
156
IRBuilder<> Builder (Orig->getParent ());
165
157
Builder.SetInsertPoint (Orig);
@@ -182,30 +174,23 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
182
174
Value *Mul = Builder.CreateFMul (Elt, Elt);
183
175
Sum = Builder.CreateFAdd (Sum, Mul);
184
176
}
185
- Value *Result = Builder.CreateIntrinsic (
186
- EltTy, Intrinsic::sqrt , ArrayRef<Value *>{Sum}, nullptr , " elt.sqrt" );
187
-
188
- Orig->replaceAllUsesWith (Result);
189
- Orig->eraseFromParent ();
190
- return true ;
177
+ return Builder.CreateIntrinsic (EltTy, Intrinsic::sqrt , ArrayRef<Value *>{Sum},
178
+ nullptr , " elt.sqrt" );
191
179
}
192
180
193
- static bool expandLerpIntrinsic (CallInst *Orig) {
181
+ static Value * expandLerpIntrinsic (CallInst *Orig) {
194
182
Value *X = Orig->getOperand (0 );
195
183
Value *Y = Orig->getOperand (1 );
196
184
Value *S = Orig->getOperand (2 );
197
185
IRBuilder<> Builder (Orig->getParent ());
198
186
Builder.SetInsertPoint (Orig);
199
187
auto *V = Builder.CreateFSub (Y, X);
200
188
V = Builder.CreateFMul (S, V);
201
- auto *Result = Builder.CreateFAdd (X, V, " dx.lerp" );
202
- Orig->replaceAllUsesWith (Result);
203
- Orig->eraseFromParent ();
204
- return true ;
189
+ return Builder.CreateFAdd (X, V, " dx.lerp" );
205
190
}
206
191
207
- static bool expandLogIntrinsic (CallInst *Orig,
208
- float LogConstVal = numbers::ln2f) {
192
+ static Value * expandLogIntrinsic (CallInst *Orig,
193
+ float LogConstVal = numbers::ln2f) {
209
194
Value *X = Orig->getOperand (0 );
210
195
IRBuilder<> Builder (Orig->getParent ());
211
196
Builder.SetInsertPoint (Orig);
@@ -221,16 +206,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
221
206
Builder.CreateIntrinsic (Ty, Intrinsic::log2 , {X}, nullptr , " elt.log2" );
222
207
Log2Call->setTailCall (Orig->isTailCall ());
223
208
Log2Call->setAttributes (Orig->getAttributes ());
224
- auto *Result = Builder.CreateFMul (Ln2Const, Log2Call);
225
- Orig->replaceAllUsesWith (Result);
226
- Orig->eraseFromParent ();
227
- return true ;
209
+ return Builder.CreateFMul (Ln2Const, Log2Call);
228
210
}
229
- static bool expandLog10Intrinsic (CallInst *Orig) {
211
+ static Value * expandLog10Intrinsic (CallInst *Orig) {
230
212
return expandLogIntrinsic (Orig, numbers::ln2f / numbers::ln10f);
231
213
}
232
214
233
- static bool expandNormalizeIntrinsic (CallInst *Orig) {
215
+ static Value * expandNormalizeIntrinsic (CallInst *Orig) {
234
216
Value *X = Orig->getOperand (0 );
235
217
Type *Ty = Orig->getType ();
236
218
Type *EltTy = Ty->getScalarType ();
@@ -245,11 +227,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
245
227
report_fatal_error (Twine (" Invalid input scalar: length is zero" ),
246
228
/* gen_crash_diag=*/ false );
247
229
}
248
- Value *Result = Builder.CreateFDiv (X, X);
249
-
250
- Orig->replaceAllUsesWith (Result);
251
- Orig->eraseFromParent ();
252
- return true ;
230
+ return Builder.CreateFDiv (X, X);
253
231
}
254
232
255
233
unsigned XVecSize = XVec->getNumElements ();
@@ -291,14 +269,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
291
269
nullptr , " dx.rsqrt" );
292
270
293
271
Value *MultiplicandVec = Builder.CreateVectorSplat (XVecSize, Multiplicand);
294
- Value *Result = Builder.CreateFMul (X, MultiplicandVec);
295
-
296
- Orig->replaceAllUsesWith (Result);
297
- Orig->eraseFromParent ();
298
- return true ;
272
+ return Builder.CreateFMul (X, MultiplicandVec);
299
273
}
300
274
301
- static bool expandPowIntrinsic (CallInst *Orig) {
275
+ static Value * expandPowIntrinsic (CallInst *Orig) {
302
276
303
277
Value *X = Orig->getOperand (0 );
304
278
Value *Y = Orig->getOperand (1 );
@@ -313,9 +287,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
313
287
Builder.CreateIntrinsic (Ty, Intrinsic::exp2 , {Mul}, nullptr , " elt.exp2" );
314
288
Exp2Call->setTailCall (Orig->isTailCall ());
315
289
Exp2Call->setAttributes (Orig->getAttributes ());
316
- Orig->replaceAllUsesWith (Exp2Call);
317
- Orig->eraseFromParent ();
318
- return true ;
290
+ return Exp2Call;
319
291
}
320
292
321
293
static Intrinsic::ID getMaxForClamp (Type *ElemTy,
@@ -344,7 +316,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
344
316
return Intrinsic::minnum;
345
317
}
346
318
347
- static bool expandClampIntrinsic (CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
319
+ static Value *expandClampIntrinsic (CallInst *Orig,
320
+ Intrinsic::ID ClampIntrinsic) {
348
321
Value *X = Orig->getOperand (0 );
349
322
Value *Min = Orig->getOperand (1 );
350
323
Value *Max = Orig->getOperand (2 );
@@ -353,43 +326,55 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
353
326
Builder.SetInsertPoint (Orig);
354
327
auto *MaxCall = Builder.CreateIntrinsic (
355
328
Ty, getMaxForClamp (Ty, ClampIntrinsic), {X, Min}, nullptr , " dx.max" );
356
- auto *MinCall =
357
- Builder.CreateIntrinsic (Ty, getMinForClamp (Ty, ClampIntrinsic),
358
- {MaxCall, Max}, nullptr , " dx.min" );
359
-
360
- Orig->replaceAllUsesWith (MinCall);
361
- Orig->eraseFromParent ();
362
- return true ;
329
+ return Builder.CreateIntrinsic (Ty, getMinForClamp (Ty, ClampIntrinsic),
330
+ {MaxCall, Max}, nullptr , " dx.min" );
363
331
}
364
332
365
333
static bool expandIntrinsic (Function &F, CallInst *Orig) {
334
+ Value *Result = nullptr ;
366
335
switch (F.getIntrinsicID ()) {
367
336
case Intrinsic::abs :
368
- return expandAbs (Orig);
337
+ Result = expandAbs (Orig);
338
+ break ;
369
339
case Intrinsic::exp :
370
- return expandExpIntrinsic (Orig);
340
+ Result = expandExpIntrinsic (Orig);
341
+ break ;
371
342
case Intrinsic::log :
372
- return expandLogIntrinsic (Orig);
343
+ Result = expandLogIntrinsic (Orig);
344
+ break ;
373
345
case Intrinsic::log10 :
374
- return expandLog10Intrinsic (Orig);
346
+ Result = expandLog10Intrinsic (Orig);
347
+ break ;
375
348
case Intrinsic::pow :
376
- return expandPowIntrinsic (Orig);
349
+ Result = expandPowIntrinsic (Orig);
350
+ break ;
377
351
case Intrinsic::dx_any:
378
- return expandAnyIntrinsic (Orig);
352
+ Result = expandAnyIntrinsic (Orig);
353
+ break ;
379
354
case Intrinsic::dx_uclamp:
380
355
case Intrinsic::dx_clamp:
381
- return expandClampIntrinsic (Orig, F.getIntrinsicID ());
356
+ Result = expandClampIntrinsic (Orig, F.getIntrinsicID ());
357
+ break ;
382
358
case Intrinsic::dx_lerp:
383
- return expandLerpIntrinsic (Orig);
359
+ Result = expandLerpIntrinsic (Orig);
360
+ break ;
384
361
case Intrinsic::dx_length:
385
- return expandLengthIntrinsic (Orig);
362
+ Result = expandLengthIntrinsic (Orig);
363
+ break ;
386
364
case Intrinsic::dx_normalize:
387
- return expandNormalizeIntrinsic (Orig);
365
+ Result = expandNormalizeIntrinsic (Orig);
366
+ break ;
388
367
case Intrinsic::dx_sdot:
389
368
case Intrinsic::dx_udot:
390
- return expandIntegerDot (Orig, F.getIntrinsicID ());
369
+ Result = expandIntegerDot (Orig, F.getIntrinsicID ());
370
+ break ;
391
371
}
392
- return false ;
372
+
373
+ if (Result) {
374
+ Orig->replaceAllUsesWith (Result);
375
+ Orig->eraseFromParent ();
376
+ }
377
+ return !!Result;
393
378
}
394
379
395
380
static bool expansionIntrinsics (Module &M) {
0 commit comments