Skip to content

[NFC][TargetTransformInfo][VectorUtils] Consolidate isVectorIntrinsic... api #117635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 19, 2024

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Nov 25, 2024

  • update VectorUtils:isVectorIntrinsicWithScalarOpAtArg to use TTI for
    all uses, to allow specifiction of target specific intrinsics

  • add TTI to the isVectorIntrinsicWithStructReturnOverloadAtField api

  • update TTI api to provide isTargetIntrinsicWith... functions and
    consistently name them

  • update all uses of the api and provide the TTI parameter

Resolves #117030

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Finn Plummer (inbelic)

Changes
  • update VectorUtils:isVectorIntrinsicWithScalarOpAtArg to use TTI for
    all uses, to allow specifiction of target specific intrinsics

  • add TTI to the isVectorIntrinsicWithStructReturnOverloadAtField api

  • update TTI api to provide isTargetIntrinsicWith... functions and
    consistently name them

  • update all uses of the api and provide the TTI parameter

Resolves #117030


Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117635.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+21-7)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+8-3)
  • (modified) llvm/include/llvm/Analysis/VectorUtils.h (+9-5)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+8-3)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+1-1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+8-3)
  • (modified) llvm/lib/Analysis/VectorUtils.cpp (+13-4)
  • (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+2-3)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+23-21)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+2-1)
  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+13-10)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 985ca1532e0149..cb4793e85e462e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -898,14 +898,20 @@ class TargetTransformInfo {
 
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
 
+  /// Identifies if the vector form of the intrinsic has a scalar operand.
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
   /// Identifies if the vector form of the intrinsic is overloaded on the type
   /// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
   /// -1.
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const;
+
+  /// Identifies if the vector form of the intrinsic that returns a struct is
+  /// overloaded at the struct element index \p RetIdx.
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const;
 
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
@@ -1999,8 +2005,11 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
-  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                      int ScalarOpdIdx) = 0;
+  virtual bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      int OpdIdx) = 0;
+  virtual bool
+  isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                   int RetIdx) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2577,9 +2586,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) override {
-    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) override {
+    return Impl.isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) override {
+    return Impl.isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 38aba183f6a173..9499c660822179 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -396,9 +396,14 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index c1016dd7bdddbd..f7febf3e82b125 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -147,8 +147,10 @@ inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
 bool isTriviallyVectorizable(Intrinsic::ID ID);
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
-bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                        unsigned ScalarOpdIdx);
+/// \p TTI is used to consider target specific intrinsics, if no target specific
+/// intrinsics will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
+                                        const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic is overloaded on the type of
 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
@@ -158,9 +160,11 @@ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
                                             const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic that returns a struct is
-/// overloaded at the struct element index \p RetIdx.
-bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                      int RetIdx);
+/// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
+/// consider target specific intrinsics, if no target specific intrinsics
+/// will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);
 
 /// Returns intrinsic ID for call.
 /// For the input call instruction it finds mapping intrinsic and returns
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..1bad9c70223b8d 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -801,9 +801,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 1971c28fc4c4de..caeaec70e3967c 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3438,7 +3438,7 @@ static Constant *ConstantFoldFixedVectorCall(
     // Gather a column of constants.
     for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) {
       // Some intrinsics use a scalar type for certain arguments.
-      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J, /*TTI=*/nullptr)) {
         Lane[J] = Operands[J];
         continue;
       }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1fb2b9836de0cc..5cd09207bfc56a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -615,9 +615,14 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
-bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
-    Intrinsic::ID ID, int ScalarOpdIdx) const {
-  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+bool TargetTransformInfo::isTargetIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int OpdIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+}
+
+bool TargetTransformInfo::isTargetIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, RetIdx);
 }
 
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 1789671276ffaf..15edaa0a47d037 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -115,7 +115,12 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
 bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx) {
+                                              unsigned ScalarOpdIdx,
+                                              const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
+
   switch (ID) {
   case Intrinsic::abs:
   case Intrinsic::ctlz:
@@ -138,7 +143,7 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
 
   if (TTI && Intrinsic::isTargetIntrinsic(ID))
-    return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+    return TTI->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
 
   switch (ID) {
   case Intrinsic::fptosi_sat:
@@ -157,8 +162,12 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   }
 }
 
-bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                            int RetIdx) {
+bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
+
   switch (ID) {
   case Intrinsic::frexp:
     return RetIdx == 0 || RetIdx == 1;
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 8d457f58e6eede..a87c2063b1e353 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -121,7 +121,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
     auto *ArgTy = Arg.value()->getType();
     bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
                                                             /*TTI=*/nullptr);
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) {
       ScalarArgTypes.push_back(ArgTy);
       if (IsOloadTy)
         OloadTys.push_back(ArgTy);
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 3b701e6ca09761..2ba8389c3a30c3 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -743,7 +743,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       // will only scalarize when the struct elements have the same bitness.
       if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
         return false;
-      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
+      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
         Tys.push_back(CurrVS->SplitTy);
     }
   }
@@ -794,8 +794,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       Tys[0] = VS->RemainderTy;
 
     for (unsigned J = 0; J != NumArgs; ++J) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
-          TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
         ScalarCallOps.push_back(ScalarOperands[J]);
       } else {
         ScalarCallOps.push_back(Scattered[J][I]);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index f1568781252c06..567d6ca2e5a319 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -927,7 +927,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
         auto *SE = PSE.getSE();
         Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI);
         for (unsigned Idx = 0; Idx < CI->arg_size(); ++Idx)
-          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx)) {
+          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx, TTI)) {
             if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(Idx)),
                                      TheLoop)) {
               reportVectorizationFailure("Found unvectorizable intrinsic",
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f13d0d80d382a4..a54cddd69d35d3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1060,7 +1060,8 @@ static bool allSameType(ArrayRef<Value *> VL) {
 /// \returns True if in-tree use also needs extract. This refers to
 /// possible scalar operand in vectorized instruction.
 static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
-                                        TargetLibraryInfo *TLI) {
+                                        TargetLibraryInfo *TLI,
+                                        const TargetTransformInfo *TTI) {
   if (!UserInst)
     return false;
   unsigned Opcode = UserInst->getOpcode();
@@ -1077,7 +1078,7 @@ static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
     CallInst *CI = cast<CallInst>(UserInst);
     Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
     return any_of(enumerate(CI->args()), [&](auto &&Arg) {
-      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index()) &&
+      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index(), TTI) &&
              Arg.value().get() == Scalar;
     });
   }
@@ -6417,7 +6418,7 @@ void BoUpSLP::buildExternalUses(
           // be used.
           if (UseEntry->State == TreeEntry::ScatterVectorize ||
               !doesInTreeUserNeedToExtract(
-                  Scalar, getRootEntryInstruction(*UseEntry), TLI)) {
+                  Scalar, getRootEntryInstruction(*UseEntry), TLI, TTI)) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
             assert(!UseEntry->isGather() && "Bad state");
@@ -7724,7 +7725,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
     unsigned NumArgs = CI->arg_size();
     SmallVector<Value *, 4> ScalarArgs(NumArgs, nullptr);
     for (unsigned J = 0; J != NumArgs; ++J)
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J))
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI))
         ScalarArgs[J] = CI->getArgOperand(J);
     for (Value *V : VL) {
       CallInst *CI2 = dyn_cast<CallInst>(V);
@@ -7740,7 +7741,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
       // Some intrinsics have scalar arguments and should be same in order for
       // them to be vectorized.
       for (unsigned J = 0; J != NumArgs; ++J) {
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
           Value *A1J = CI2->getArgOperand(J);
           if (ScalarArgs[J] != A1J) {
             LLVM_DEBUG(dbgs()
@@ -8613,7 +8614,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
         SmallVector<ValueList> Operands;
         for (unsigned I : seq<unsigned>(2, CI->arg_size())) {
           Operands.emplace_back();
-          if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
             continue;
           for (Value *V : VL) {
             auto *CI2 = cast<CallInst>(V);
@@ -8634,7 +8635,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
         // For scalar operands no need to create an entry since no need to
         // vectorize it.
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
           continue;
         ValueList Operands;
         // Prepare the operand vector.
@@ -10834,14 +10835,14 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
 
 /// Builds the arguments types vector for the given call instruction with the
 /// given \p ID for the specified vector factor.
-static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI,
-                                                  const Intrinsic::ID ID,
-                                                  const unsigned VF,
-                                                  unsigned MinBW) {
+static SmallVector<Type *>
+buildIntrinsicArgTypes(const CallInst *CI, const Intrinsic::ID ID,
+                       const unsigned VF, unsigned MinBW,
+                       const TargetTransformInfo *TTI) {
   SmallVector<Type *> ArgTys;
   for (auto [Idx, Arg] : enumerate(CI->args())) {
     if (ID != Intrinsic::not_intrinsic) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx, TTI)) {
         ArgTys.push_back(Arg->getType());
         continue;
       }
@@ -11520,9 +11521,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto GetVectorCost = [=](InstructionCost CommonCost) {
       auto *CI = cast<CallInst>(VL0);
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
     };
@@ -15644,9 +15645,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
                           VecCallCosts.first <= VecCallCosts.second;
@@ -15662,7 +15663,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         ValueList OpVL;
         // Some intrinsics have scalar arguments. This argument should not be
         // vectorized.
-        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
+        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI)) {
           ScalarArg = CEI->getArgOperand(I);
           // if decided to reduce bitwidth of abs intrinsic, it second argument
           // must be set false (do not return poison, if value issigned min).
@@ -16195,7 +16196,7 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
                                E->State == TreeEntry::StridedVectorize) &&
                               doesInTreeUserNeedToExtract(
    ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-vectorizers

Author: Finn Plummer (inbelic)

Changes
  • update VectorUtils:isVectorIntrinsicWithScalarOpAtArg to use TTI for
    all uses, to allow specifiction of target specific intrinsics

  • add TTI to the isVectorIntrinsicWithStructReturnOverloadAtField api

  • update TTI api to provide isTargetIntrinsicWith... functions and
    consistently name them

  • update all uses of the api and provide the TTI parameter

Resolves #117030


Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117635.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+21-7)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+8-3)
  • (modified) llvm/include/llvm/Analysis/VectorUtils.h (+9-5)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+8-3)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+1-1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+8-3)
  • (modified) llvm/lib/Analysis/VectorUtils.cpp (+13-4)
  • (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+2-3)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+23-21)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+2-1)
  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+13-10)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 985ca1532e0149..cb4793e85e462e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -898,14 +898,20 @@ class TargetTransformInfo {
 
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
 
+  /// Identifies if the vector form of the intrinsic has a scalar operand.
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
   /// Identifies if the vector form of the intrinsic is overloaded on the type
   /// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
   /// -1.
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const;
+
+  /// Identifies if the vector form of the intrinsic that returns a struct is
+  /// overloaded at the struct element index \p RetIdx.
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const;
 
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
@@ -1999,8 +2005,11 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
-  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                      int ScalarOpdIdx) = 0;
+  virtual bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      int OpdIdx) = 0;
+  virtual bool
+  isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                   int RetIdx) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2577,9 +2586,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) override {
-    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) override {
+    return Impl.isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) override {
+    return Impl.isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 38aba183f6a173..9499c660822179 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -396,9 +396,14 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index c1016dd7bdddbd..f7febf3e82b125 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -147,8 +147,10 @@ inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
 bool isTriviallyVectorizable(Intrinsic::ID ID);
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
-bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                        unsigned ScalarOpdIdx);
+/// \p TTI is used to consider target specific intrinsics, if no target specific
+/// intrinsics will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
+                                        const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic is overloaded on the type of
 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
@@ -158,9 +160,11 @@ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
                                             const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic that returns a struct is
-/// overloaded at the struct element index \p RetIdx.
-bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                      int RetIdx);
+/// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
+/// consider target specific intrinsics, if no target specific intrinsics
+/// will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);
 
 /// Returns intrinsic ID for call.
 /// For the input call instruction it finds mapping intrinsic and returns
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..1bad9c70223b8d 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -801,9 +801,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 1971c28fc4c4de..caeaec70e3967c 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3438,7 +3438,7 @@ static Constant *ConstantFoldFixedVectorCall(
     // Gather a column of constants.
     for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) {
       // Some intrinsics use a scalar type for certain arguments.
-      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J, /*TTI=*/nullptr)) {
         Lane[J] = Operands[J];
         continue;
       }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1fb2b9836de0cc..5cd09207bfc56a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -615,9 +615,14 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
-bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
-    Intrinsic::ID ID, int ScalarOpdIdx) const {
-  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+bool TargetTransformInfo::isTargetIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int OpdIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+}
+
+bool TargetTransformInfo::isTargetIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, RetIdx);
 }
 
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 1789671276ffaf..15edaa0a47d037 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -115,7 +115,12 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
 bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx) {
+                                              unsigned ScalarOpdIdx,
+                                              const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
+
   switch (ID) {
   case Intrinsic::abs:
   case Intrinsic::ctlz:
@@ -138,7 +143,7 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
 
   if (TTI && Intrinsic::isTargetIntrinsic(ID))
-    return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+    return TTI->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
 
   switch (ID) {
   case Intrinsic::fptosi_sat:
@@ -157,8 +162,12 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   }
 }
 
-bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                            int RetIdx) {
+bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
+
   switch (ID) {
   case Intrinsic::frexp:
     return RetIdx == 0 || RetIdx == 1;
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 8d457f58e6eede..a87c2063b1e353 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -121,7 +121,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
     auto *ArgTy = Arg.value()->getType();
     bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
                                                             /*TTI=*/nullptr);
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) {
       ScalarArgTypes.push_back(ArgTy);
       if (IsOloadTy)
         OloadTys.push_back(ArgTy);
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 3b701e6ca09761..2ba8389c3a30c3 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -743,7 +743,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       // will only scalarize when the struct elements have the same bitness.
       if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
         return false;
-      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
+      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
         Tys.push_back(CurrVS->SplitTy);
     }
   }
@@ -794,8 +794,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       Tys[0] = VS->RemainderTy;
 
     for (unsigned J = 0; J != NumArgs; ++J) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
-          TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
         ScalarCallOps.push_back(ScalarOperands[J]);
       } else {
         ScalarCallOps.push_back(Scattered[J][I]);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index f1568781252c06..567d6ca2e5a319 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -927,7 +927,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
         auto *SE = PSE.getSE();
         Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI);
         for (unsigned Idx = 0; Idx < CI->arg_size(); ++Idx)
-          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx)) {
+          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx, TTI)) {
             if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(Idx)),
                                      TheLoop)) {
               reportVectorizationFailure("Found unvectorizable intrinsic",
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f13d0d80d382a4..a54cddd69d35d3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1060,7 +1060,8 @@ static bool allSameType(ArrayRef<Value *> VL) {
 /// \returns True if in-tree use also needs extract. This refers to
 /// possible scalar operand in vectorized instruction.
 static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
-                                        TargetLibraryInfo *TLI) {
+                                        TargetLibraryInfo *TLI,
+                                        const TargetTransformInfo *TTI) {
   if (!UserInst)
     return false;
   unsigned Opcode = UserInst->getOpcode();
@@ -1077,7 +1078,7 @@ static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
     CallInst *CI = cast<CallInst>(UserInst);
     Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
     return any_of(enumerate(CI->args()), [&](auto &&Arg) {
-      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index()) &&
+      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index(), TTI) &&
              Arg.value().get() == Scalar;
     });
   }
@@ -6417,7 +6418,7 @@ void BoUpSLP::buildExternalUses(
           // be used.
           if (UseEntry->State == TreeEntry::ScatterVectorize ||
               !doesInTreeUserNeedToExtract(
-                  Scalar, getRootEntryInstruction(*UseEntry), TLI)) {
+                  Scalar, getRootEntryInstruction(*UseEntry), TLI, TTI)) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
             assert(!UseEntry->isGather() && "Bad state");
@@ -7724,7 +7725,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
     unsigned NumArgs = CI->arg_size();
     SmallVector<Value *, 4> ScalarArgs(NumArgs, nullptr);
     for (unsigned J = 0; J != NumArgs; ++J)
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J))
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI))
         ScalarArgs[J] = CI->getArgOperand(J);
     for (Value *V : VL) {
       CallInst *CI2 = dyn_cast<CallInst>(V);
@@ -7740,7 +7741,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
       // Some intrinsics have scalar arguments and should be same in order for
       // them to be vectorized.
       for (unsigned J = 0; J != NumArgs; ++J) {
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
           Value *A1J = CI2->getArgOperand(J);
           if (ScalarArgs[J] != A1J) {
             LLVM_DEBUG(dbgs()
@@ -8613,7 +8614,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
         SmallVector<ValueList> Operands;
         for (unsigned I : seq<unsigned>(2, CI->arg_size())) {
           Operands.emplace_back();
-          if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
             continue;
           for (Value *V : VL) {
             auto *CI2 = cast<CallInst>(V);
@@ -8634,7 +8635,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
         // For scalar operands no need to create an entry since no need to
         // vectorize it.
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
           continue;
         ValueList Operands;
         // Prepare the operand vector.
@@ -10834,14 +10835,14 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
 
 /// Builds the arguments types vector for the given call instruction with the
 /// given \p ID for the specified vector factor.
-static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI,
-                                                  const Intrinsic::ID ID,
-                                                  const unsigned VF,
-                                                  unsigned MinBW) {
+static SmallVector<Type *>
+buildIntrinsicArgTypes(const CallInst *CI, const Intrinsic::ID ID,
+                       const unsigned VF, unsigned MinBW,
+                       const TargetTransformInfo *TTI) {
   SmallVector<Type *> ArgTys;
   for (auto [Idx, Arg] : enumerate(CI->args())) {
     if (ID != Intrinsic::not_intrinsic) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx, TTI)) {
         ArgTys.push_back(Arg->getType());
         continue;
       }
@@ -11520,9 +11521,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto GetVectorCost = [=](InstructionCost CommonCost) {
       auto *CI = cast<CallInst>(VL0);
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
     };
@@ -15644,9 +15645,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
                           VecCallCosts.first <= VecCallCosts.second;
@@ -15662,7 +15663,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         ValueList OpVL;
         // Some intrinsics have scalar arguments. This argument should not be
         // vectorized.
-        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
+        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI)) {
           ScalarArg = CEI->getArgOperand(I);
           // if decided to reduce bitwidth of abs intrinsic, it second argument
           // must be set false (do not return poison, if value issigned min).
@@ -16195,7 +16196,7 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
                                E->State == TreeEntry::StridedVectorize) &&
                               doesInTreeUserNeedToExtract(
    ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Finn Plummer (inbelic)

Changes
  • update VectorUtils:isVectorIntrinsicWithScalarOpAtArg to use TTI for
    all uses, to allow specifiction of target specific intrinsics

  • add TTI to the isVectorIntrinsicWithStructReturnOverloadAtField api

  • update TTI api to provide isTargetIntrinsicWith... functions and
    consistently name them

  • update all uses of the api and provide the TTI parameter

Resolves #117030


Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117635.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+21-7)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+8-3)
  • (modified) llvm/include/llvm/Analysis/VectorUtils.h (+9-5)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+8-3)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+1-1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+8-3)
  • (modified) llvm/lib/Analysis/VectorUtils.cpp (+13-4)
  • (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+2-3)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+23-21)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+2-1)
  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+13-10)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 985ca1532e0149..cb4793e85e462e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -898,14 +898,20 @@ class TargetTransformInfo {
 
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
 
+  /// Identifies if the vector form of the intrinsic has a scalar operand.
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
   /// Identifies if the vector form of the intrinsic is overloaded on the type
   /// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
   /// -1.
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const;
+
+  /// Identifies if the vector form of the intrinsic that returns a struct is
+  /// overloaded at the struct element index \p RetIdx.
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const;
 
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
@@ -1999,8 +2005,11 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
-  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                      int ScalarOpdIdx) = 0;
+  virtual bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      int OpdIdx) = 0;
+  virtual bool
+  isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                   int RetIdx) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2577,9 +2586,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) override {
-    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) override {
+    return Impl.isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) override {
+    return Impl.isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 38aba183f6a173..9499c660822179 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -396,9 +396,14 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index c1016dd7bdddbd..f7febf3e82b125 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -147,8 +147,10 @@ inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
 bool isTriviallyVectorizable(Intrinsic::ID ID);
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
-bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                        unsigned ScalarOpdIdx);
+/// \p TTI is used to consider target specific intrinsics, if no target specific
+/// intrinsics will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
+                                        const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic is overloaded on the type of
 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
@@ -158,9 +160,11 @@ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
                                             const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic that returns a struct is
-/// overloaded at the struct element index \p RetIdx.
-bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                      int RetIdx);
+/// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
+/// consider target specific intrinsics, if no target specific intrinsics
+/// will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);
 
 /// Returns intrinsic ID for call.
 /// For the input call instruction it finds mapping intrinsic and returns
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..1bad9c70223b8d 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -801,9 +801,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 1971c28fc4c4de..caeaec70e3967c 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3438,7 +3438,7 @@ static Constant *ConstantFoldFixedVectorCall(
     // Gather a column of constants.
     for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) {
       // Some intrinsics use a scalar type for certain arguments.
-      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J, /*TTI=*/nullptr)) {
         Lane[J] = Operands[J];
         continue;
       }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1fb2b9836de0cc..5cd09207bfc56a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -615,9 +615,14 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
-bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
-    Intrinsic::ID ID, int ScalarOpdIdx) const {
-  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+bool TargetTransformInfo::isTargetIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int OpdIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+}
+
+bool TargetTransformInfo::isTargetIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, RetIdx);
 }
 
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 1789671276ffaf..15edaa0a47d037 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -115,7 +115,12 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
 bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx) {
+                                              unsigned ScalarOpdIdx,
+                                              const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
+
   switch (ID) {
   case Intrinsic::abs:
   case Intrinsic::ctlz:
@@ -138,7 +143,7 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
 
   if (TTI && Intrinsic::isTargetIntrinsic(ID))
-    return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+    return TTI->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
 
   switch (ID) {
   case Intrinsic::fptosi_sat:
@@ -157,8 +162,12 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   }
 }
 
-bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                            int RetIdx) {
+bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
+
   switch (ID) {
   case Intrinsic::frexp:
     return RetIdx == 0 || RetIdx == 1;
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 8d457f58e6eede..a87c2063b1e353 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -121,7 +121,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
     auto *ArgTy = Arg.value()->getType();
     bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
                                                             /*TTI=*/nullptr);
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) {
       ScalarArgTypes.push_back(ArgTy);
       if (IsOloadTy)
         OloadTys.push_back(ArgTy);
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 3b701e6ca09761..2ba8389c3a30c3 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -743,7 +743,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       // will only scalarize when the struct elements have the same bitness.
       if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
         return false;
-      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
+      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
         Tys.push_back(CurrVS->SplitTy);
     }
   }
@@ -794,8 +794,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       Tys[0] = VS->RemainderTy;
 
     for (unsigned J = 0; J != NumArgs; ++J) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
-          TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
         ScalarCallOps.push_back(ScalarOperands[J]);
       } else {
         ScalarCallOps.push_back(Scattered[J][I]);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index f1568781252c06..567d6ca2e5a319 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -927,7 +927,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
         auto *SE = PSE.getSE();
         Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI);
         for (unsigned Idx = 0; Idx < CI->arg_size(); ++Idx)
-          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx)) {
+          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx, TTI)) {
             if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(Idx)),
                                      TheLoop)) {
               reportVectorizationFailure("Found unvectorizable intrinsic",
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f13d0d80d382a4..a54cddd69d35d3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1060,7 +1060,8 @@ static bool allSameType(ArrayRef<Value *> VL) {
 /// \returns True if in-tree use also needs extract. This refers to
 /// possible scalar operand in vectorized instruction.
 static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
-                                        TargetLibraryInfo *TLI) {
+                                        TargetLibraryInfo *TLI,
+                                        const TargetTransformInfo *TTI) {
   if (!UserInst)
     return false;
   unsigned Opcode = UserInst->getOpcode();
@@ -1077,7 +1078,7 @@ static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
     CallInst *CI = cast<CallInst>(UserInst);
     Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
     return any_of(enumerate(CI->args()), [&](auto &&Arg) {
-      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index()) &&
+      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index(), TTI) &&
              Arg.value().get() == Scalar;
     });
   }
@@ -6417,7 +6418,7 @@ void BoUpSLP::buildExternalUses(
           // be used.
           if (UseEntry->State == TreeEntry::ScatterVectorize ||
               !doesInTreeUserNeedToExtract(
-                  Scalar, getRootEntryInstruction(*UseEntry), TLI)) {
+                  Scalar, getRootEntryInstruction(*UseEntry), TLI, TTI)) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
             assert(!UseEntry->isGather() && "Bad state");
@@ -7724,7 +7725,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
     unsigned NumArgs = CI->arg_size();
     SmallVector<Value *, 4> ScalarArgs(NumArgs, nullptr);
     for (unsigned J = 0; J != NumArgs; ++J)
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J))
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI))
         ScalarArgs[J] = CI->getArgOperand(J);
     for (Value *V : VL) {
       CallInst *CI2 = dyn_cast<CallInst>(V);
@@ -7740,7 +7741,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
       // Some intrinsics have scalar arguments and should be same in order for
       // them to be vectorized.
       for (unsigned J = 0; J != NumArgs; ++J) {
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
           Value *A1J = CI2->getArgOperand(J);
           if (ScalarArgs[J] != A1J) {
             LLVM_DEBUG(dbgs()
@@ -8613,7 +8614,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
         SmallVector<ValueList> Operands;
         for (unsigned I : seq<unsigned>(2, CI->arg_size())) {
           Operands.emplace_back();
-          if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
             continue;
           for (Value *V : VL) {
             auto *CI2 = cast<CallInst>(V);
@@ -8634,7 +8635,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
         // For scalar operands no need to create an entry since no need to
         // vectorize it.
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
           continue;
         ValueList Operands;
         // Prepare the operand vector.
@@ -10834,14 +10835,14 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
 
 /// Builds the arguments types vector for the given call instruction with the
 /// given \p ID for the specified vector factor.
-static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI,
-                                                  const Intrinsic::ID ID,
-                                                  const unsigned VF,
-                                                  unsigned MinBW) {
+static SmallVector<Type *>
+buildIntrinsicArgTypes(const CallInst *CI, const Intrinsic::ID ID,
+                       const unsigned VF, unsigned MinBW,
+                       const TargetTransformInfo *TTI) {
   SmallVector<Type *> ArgTys;
   for (auto [Idx, Arg] : enumerate(CI->args())) {
     if (ID != Intrinsic::not_intrinsic) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx, TTI)) {
         ArgTys.push_back(Arg->getType());
         continue;
       }
@@ -11520,9 +11521,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto GetVectorCost = [=](InstructionCost CommonCost) {
       auto *CI = cast<CallInst>(VL0);
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
     };
@@ -15644,9 +15645,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
                           VecCallCosts.first <= VecCallCosts.second;
@@ -15662,7 +15663,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         ValueList OpVL;
         // Some intrinsics have scalar arguments. This argument should not be
         // vectorized.
-        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
+        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI)) {
           ScalarArg = CEI->getArgOperand(I);
           // if decided to reduce bitwidth of abs intrinsic, it second argument
           // must be set false (do not return poison, if value issigned min).
@@ -16195,7 +16196,7 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
                                E->State == TreeEntry::StridedVectorize) &&
                               doesInTreeUserNeedToExtract(
    ...
[truncated]

@farzonl
Copy link
Member

farzonl commented Dec 12, 2024

Should we do something similar with isTargetIntrinsicTriviallyScalarizable? Alternatively was thinking we need a ticket to consolidate isTargetIntrinsicTriviallyScalarizable with isTargetIntrinsicTriviallyVectorizable?

@inbelic
Copy link
Contributor Author

inbelic commented Dec 13, 2024

Sure, I think it makes sense to move the ScalarizerVisitor::isTriviallyScalarizable to VectorUtils.cpp.

By consolidate, do you mean to just have one of isTriviallyVectorizable and isTriviallyScalarizable?

@farzonl
Copy link
Member

farzonl commented Dec 13, 2024

Sure, I think it makes sense to move the ScalarizerVisitor::isTriviallyScalarizable to VectorUtils.cpp.

By consolidate, do you mean to just have one of isTriviallyVectorizable and isTriviallyScalarizable?

I was thinking one function with a more neutral name. I don’t have a good suggestion for you though. Maybe something with elementwise in the name? Or maybe we just adopt the vectorizable name?

All that said there might be advantages to having a separate scalarizer function in VectorUtils. For one there are intrinsics we don’t need to figure out vectorization for. I leave it to you to decided what needs to be in VectorUtils.

…c...` api

- update `VectorUtils:isVectorIntrinsicWithScalarOpAtArg` to use TTI for
all uses, to allow specifiction of target specific intrinsics
- add TTI to the `isVectorIntrinsicWithStructReturnOverloadAtField` api
- update TTI api to provide `isTargetIntrinsicWith...` functions and
  consistently name them
- we don't expect a target intrinsic to be replaced with a general
libvec function. so we can just pass in nullptr for now
- we can pass in the already defined TTI
- pass down the previously defined TTI
- use the previously defined TTI
- use the already defined TTI
- pass down the nullptr as there is no folding of any target intrinsics
@inbelic inbelic force-pushed the inbelic/vectorutils-tti branch from 03cc71b to 6067351 Compare December 17, 2024 01:53
@inbelic
Copy link
Contributor Author

inbelic commented Dec 17, 2024

Rebased onto main to resolve merge conflicts and updated to move isTriviallyScalarizable into VectorUtils.

I decided to keep the two functions separate as it will force any future intrinsic that implements scalarization to also implement vectorization, which may be redundant work. But also, because isTriviallyVectorizable is used for other logic that needs to be audited before changing. For instance, isNotCrossLaneOperation, where it is not true that if a function is scalarizable then it is not a cross-lane op (dx_wave_readlane).

I created this issue to track the work needed to update isTriviallyVectorizable to accept target intrinsics and from there we can consider if we want to merge them.

@inbelic inbelic merged commit 45c01e8 into llvm:main Dec 19, 2024
8 checks passed
@inbelic inbelic deleted the inbelic/vectorutils-tti branch January 17, 2025 23:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[NFC] Update isVectorIntrinsicWith... functions to consider TargetTransformInfo
4 participants