Skip to content

[SandboxVec][Legality] Diamond reuse multi input #123426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025
Merged

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Jan 17, 2025

This patch implements the diamond pattern where we are vectorizing toward the top of the diamond from both edges, but the second edge may use elements from a different vector or just scalar values. This requires some additional packing code (see lit test).

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

This patch implements the diamond pattern where we are vectorizing toward the top of the diamond from both edges, but the second edge may use elements from a different vector or just scalar values. This requires some additional packing code (see lit test).


Full diff: https://github.com/llvm/llvm-project/pull/123426.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h (+20-2)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (+2-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+34)
  • (modified) llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll (+27)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 4858ebaf0770aa..f10c535aa820ee 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -81,6 +81,7 @@ enum class LegalityResultID {
   Widen,                   ///> Vectorize by combining scalars to a vector.
   DiamondReuse,            ///> Don't generate new code, reuse existing vector.
   DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
+  DiamondReuseMultiInput,  ///> Reuse more than one vector and/or scalars.
 };
 
 /// The reason for vectorizing or not vectorizing.
@@ -108,6 +109,8 @@ struct ToStr {
       return "DiamondReuse";
     case LegalityResultID::DiamondReuseWithShuffle:
       return "DiamondReuseWithShuffle";
+    case LegalityResultID::DiamondReuseMultiInput:
+      return "DiamondReuseMultiInput";
     }
     llvm_unreachable("Unknown LegalityResultID enum");
   }
@@ -287,6 +290,20 @@ class CollectDescr {
   }
 };
 
+class DiamondReuseMultiInput final : public LegalityResult {
+  friend class LegalityAnalysis;
+  CollectDescr Descr;
+  DiamondReuseMultiInput(CollectDescr &&Descr)
+      : LegalityResult(LegalityResultID::DiamondReuseMultiInput),
+        Descr(std::move(Descr)) {}
+
+public:
+  static bool classof(const LegalityResult *From) {
+    return From->getSubclassID() == LegalityResultID::DiamondReuseMultiInput;
+  }
+  const CollectDescr &getCollectDescr() const { return Descr; }
+};
+
 /// Performs the legality analysis and returns a LegalityResult object.
 class LegalityAnalysis {
   Scheduler Sched;
@@ -312,8 +329,9 @@ class LegalityAnalysis {
       : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
   /// A LegalityResult factory.
   template <typename ResultT, typename... ArgsT>
-  ResultT &createLegalityResult(ArgsT... Args) {
-    ResultPool.push_back(std::unique_ptr<ResultT>(new ResultT(Args...)));
+  ResultT &createLegalityResult(ArgsT &&...Args) {
+    ResultPool.push_back(
+        std::unique_ptr<ResultT>(new ResultT(std::move(Args)...)));
     return cast<ResultT>(*ResultPool.back());
   }
   /// Checks if it's legal to vectorize the instructions in \p Bndl.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index ad3e38e2f1d923..085f4cd67ab76e 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -223,7 +223,8 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
         return createLegalityResult<DiamondReuse>(Vec);
       return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
     }
-    llvm_unreachable("TODO: Unimplemented");
+    return createLegalityResult<DiamondReuseMultiInput>(
+        std::move(CollectDescrs));
   }
 
   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index d62023ea018846..c6ab3c1942c330 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -308,6 +308,40 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
     NewVec = createShuffle(VecOp, Mask);
     break;
   }
+  case LegalityResultID::DiamondReuseMultiInput: {
+    const auto &Descr =
+        cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
+    Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
+
+    // TODO: Try to get WhereIt without creating a vector.
+    SmallVector<Value *, 4> DescrInstrs;
+    for (const auto &ElmDescr : Descr.getDescrs()) {
+      if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue()))
+        DescrInstrs.push_back(I);
+    }
+    auto WhereIt = getInsertPointAfterInstrs(DescrInstrs);
+
+    Value *LastV = PoisonValue::get(ResTy);
+    for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
+      Value *VecOp = ElmDescr.getValue();
+      Context &Ctx = VecOp->getContext();
+      Value *ValueToInsert;
+      if (ElmDescr.needsExtract()) {
+        ConstantInt *IdxC =
+            ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx());
+        ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt,
+                                                   VecOp->getContext(), "VExt");
+      } else {
+        ValueToInsert = VecOp;
+      }
+      ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
+      Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
+                                             WhereIt, Ctx, "VIns");
+      LastV = Ins;
+    }
+    NewVec = LastV;
+    break;
+  }
   case LegalityResultID::Pack: {
     // If we can't vectorize the seeds then just return.
     if (Depth == 0)
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index a3798af8399087..5b389e25d70d95 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -242,3 +242,30 @@ define void @diamondWithShuffle(ptr %ptr) {
   store float %sub1, ptr %ptr1
   ret void
 }
+
+define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
+; CHECK-LABEL: define void @diamondMultiInput(
+; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDX:%.*]] = load float, ptr [[PTRX]], align 4
+; CHECK-NEXT:    [[VINS:%.*]] = insertelement <2 x float> poison, float [[LDX]], i32 0
+; CHECK-NEXT:    [[VEXT:%.*]] = extractelement <2 x float> [[VECL]], i32 0
+; CHECK-NEXT:    [[VINS1:%.*]] = insertelement <2 x float> [[VINS]], float [[VEXT]], i32 1
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VINS1]]
+; CHECK-NEXT:    store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %ptr0
+  %ld1 = load float, ptr %ptr1
+
+  %ldX = load float, ptr %ptrX
+
+  %sub0 = fsub float %ld0, %ldX
+  %sub1 = fsub float %ld1, %ld0
+  store float %sub0, ptr %ptr0
+  store float %sub1, ptr %ptr1
+  ret void
+}

This patch implements the diamond pattern where we are vectorizing toward the
top of the diamond from both edges, but the second edge may use elements from
a different vector or just scalar values. This requires some additional packing
code (see lit test).
@vporpo vporpo merged commit fd08713 into llvm:main Jan 22, 2025
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants