Skip to content

Commit fd08713

Browse files
authored
[SandboxVec][Legality] Diamond reuse multi input (#123426)
This patch implements the diamond pattern where we are vectorizing toward the top of the diamond from both edges, but the second edge may use elements from a different vector or just scalar values. This requires some additional packing code (see lit test).
1 parent 939f290 commit fd08713

File tree

4 files changed

+83
-3
lines changed

4 files changed

+83
-3
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ enum class LegalityResultID {
8181
Widen, ///> Vectorize by combining scalars to a vector.
8282
DiamondReuse, ///> Don't generate new code, reuse existing vector.
8383
DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
84+
DiamondReuseMultiInput, ///> Reuse more than one vector and/or scalars.
8485
};
8586

8687
/// The reason for vectorizing or not vectorizing.
@@ -108,6 +109,8 @@ struct ToStr {
108109
return "DiamondReuse";
109110
case LegalityResultID::DiamondReuseWithShuffle:
110111
return "DiamondReuseWithShuffle";
112+
case LegalityResultID::DiamondReuseMultiInput:
113+
return "DiamondReuseMultiInput";
111114
}
112115
llvm_unreachable("Unknown LegalityResultID enum");
113116
}
@@ -287,6 +290,20 @@ class CollectDescr {
287290
}
288291
};
289292

293+
class DiamondReuseMultiInput final : public LegalityResult {
294+
friend class LegalityAnalysis;
295+
CollectDescr Descr;
296+
DiamondReuseMultiInput(CollectDescr &&Descr)
297+
: LegalityResult(LegalityResultID::DiamondReuseMultiInput),
298+
Descr(std::move(Descr)) {}
299+
300+
public:
301+
static bool classof(const LegalityResult *From) {
302+
return From->getSubclassID() == LegalityResultID::DiamondReuseMultiInput;
303+
}
304+
const CollectDescr &getCollectDescr() const { return Descr; }
305+
};
306+
290307
/// Performs the legality analysis and returns a LegalityResult object.
291308
class LegalityAnalysis {
292309
Scheduler Sched;
@@ -312,8 +329,9 @@ class LegalityAnalysis {
312329
: Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
313330
/// A LegalityResult factory.
314331
template <typename ResultT, typename... ArgsT>
315-
ResultT &createLegalityResult(ArgsT... Args) {
316-
ResultPool.push_back(std::unique_ptr<ResultT>(new ResultT(Args...)));
332+
ResultT &createLegalityResult(ArgsT &&...Args) {
333+
ResultPool.push_back(
334+
std::unique_ptr<ResultT>(new ResultT(std::move(Args)...)));
317335
return cast<ResultT>(*ResultPool.back());
318336
}
319337
/// Checks if it's legal to vectorize the instructions in \p Bndl.

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
223223
return createLegalityResult<DiamondReuse>(Vec);
224224
return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
225225
}
226-
llvm_unreachable("TODO: Unimplemented");
226+
return createLegalityResult<DiamondReuseMultiInput>(
227+
std::move(CollectDescrs));
227228
}
228229

229230
if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,40 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
308308
NewVec = createShuffle(VecOp, Mask);
309309
break;
310310
}
311+
case LegalityResultID::DiamondReuseMultiInput: {
312+
const auto &Descr =
313+
cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
314+
Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
315+
316+
// TODO: Try to get WhereIt without creating a vector.
317+
SmallVector<Value *, 4> DescrInstrs;
318+
for (const auto &ElmDescr : Descr.getDescrs()) {
319+
if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue()))
320+
DescrInstrs.push_back(I);
321+
}
322+
auto WhereIt = getInsertPointAfterInstrs(DescrInstrs);
323+
324+
Value *LastV = PoisonValue::get(ResTy);
325+
for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
326+
Value *VecOp = ElmDescr.getValue();
327+
Context &Ctx = VecOp->getContext();
328+
Value *ValueToInsert;
329+
if (ElmDescr.needsExtract()) {
330+
ConstantInt *IdxC =
331+
ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx());
332+
ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt,
333+
VecOp->getContext(), "VExt");
334+
} else {
335+
ValueToInsert = VecOp;
336+
}
337+
ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
338+
Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
339+
WhereIt, Ctx, "VIns");
340+
LastV = Ins;
341+
}
342+
NewVec = LastV;
343+
break;
344+
}
311345
case LegalityResultID::Pack: {
312346
// If we can't vectorize the seeds then just return.
313347
if (Depth == 0)

llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,30 @@ define void @diamondWithShuffle(ptr %ptr) {
242242
store float %sub1, ptr %ptr1
243243
ret void
244244
}
245+
246+
define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
247+
; CHECK-LABEL: define void @diamondMultiInput(
248+
; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
249+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
250+
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
251+
; CHECK-NEXT: [[LDX:%.*]] = load float, ptr [[PTRX]], align 4
252+
; CHECK-NEXT: [[VINS:%.*]] = insertelement <2 x float> poison, float [[LDX]], i32 0
253+
; CHECK-NEXT: [[VEXT:%.*]] = extractelement <2 x float> [[VECL]], i32 0
254+
; CHECK-NEXT: [[VINS1:%.*]] = insertelement <2 x float> [[VINS]], float [[VEXT]], i32 1
255+
; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VINS1]]
256+
; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
257+
; CHECK-NEXT: ret void
258+
;
259+
%ptr0 = getelementptr float, ptr %ptr, i32 0
260+
%ptr1 = getelementptr float, ptr %ptr, i32 1
261+
%ld0 = load float, ptr %ptr0
262+
%ld1 = load float, ptr %ptr1
263+
264+
%ldX = load float, ptr %ptrX
265+
266+
%sub0 = fsub float %ld0, %ldX
267+
%sub1 = fsub float %ld1, %ld0
268+
store float %sub0, ptr %ptr0
269+
store float %sub1, ptr %ptr1
270+
ret void
271+
}

0 commit comments

Comments
 (0)