Skip to content

Commit 64f74e6

Browse files
committed
[Transforms] Expand optimizeTan to fold more inverse trig pairs
optimizeTan has been renamed to optimizeTrigInversionPairs as a result. Sadly, this is not mathematically true that all inverse pairs fold to x. For example, asin(sin(x)) does not fold to x if x is over 2*pi.
1 parent b163238 commit 64f74e6

File tree

3 files changed

+49
-24
lines changed

3 files changed

+49
-24
lines changed

llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class LibCallSimplifier {
202202
Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
203203
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
204204
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
205-
Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
205+
Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
206206
// Wrapper for all floating point library call optimizations
207207
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
208208
IRBuilderBase &B);

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,13 +2607,16 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
26072607
return copyFlags(*CI, FabsCall);
26082608
}
26092609

2610-
// TODO: Generalize to handle any trig function and its inverse.
2611-
Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
2610+
Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI,
2611+
IRBuilderBase &B) {
26122612
Module *M = CI->getModule();
26132613
Function *Callee = CI->getCalledFunction();
26142614
Value *Ret = nullptr;
26152615
StringRef Name = Callee->getName();
2616-
if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
2616+
if (UnsafeFPShrink &&
2617+
(Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" ||
2618+
Name == "asinh") &&
2619+
hasFloatVersion(M, Name))
26172620
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
26182621

26192622
Value *Op1 = CI->getArgOperand(0);
@@ -2626,16 +2629,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
26262629
return Ret;
26272630

26282631
// tan(atan(x)) -> x
2629-
// tanf(atanf(x)) -> x
2630-
// tanl(atanl(x)) -> x
2632+
// atanh(tanh(x)) -> x
2633+
// sinh(asinh(x)) -> x
2634+
// asinh(sinh(x)) -> x
2635+
// cosh(acosh(x)) -> x
26312636
LibFunc Func;
26322637
Function *F = OpC->getCalledFunction();
26332638
if (F && TLI->getLibFunc(F->getName(), Func) &&
2634-
isLibFuncEmittable(M, TLI, Func) &&
2635-
((Func == LibFunc_atan && Callee->getName() == "tan") ||
2636-
(Func == LibFunc_atanf && Callee->getName() == "tanf") ||
2637-
(Func == LibFunc_atanl && Callee->getName() == "tanl")))
2638-
Ret = OpC->getArgOperand(0);
2639+
isLibFuncEmittable(M, TLI, Func)) {
2640+
LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName())
2641+
.Case("tan", LibFunc_atan)
2642+
.Case("atanh", LibFunc_tanh)
2643+
.Case("sinh", LibFunc_asinh)
2644+
.Case("cosh", LibFunc_acosh)
2645+
.Case("tanf", LibFunc_atanf)
2646+
.Case("atanhf", LibFunc_tanhf)
2647+
.Case("sinhf", LibFunc_asinhf)
2648+
.Case("coshf", LibFunc_acoshf)
2649+
.Case("tanl", LibFunc_atanl)
2650+
.Case("atanhl", LibFunc_tanhl)
2651+
.Case("sinhl", LibFunc_asinhl)
2652+
.Case("coshl", LibFunc_acoshl)
2653+
.Case("asinh", LibFunc_sinh)
2654+
.Case("asinhf", LibFunc_sinhf)
2655+
.Case("asinhl", LibFunc_sinhl)
2656+
.Default(NumLibFuncs); // Used as error value
2657+
if (Func == inverseFunc)
2658+
Ret = OpC->getArgOperand(0);
2659+
}
26392660
return Ret;
26402661
}
26412662

@@ -3628,7 +3649,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
36283649
case LibFunc_tan:
36293650
case LibFunc_tanf:
36303651
case LibFunc_tanl:
3631-
return optimizeTan(CI, Builder);
3652+
case LibFunc_sinh:
3653+
case LibFunc_sinhf:
3654+
case LibFunc_sinhl:
3655+
case LibFunc_asinh:
3656+
case LibFunc_asinhf:
3657+
case LibFunc_asinhl:
3658+
case LibFunc_cosh:
3659+
case LibFunc_coshf:
3660+
case LibFunc_coshl:
3661+
case LibFunc_atanh:
3662+
case LibFunc_atanhf:
3663+
case LibFunc_atanhl:
3664+
return optimizeTrigInversionPairs(CI, Builder);
36323665
case LibFunc_ceil:
36333666
return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
36343667
case LibFunc_floor:
@@ -3646,17 +3679,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
36463679
case LibFunc_acos:
36473680
case LibFunc_acosh:
36483681
case LibFunc_asin:
3649-
case LibFunc_asinh:
36503682
case LibFunc_atan:
3651-
case LibFunc_atanh:
36523683
case LibFunc_cbrt:
3653-
case LibFunc_cosh:
36543684
case LibFunc_exp:
36553685
case LibFunc_exp10:
36563686
case LibFunc_expm1:
36573687
case LibFunc_cos:
36583688
case LibFunc_sin:
3659-
case LibFunc_sinh:
36603689
case LibFunc_tanh:
36613690
if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
36623691
return optimizeUnaryDoubleFP(CI, Builder, TLI, true);

llvm/test/Transforms/InstCombine/trig.ll

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ define float @atanhTanhInverseFast(float %x) {
1616
; CHECK-LABEL: define float @atanhTanhInverseFast(
1717
; CHECK-SAME: float [[X:%.*]]) {
1818
; CHECK-NEXT: [[CALL:%.*]] = call fast float @tanhf(float [[X]])
19-
; CHECK-NEXT: [[CALL1:%.*]] = call fast float @atanhf(float [[CALL]])
20-
; CHECK-NEXT: ret float [[CALL1]]
19+
; CHECK-NEXT: ret float [[X]]
2120
;
2221
%call = call fast float @tanhf(float %x)
2322
%call1 = call fast float @atanhf(float %call)
@@ -28,8 +27,7 @@ define float @sinhAsinhInverseFast(float %x) {
2827
; CHECK-LABEL: define float @sinhAsinhInverseFast(
2928
; CHECK-SAME: float [[X:%.*]]) {
3029
; CHECK-NEXT: [[CALL:%.*]] = call fast float @asinhf(float [[X]])
31-
; CHECK-NEXT: [[CALL1:%.*]] = call fast float @sinhf(float [[CALL]])
32-
; CHECK-NEXT: ret float [[CALL1]]
30+
; CHECK-NEXT: ret float [[X]]
3331
;
3432
%call = call fast float @asinhf(float %x)
3533
%call1 = call fast float @sinhf(float %call)
@@ -40,8 +38,7 @@ define float @asinhSinhInverseFast(float %x) {
4038
; CHECK-LABEL: define float @asinhSinhInverseFast(
4139
; CHECK-SAME: float [[X:%.*]]) {
4240
; CHECK-NEXT: [[CALL:%.*]] = call fast float @sinhf(float [[X]])
43-
; CHECK-NEXT: [[CALL1:%.*]] = call fast float @asinhf(float [[CALL]])
44-
; CHECK-NEXT: ret float [[CALL1]]
41+
; CHECK-NEXT: ret float [[X]]
4542
;
4643
%call = call fast float @sinhf(float %x)
4744
%call1 = call fast float @asinhf(float %call)
@@ -52,8 +49,7 @@ define float @coshAcoshInverseFast(float %x) {
5249
; CHECK-LABEL: define float @coshAcoshInverseFast(
5350
; CHECK-SAME: float [[X:%.*]]) {
5451
; CHECK-NEXT: [[CALL:%.*]] = call fast float @acoshf(float [[X]])
55-
; CHECK-NEXT: [[CALL1:%.*]] = call fast float @coshf(float [[CALL]])
56-
; CHECK-NEXT: ret float [[CALL1]]
52+
; CHECK-NEXT: ret float [[X]]
5753
;
5854
%call = call fast float @acoshf(float %x)
5955
%call1 = call fast float @coshf(float %call)

0 commit comments

Comments
 (0)