Skip to content

Commit c6b5ea3

Browse files
authored
[Transforms] Expand optimizeTan to fold more inverse trig pairs (#77799)
optimizeTan has been renamed to optimizeTrigInversionPairs as a result. Sadly, this is not mathematically true that all inverse pairs fold to x. For example, asin(sin(x)) does not fold to x if x is over 2pi.
1 parent c9fd738 commit c6b5ea3

File tree

5 files changed

+185
-56
lines changed

5 files changed

+185
-56
lines changed

llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class LibCallSimplifier {
203203
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
204204
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
205205
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
206-
Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
206+
Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
207207
// Wrapper for all floating point library call optimizations
208208
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
209209
IRBuilderBase &B);

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,13 +2681,16 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
26812681
return copyFlags(*CI, FabsCall);
26822682
}
26832683

2684-
// TODO: Generalize to handle any trig function and its inverse.
2685-
Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
2684+
Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI,
2685+
IRBuilderBase &B) {
26862686
Module *M = CI->getModule();
26872687
Function *Callee = CI->getCalledFunction();
26882688
Value *Ret = nullptr;
26892689
StringRef Name = Callee->getName();
2690-
if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
2690+
if (UnsafeFPShrink &&
2691+
(Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" ||
2692+
Name == "asinh") &&
2693+
hasFloatVersion(M, Name))
26912694
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
26922695

26932696
Value *Op1 = CI->getArgOperand(0);
@@ -2700,16 +2703,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
27002703
return Ret;
27012704

27022705
// tan(atan(x)) -> x
2703-
// tanf(atanf(x)) -> x
2704-
// tanl(atanl(x)) -> x
2706+
// atanh(tanh(x)) -> x
2707+
// sinh(asinh(x)) -> x
2708+
// asinh(sinh(x)) -> x
2709+
// cosh(acosh(x)) -> x
27052710
LibFunc Func;
27062711
Function *F = OpC->getCalledFunction();
27072712
if (F && TLI->getLibFunc(F->getName(), Func) &&
2708-
isLibFuncEmittable(M, TLI, Func) &&
2709-
((Func == LibFunc_atan && Callee->getName() == "tan") ||
2710-
(Func == LibFunc_atanf && Callee->getName() == "tanf") ||
2711-
(Func == LibFunc_atanl && Callee->getName() == "tanl")))
2712-
Ret = OpC->getArgOperand(0);
2713+
isLibFuncEmittable(M, TLI, Func)) {
2714+
LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName())
2715+
.Case("tan", LibFunc_atan)
2716+
.Case("atanh", LibFunc_tanh)
2717+
.Case("sinh", LibFunc_asinh)
2718+
.Case("cosh", LibFunc_acosh)
2719+
.Case("tanf", LibFunc_atanf)
2720+
.Case("atanhf", LibFunc_tanhf)
2721+
.Case("sinhf", LibFunc_asinhf)
2722+
.Case("coshf", LibFunc_acoshf)
2723+
.Case("tanl", LibFunc_atanl)
2724+
.Case("atanhl", LibFunc_tanhl)
2725+
.Case("sinhl", LibFunc_asinhl)
2726+
.Case("coshl", LibFunc_acoshl)
2727+
.Case("asinh", LibFunc_sinh)
2728+
.Case("asinhf", LibFunc_sinhf)
2729+
.Case("asinhl", LibFunc_sinhl)
2730+
.Default(NumLibFuncs); // Used as error value
2731+
if (Func == inverseFunc)
2732+
Ret = OpC->getArgOperand(0);
2733+
}
27132734
return Ret;
27142735
}
27152736

@@ -3702,7 +3723,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
37023723
case LibFunc_tan:
37033724
case LibFunc_tanf:
37043725
case LibFunc_tanl:
3705-
return optimizeTan(CI, Builder);
3726+
case LibFunc_sinh:
3727+
case LibFunc_sinhf:
3728+
case LibFunc_sinhl:
3729+
case LibFunc_asinh:
3730+
case LibFunc_asinhf:
3731+
case LibFunc_asinhl:
3732+
case LibFunc_cosh:
3733+
case LibFunc_coshf:
3734+
case LibFunc_coshl:
3735+
case LibFunc_atanh:
3736+
case LibFunc_atanhf:
3737+
case LibFunc_atanhl:
3738+
return optimizeTrigInversionPairs(CI, Builder);
37063739
case LibFunc_ceil:
37073740
return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
37083741
case LibFunc_floor:
@@ -3720,17 +3753,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
37203753
case LibFunc_acos:
37213754
case LibFunc_acosh:
37223755
case LibFunc_asin:
3723-
case LibFunc_asinh:
37243756
case LibFunc_atan:
3725-
case LibFunc_atanh:
37263757
case LibFunc_cbrt:
3727-
case LibFunc_cosh:
37283758
case LibFunc_exp:
37293759
case LibFunc_exp10:
37303760
case LibFunc_expm1:
37313761
case LibFunc_cos:
37323762
case LibFunc_sin:
3733-
case LibFunc_sinh:
37343763
case LibFunc_tanh:
37353764
if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
37363765
return optimizeUnaryDoubleFP(CI, Builder, TLI, true);

llvm/test/Transforms/InstCombine/tan-nofastmath.ll

Lines changed: 0 additions & 17 deletions
This file was deleted.

llvm/test/Transforms/InstCombine/tan.ll

Lines changed: 0 additions & 23 deletions
This file was deleted.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
4+
define float @tanAtanInverseFast(float %x) {
5+
; CHECK-LABEL: define float @tanAtanInverseFast(
6+
; CHECK-SAME: float [[X:%.*]]) {
7+
; CHECK-NEXT: [[CALL:%.*]] = call fast float @atanf(float [[X]])
8+
; CHECK-NEXT: ret float [[X]]
9+
;
10+
%call = call fast float @atanf(float %x)
11+
%call1 = call fast float @tanf(float %call)
12+
ret float %call1
13+
}
14+
15+
define float @atanhTanhInverseFast(float %x) {
16+
; CHECK-LABEL: define float @atanhTanhInverseFast(
17+
; CHECK-SAME: float [[X:%.*]]) {
18+
; CHECK-NEXT: [[CALL:%.*]] = call fast float @tanhf(float [[X]])
19+
; CHECK-NEXT: ret float [[X]]
20+
;
21+
%call = call fast float @tanhf(float %x)
22+
%call1 = call fast float @atanhf(float %call)
23+
ret float %call1
24+
}
25+
26+
define float @sinhAsinhInverseFast(float %x) {
27+
; CHECK-LABEL: define float @sinhAsinhInverseFast(
28+
; CHECK-SAME: float [[X:%.*]]) {
29+
; CHECK-NEXT: [[CALL:%.*]] = call fast float @asinhf(float [[X]])
30+
; CHECK-NEXT: ret float [[X]]
31+
;
32+
%call = call fast float @asinhf(float %x)
33+
%call1 = call fast float @sinhf(float %call)
34+
ret float %call1
35+
}
36+
37+
define float @asinhSinhInverseFast(float %x) {
38+
; CHECK-LABEL: define float @asinhSinhInverseFast(
39+
; CHECK-SAME: float [[X:%.*]]) {
40+
; CHECK-NEXT: [[CALL:%.*]] = call fast float @sinhf(float [[X]])
41+
; CHECK-NEXT: ret float [[X]]
42+
;
43+
%call = call fast float @sinhf(float %x)
44+
%call1 = call fast float @asinhf(float %call)
45+
ret float %call1
46+
}
47+
48+
define float @coshAcoshInverseFast(float %x) {
49+
; CHECK-LABEL: define float @coshAcoshInverseFast(
50+
; CHECK-SAME: float [[X:%.*]]) {
51+
; CHECK-NEXT: [[CALL:%.*]] = call fast float @acoshf(float [[X]])
52+
; CHECK-NEXT: ret float [[X]]
53+
;
54+
%call = call fast float @acoshf(float %x)
55+
%call1 = call fast float @coshf(float %call)
56+
ret float %call1
57+
}
58+
59+
define float @indirectTanCall(ptr %fptr) {
60+
; CHECK-LABEL: define float @indirectTanCall(
61+
; CHECK-SAME: ptr [[FPTR:%.*]]) {
62+
; CHECK-NEXT: [[CALL1:%.*]] = call fast float [[FPTR]]()
63+
; CHECK-NEXT: [[TAN:%.*]] = call fast float @tanf(float [[CALL1]])
64+
; CHECK-NEXT: ret float [[TAN]]
65+
;
66+
%call1 = call fast float %fptr()
67+
%tan = call fast float @tanf(float %call1)
68+
ret float %tan
69+
}
70+
71+
; No fast-math.
72+
73+
define float @tanAtanInverse(float %x) {
74+
; CHECK-LABEL: define float @tanAtanInverse(
75+
; CHECK-SAME: float [[X:%.*]]) {
76+
; CHECK-NEXT: [[CALL:%.*]] = call float @atanf(float [[X]])
77+
; CHECK-NEXT: [[CALL1:%.*]] = call float @tanf(float [[CALL]])
78+
; CHECK-NEXT: ret float [[CALL1]]
79+
;
80+
%call = call float @atanf(float %x)
81+
%call1 = call float @tanf(float %call)
82+
ret float %call1
83+
}
84+
85+
define float @atanhTanhInverse(float %x) {
86+
; CHECK-LABEL: define float @atanhTanhInverse(
87+
; CHECK-SAME: float [[X:%.*]]) {
88+
; CHECK-NEXT: [[CALL:%.*]] = call float @tanhf(float [[X]])
89+
; CHECK-NEXT: [[CALL1:%.*]] = call float @atanhf(float [[CALL]])
90+
; CHECK-NEXT: ret float [[CALL1]]
91+
;
92+
%call = call float @tanhf(float %x)
93+
%call1 = call float @atanhf(float %call)
94+
ret float %call1
95+
}
96+
97+
define float @sinhAsinhInverse(float %x) {
98+
; CHECK-LABEL: define float @sinhAsinhInverse(
99+
; CHECK-SAME: float [[X:%.*]]) {
100+
; CHECK-NEXT: [[CALL:%.*]] = call float @asinhf(float [[X]])
101+
; CHECK-NEXT: [[CALL1:%.*]] = call float @sinhf(float [[CALL]])
102+
; CHECK-NEXT: ret float [[CALL1]]
103+
;
104+
%call = call float @asinhf(float %x)
105+
%call1 = call float @sinhf(float %call)
106+
ret float %call1
107+
}
108+
109+
define float @asinhSinhInverse(float %x) {
110+
; CHECK-LABEL: define float @asinhSinhInverse(
111+
; CHECK-SAME: float [[X:%.*]]) {
112+
; CHECK-NEXT: [[CALL:%.*]] = call float @sinhf(float [[X]])
113+
; CHECK-NEXT: [[CALL1:%.*]] = call float @asinhf(float [[CALL]])
114+
; CHECK-NEXT: ret float [[CALL1]]
115+
;
116+
%call = call float @sinhf(float %x)
117+
%call1 = call float @asinhf(float %call)
118+
ret float %call1
119+
}
120+
121+
define float @coshAcoshInverse(float %x) {
122+
; CHECK-LABEL: define float @coshAcoshInverse(
123+
; CHECK-SAME: float [[X:%.*]]) {
124+
; CHECK-NEXT: [[CALL:%.*]] = call float @acoshf(float [[X]])
125+
; CHECK-NEXT: [[CALL1:%.*]] = call float @coshf(float [[CALL]])
126+
; CHECK-NEXT: ret float [[CALL1]]
127+
;
128+
%call = call float @acoshf(float %x)
129+
%call1 = call float @coshf(float %call)
130+
ret float %call1
131+
}
132+
133+
declare float @asinhf(float)
134+
declare float @sinhf(float)
135+
declare float @acoshf(float)
136+
declare float @coshf(float)
137+
declare float @tanhf(float)
138+
declare float @atanhf(float)
139+
declare float @tanf(float)
140+
declare float @atanf(float)

0 commit comments

Comments
 (0)