-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[Transforms] Expand optimizeTan to fold more inverse trig pairs #77799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: AtariDreams (AtariDreams) ChangesIt has been renamed to optimizeTrig as a result. Use a map to map functions to their inverses. Full diff: https://github.com/llvm/llvm-project/pull/77799.diff 2 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..b1b8b9a5b6ad6a 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
- Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+ Value *optimizeTrig(CallInst *CI, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..bc09763d23f297 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2603,13 +2603,29 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
return copyFlags(*CI, FabsCall);
}
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrig(CallInst *CI, IRBuilderBase &B) {
Module *M = CI->getModule();
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
StringRef Name = Callee->getName();
- if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+
+ // Map of trigonometric functions to their inverses.
+ static const std::map<std::string, std::string> TrigFuncMap = {
+ {"sin", "asin"}, {"cos", "acos"}, {"tan", "atan"},
+ {"sinf", "asinf"}, {"cosf", "acosf"}, {"tanf", "atanf"},
+ {"sinl", "asinl"}, {"cosl", "acosl"}, {"tanl", "atanl"},
+ {"sinh", "asin"}, {"cosh", "acosh"}, {"tanh", "atanh"},
+ {"sinhf", "asinf"}, {"coshf", "acoshf"}, {"tanhf", "atanhf"},
+ {"sinhl", "asinhl"}, {"coshl", "acoshl"}, {"tanhl", "atanhl"},
+ };
+
+ // Check if the function is a trigonometric function.
+ auto It = TrigFuncMap.find(Name.str());
+ if (It == TrigFuncMap.end())
+ return Ret;
+
+ // Check if the function has a float version.
+ if (UnsafeFPShrink && hasFloatVersion(M, Name))
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
Value *Op1 = CI->getArgOperand(0);
@@ -2621,16 +2637,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
if (!CI->isFast() || !OpC->isFast())
return Ret;
- // tan(atan(x)) -> x
- // tanf(atanf(x)) -> x
- // tanl(atanl(x)) -> x
+ // Check if the operand is the inverse of the trigonometric function.
+ // in which case, a chain of inverses can be folded, ie: tan(atan(x)) -> x
LibFunc Func;
Function *F = OpC->getCalledFunction();
if (F && TLI->getLibFunc(F->getName(), Func) &&
- isLibFuncEmittable(M, TLI, Func) &&
- ((Func == LibFunc_atan && Callee->getName() == "tan") ||
- (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
- (Func == LibFunc_atanl && Callee->getName() == "tanl")))
+ isLibFuncEmittable(M, TLI, Func) && F->getName() == It->second)
Ret = OpC->getArgOperand(0);
return Ret;
}
@@ -3621,10 +3633,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_logb:
case LibFunc_logbl:
return optimizeLog(CI, Builder);
- case LibFunc_tan:
- case LibFunc_tanf:
- case LibFunc_tanl:
- return optimizeTan(CI, Builder);
case LibFunc_ceil:
return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
case LibFunc_floor:
@@ -3646,17 +3654,32 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_atan:
case LibFunc_atanh:
case LibFunc_cbrt:
- case LibFunc_cosh:
case LibFunc_exp:
case LibFunc_exp10:
case LibFunc_expm1:
+ if (UnsafeFPShrink &&
+ hasFloatVersion(M, CI->getCalledFunction()->getName()))
+ return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
+ return nullptr;
case LibFunc_cos:
+ case LibFunc_cosf:
+ case LibFunc_cosl:
+ case LibFunc_cosh:
+ case LibFunc_coshf:
+ case LibFunc_coshl:
case LibFunc_sin:
+ case LibFunc_sinf:
+ case LibFunc_sinl:
case LibFunc_sinh:
+ case LibFunc_sinhf:
+ case LibFunc_sinhl:
+ case LibFunc_tan:
+ case LibFunc_tanf:
+ case LibFunc_tanl:
case LibFunc_tanh:
- if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
- return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
- return nullptr;
+ case LibFunc_tanhf:
+ case LibFunc_tanhl:
+ return optimizeTrig(CI, Builder);
case LibFunc_copysign:
if (hasFloatVersion(M, CI->getCalledFunction()->getName()))
return optimizeBinaryDoubleFP(CI, Builder, TLI);
|
e9875b1
to
1a5e6cc
Compare
6e98a99
to
146d5bc
Compare
Could you please add some regression tests? |
Done! |
3495930
to
8c19541
Compare
9374771
to
4b1e3c3
Compare
5375707
to
5c5336b
Compare
a569d84
to
f3d5ae9
Compare
c63a5ba
to
d01d533
Compare
@arsenm Can you review this? |
861c51c
to
447ed32
Compare
@dtcxzyw Is this ready to merge? |
Please wait for approval from @arsenm or @jcranmer-intel. |
@dtcxzyw Ready! |
e8b833b
to
cac9c2c
Compare
Merge tan-nofastmath.ll and tan.ll into trig.ll
optimizeTan has been renamed to optimizeTrigInversionPairs as a result. Sadly, this is not mathematically true that all inverse pairs fold to x. For example, asin(sin(x)) does not fold to x if x is over 2*pi.
@arsenm Can we merge this please? |
optimizeTan has been renamed to optimizeTrigInversionPairs as a result.
Sadly, this is not mathematically true that all inverse pairs fold to x. For example, asin(sin(x)) does not fold to x if x is over 2pi.