Skip to content

[Transforms] Expand optimizeTan to fold more inverse trig pairs #77799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 6, 2024

Conversation

AZero13
Copy link
Contributor

@AZero13 AZero13 commented Jan 11, 2024

optimizeTan has been renamed to optimizeTrigInversionPairs as a result.

Sadly, this is not mathematically true that all inverse pairs fold to x. For example, asin(sin(x)) does not fold to x if x is over 2pi.

@llvmbot
Copy link
Member

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-llvm-transforms

Author: AtariDreams (AtariDreams)

Changes

It has been renamed to optimizeTrig as a result. Use a map to map functions to their inverses.


Full diff: https://github.com/llvm/llvm-project/pull/77799.diff

2 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h (+1-1)
  • (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+41-18)
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..b1b8b9a5b6ad6a 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
-  Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeTrig(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..bc09763d23f297 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2603,13 +2603,29 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   return copyFlags(*CI, FabsCall);
 }
 
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrig(CallInst *CI, IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
   StringRef Name = Callee->getName();
-  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+
+  // Map of trigonometric functions to their inverses.
+  static const std::map<std::string, std::string> TrigFuncMap = {
+      {"sin", "asin"},     {"cos", "acos"},     {"tan", "atan"},
+      {"sinf", "asinf"},   {"cosf", "acosf"},   {"tanf", "atanf"},
+      {"sinl", "asinl"},   {"cosl", "acosl"},   {"tanl", "atanl"},
+      {"sinh", "asin"},    {"cosh", "acosh"},   {"tanh", "atanh"},
+      {"sinhf", "asinf"},  {"coshf", "acoshf"}, {"tanhf", "atanhf"},
+      {"sinhl", "asinhl"}, {"coshl", "acoshl"}, {"tanhl", "atanhl"},
+  };
+
+  // Check if the function is a trigonometric function.
+  auto It = TrigFuncMap.find(Name.str());
+  if (It == TrigFuncMap.end())
+    return Ret;
+
+  // Check if the function has a float version.
+  if (UnsafeFPShrink && hasFloatVersion(M, Name))
     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
 
   Value *Op1 = CI->getArgOperand(0);
@@ -2621,16 +2637,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
   if (!CI->isFast() || !OpC->isFast())
     return Ret;
 
-  // tan(atan(x)) -> x
-  // tanf(atanf(x)) -> x
-  // tanl(atanl(x)) -> x
+  // Check if the operand is the inverse of the trigonometric function.
+  // in which case, a chain of inverses can be folded, ie: tan(atan(x)) -> x
   LibFunc Func;
   Function *F = OpC->getCalledFunction();
   if (F && TLI->getLibFunc(F->getName(), Func) &&
-      isLibFuncEmittable(M, TLI, Func) &&
-      ((Func == LibFunc_atan && Callee->getName() == "tan") ||
-       (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
-       (Func == LibFunc_atanl && Callee->getName() == "tanl")))
+      isLibFuncEmittable(M, TLI, Func) && F->getName() == It->second)
     Ret = OpC->getArgOperand(0);
   return Ret;
 }
@@ -3621,10 +3633,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_logb:
   case LibFunc_logbl:
     return optimizeLog(CI, Builder);
-  case LibFunc_tan:
-  case LibFunc_tanf:
-  case LibFunc_tanl:
-    return optimizeTan(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3646,17 +3654,32 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_atan:
   case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
+    if (UnsafeFPShrink &&
+        hasFloatVersion(M, CI->getCalledFunction()->getName()))
+      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
+    return nullptr;
   case LibFunc_cos:
+  case LibFunc_cosf:
+  case LibFunc_cosl:
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
   case LibFunc_sin:
+  case LibFunc_sinf:
+  case LibFunc_sinl:
   case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+  case LibFunc_tan:
+  case LibFunc_tanf:
+  case LibFunc_tanl:
   case LibFunc_tanh:
-    if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
-      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
-    return nullptr;
+  case LibFunc_tanhf:
+  case LibFunc_tanhl:
+    return optimizeTrig(CI, Builder);
   case LibFunc_copysign:
     if (hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeBinaryDoubleFP(CI, Builder, TLI);

@AZero13 AZero13 force-pushed the trig branch 2 times, most recently from e9875b1 to 1a5e6cc Compare January 11, 2024 18:44
@AZero13 AZero13 changed the title Resolve FIXME: Generalize optimizeTan to support other trig functions Create more optimizing functions to fold inverse pairs Jan 11, 2024
@AZero13 AZero13 force-pushed the trig branch 2 times, most recently from 6e98a99 to 146d5bc Compare January 11, 2024 19:04
@AZero13 AZero13 changed the title Create more optimizing functions to fold inverse pairs [Transforms] Create more optimizing functions to fold inverse pairs Jan 11, 2024
@AZero13 AZero13 changed the title [Transforms] Create more optimizing functions to fold inverse pairs [Transforms] Create more optimizing functions to fold inverse trig pairs Jan 11, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 11, 2024

Could you please add some regression tests?

@AZero13
Copy link
Contributor Author

AZero13 commented Jan 11, 2024

Could you please add some regression tests?

Done!

@AZero13 AZero13 changed the title [Transforms] Create more optimizing functions to fold inverse trig pairs [Transforms] Expand optimizeTan to fold more inverse trig pairs Jan 12, 2024
@AZero13 AZero13 force-pushed the trig branch 3 times, most recently from 9374771 to 4b1e3c3 Compare January 13, 2024 19:55
@AZero13 AZero13 requested a review from arsenm January 13, 2024 20:08
@AZero13 AZero13 force-pushed the trig branch 2 times, most recently from 5375707 to 5c5336b Compare January 15, 2024 00:29
@AZero13 AZero13 force-pushed the trig branch 2 times, most recently from a569d84 to f3d5ae9 Compare January 17, 2024 02:04
@AZero13
Copy link
Contributor Author

AZero13 commented Jan 20, 2024

@arsenm Can you review this?

@AZero13 AZero13 force-pushed the trig branch 2 times, most recently from 861c51c to 447ed32 Compare January 22, 2024 05:35
@AZero13
Copy link
Contributor Author

AZero13 commented Jan 22, 2024

@dtcxzyw Is this ready to merge?

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 22, 2024

@dtcxzyw Is this ready to merge?

Please wait for approval from @arsenm or @jcranmer-intel.

@AZero13
Copy link
Contributor Author

AZero13 commented Jan 24, 2024

@arsenm

@AZero13
Copy link
Contributor Author

AZero13 commented Jan 25, 2024

@dtcxzyw Ready!

@AZero13 AZero13 force-pushed the trig branch 3 times, most recently from e8b833b to cac9c2c Compare January 25, 2024 18:45
Merge tan-nofastmath.ll and tan.ll into trig.ll
optimizeTan has been renamed to optimizeTrigInversionPairs as a result.

Sadly, this is not mathematically true that all inverse pairs fold to x.

For example, asin(sin(x)) does not fold to x if x is over 2*pi.
@AZero13
Copy link
Contributor Author

AZero13 commented Jan 27, 2024

@arsenm Can we merge this please?

@arsenm arsenm merged commit c6b5ea3 into llvm:main Feb 6, 2024
@AZero13 AZero13 deleted the trig branch February 6, 2024 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants