From 6f7b5f155f83add541b39b988c134cfe35c7619e Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Fri, 22 Nov 2024 10:12:28 -0800 Subject: [PATCH] [SandboxVec][Legality] Diamond reuse multi input This patch implements the diamond pattern where we are vectorizing toward the top of the diamond from both edges, but the second edge may use elements from a different vector or just scalar values. This requires some additional packing code (see lit test). --- .../Vectorize/SandboxVectorizer/Legality.h | 22 ++++++++++-- .../Vectorize/SandboxVectorizer/Legality.cpp | 3 +- .../SandboxVectorizer/Passes/BottomUpVec.cpp | 34 +++++++++++++++++++ .../SandboxVectorizer/bottomup_basic.ll | 27 +++++++++++++++ 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index 4858ebaf0770a..f10c535aa820e 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -81,6 +81,7 @@ enum class LegalityResultID { Widen, ///> Vectorize by combining scalars to a vector. DiamondReuse, ///> Don't generate new code, reuse existing vector. DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle. + DiamondReuseMultiInput, ///> Reuse more than one vector and/or scalars. }; /// The reason for vectorizing or not vectorizing. @@ -108,6 +109,8 @@ struct ToStr { return "DiamondReuse"; case LegalityResultID::DiamondReuseWithShuffle: return "DiamondReuseWithShuffle"; + case LegalityResultID::DiamondReuseMultiInput: + return "DiamondReuseMultiInput"; } llvm_unreachable("Unknown LegalityResultID enum"); } @@ -287,6 +290,20 @@ class CollectDescr { } }; +class DiamondReuseMultiInput final : public LegalityResult { + friend class LegalityAnalysis; + CollectDescr Descr; + DiamondReuseMultiInput(CollectDescr &&Descr) + : LegalityResult(LegalityResultID::DiamondReuseMultiInput), + Descr(std::move(Descr)) {} + +public: + static bool classof(const LegalityResult *From) { + return From->getSubclassID() == LegalityResultID::DiamondReuseMultiInput; + } + const CollectDescr &getCollectDescr() const { return Descr; } +}; + /// Performs the legality analysis and returns a LegalityResult object. class LegalityAnalysis { Scheduler Sched; @@ -312,8 +329,9 @@ class LegalityAnalysis { : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {} /// A LegalityResult factory. template - ResultT &createLegalityResult(ArgsT... Args) { - ResultPool.push_back(std::unique_ptr(new ResultT(Args...))); + ResultT &createLegalityResult(ArgsT &&...Args) { + ResultPool.push_back( + std::unique_ptr(new ResultT(std::move(Args)...))); return cast(*ResultPool.back()); } /// Checks if it's legal to vectorize the instructions in \p Bndl. diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index ad3e38e2f1d92..085f4cd67ab76 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -223,7 +223,8 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl, return createLegalityResult(Vec); return createLegalityResult(Vec, Mask); } - llvm_unreachable("TODO: Unimplemented"); + return createLegalityResult( + std::move(CollectDescrs)); } if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index d62023ea01884..c6ab3c1942c33 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -308,6 +308,40 @@ Value *BottomUpVec::vectorizeRec(ArrayRef Bndl, unsigned Depth) { NewVec = createShuffle(VecOp, Mask); break; } + case LegalityResultID::DiamondReuseMultiInput: { + const auto &Descr = + cast(LegalityRes).getCollectDescr(); + Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size()); + + // TODO: Try to get WhereIt without creating a vector. + SmallVector DescrInstrs; + for (const auto &ElmDescr : Descr.getDescrs()) { + if (auto *I = dyn_cast(ElmDescr.getValue())) + DescrInstrs.push_back(I); + } + auto WhereIt = getInsertPointAfterInstrs(DescrInstrs); + + Value *LastV = PoisonValue::get(ResTy); + for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) { + Value *VecOp = ElmDescr.getValue(); + Context &Ctx = VecOp->getContext(); + Value *ValueToInsert; + if (ElmDescr.needsExtract()) { + ConstantInt *IdxC = + ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx()); + ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt, + VecOp->getContext(), "VExt"); + } else { + ValueToInsert = VecOp; + } + ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane); + Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC, + WhereIt, Ctx, "VIns"); + LastV = Ins; + } + NewVec = LastV; + break; + } case LegalityResultID::Pack: { // If we can't vectorize the seeds then just return. if (Depth == 0) diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll index a3798af839908..5b389e25d70d9 100644 --- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -242,3 +242,30 @@ define void @diamondWithShuffle(ptr %ptr) { store float %sub1, ptr %ptr1 ret void } + +define void @diamondMultiInput(ptr %ptr, ptr %ptrX) { +; CHECK-LABEL: define void @diamondMultiInput( +; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDX:%.*]] = load float, ptr [[PTRX]], align 4 +; CHECK-NEXT: [[VINS:%.*]] = insertelement <2 x float> poison, float [[LDX]], i32 0 +; CHECK-NEXT: [[VEXT:%.*]] = extractelement <2 x float> [[VECL]], i32 0 +; CHECK-NEXT: [[VINS1:%.*]] = insertelement <2 x float> [[VINS]], float [[VEXT]], i32 1 +; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VINS1]] +; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %ptr0 + %ld1 = load float, ptr %ptr1 + + %ldX = load float, ptr %ptrX + + %sub0 = fsub float %ld0, %ldX + %sub1 = fsub float %ld1, %ld0 + store float %sub0, ptr %ptr0 + store float %sub1, ptr %ptr1 + ret void +}