Skip to content

Commit 87e4b68

Browse files
authored
[SandboxVec][Legality] Implement ShuffleMask (#123404)
This patch implements a helper ShuffleMask data structure that helps describe shuffles of elements across lanes.
1 parent 22d4ff1 commit 87e4b68

File tree

6 files changed

+219
-17
lines changed

6 files changed

+219
-17
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,62 @@ class LegalityAnalysis;
2525
class Value;
2626
class InstrMaps;
2727

28+
class ShuffleMask {
29+
public:
30+
using IndicesVecT = SmallVector<int, 8>;
31+
32+
private:
33+
IndicesVecT Indices;
34+
35+
public:
36+
ShuffleMask(SmallVectorImpl<int> &&Indices) : Indices(std::move(Indices)) {}
37+
ShuffleMask(std::initializer_list<int> Indices) : Indices(Indices) {}
38+
explicit ShuffleMask(ArrayRef<int> Indices) : Indices(Indices) {}
39+
operator ArrayRef<int>() const { return Indices; }
40+
/// Creates and returns an identity shuffle mask of size \p Sz.
41+
/// For example if Sz == 4 the returned mask is {0, 1, 2, 3}.
42+
static ShuffleMask getIdentity(unsigned Sz) {
43+
IndicesVecT Indices;
44+
Indices.reserve(Sz);
45+
for (auto Idx : seq<int>(0, (int)Sz))
46+
Indices.push_back(Idx);
47+
return ShuffleMask(std::move(Indices));
48+
}
49+
/// \Returns true if the mask is a perfect identity mask with consecutive
50+
/// indices, i.e., performs no lane shuffling, like 0,1,2,3...
51+
bool isIdentity() const {
52+
for (auto [Idx, Elm] : enumerate(Indices)) {
53+
if ((int)Idx != Elm)
54+
return false;
55+
}
56+
return true;
57+
}
58+
bool operator==(const ShuffleMask &Other) const {
59+
return Indices == Other.Indices;
60+
}
61+
bool operator!=(const ShuffleMask &Other) const { return !(*this == Other); }
62+
size_t size() const { return Indices.size(); }
63+
int operator[](int Idx) const { return Indices[Idx]; }
64+
using const_iterator = IndicesVecT::const_iterator;
65+
const_iterator begin() const { return Indices.begin(); }
66+
const_iterator end() const { return Indices.end(); }
67+
#ifndef NDEBUG
68+
friend raw_ostream &operator<<(raw_ostream &OS, const ShuffleMask &Mask) {
69+
Mask.print(OS);
70+
return OS;
71+
}
72+
void print(raw_ostream &OS) const {
73+
interleave(Indices, OS, [&OS](auto Elm) { OS << Elm; }, ",");
74+
}
75+
LLVM_DUMP_METHOD void dump() const;
76+
#endif
77+
};
78+
2879
enum class LegalityResultID {
29-
Pack, ///> Collect scalar values.
30-
Widen, ///> Vectorize by combining scalars to a vector.
31-
DiamondReuse, ///> Don't generate new code, reuse existing vector.
80+
Pack, ///> Collect scalar values.
81+
Widen, ///> Vectorize by combining scalars to a vector.
82+
DiamondReuse, ///> Don't generate new code, reuse existing vector.
83+
DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
3284
};
3385

3486
/// The reason for vectorizing or not vectorizing.
@@ -54,6 +106,8 @@ struct ToStr {
54106
return "Widen";
55107
case LegalityResultID::DiamondReuse:
56108
return "DiamondReuse";
109+
case LegalityResultID::DiamondReuseWithShuffle:
110+
return "DiamondReuseWithShuffle";
57111
}
58112
llvm_unreachable("Unknown LegalityResultID enum");
59113
}
@@ -154,6 +208,22 @@ class DiamondReuse final : public LegalityResult {
154208
Value *getVector() const { return Vec; }
155209
};
156210

211+
class DiamondReuseWithShuffle final : public LegalityResult {
212+
friend class LegalityAnalysis;
213+
Value *Vec;
214+
ShuffleMask Mask;
215+
DiamondReuseWithShuffle(Value *Vec, const ShuffleMask &Mask)
216+
: LegalityResult(LegalityResultID::DiamondReuseWithShuffle), Vec(Vec),
217+
Mask(Mask) {}
218+
219+
public:
220+
static bool classof(const LegalityResult *From) {
221+
return From->getSubclassID() == LegalityResultID::DiamondReuseWithShuffle;
222+
}
223+
Value *getVector() const { return Vec; }
224+
const ShuffleMask &getMask() const { return Mask; }
225+
};
226+
157227
class Pack final : public LegalityResultWithReason {
158228
Pack(ResultReason Reason)
159229
: LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -192,23 +262,22 @@ class CollectDescr {
192262
CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
193263
: Descrs(std::move(Descrs)) {}
194264
/// If all elements come from a single vector input, then return that vector
195-
/// and whether we need a shuffle to get them in order.
196-
std::optional<std::pair<Value *, bool>> getSingleInput() const {
265+
/// and also the shuffle mask required to get them in order.
266+
std::optional<std::pair<Value *, ShuffleMask>> getSingleInput() const {
197267
const auto &Descr0 = *Descrs.begin();
198268
Value *V0 = Descr0.getValue();
199269
if (!Descr0.needsExtract())
200270
return std::nullopt;
201-
bool NeedsShuffle = Descr0.getExtractIdx() != 0;
202-
int Lane = 1;
271+
ShuffleMask::IndicesVecT MaskIndices;
272+
MaskIndices.push_back(Descr0.getExtractIdx());
203273
for (const auto &Descr : drop_begin(Descrs)) {
204274
if (!Descr.needsExtract())
205275
return std::nullopt;
206276
if (Descr.getValue() != V0)
207277
return std::nullopt;
208-
if (Descr.getExtractIdx() != Lane++)
209-
NeedsShuffle = true;
278+
MaskIndices.push_back(Descr.getExtractIdx());
210279
}
211-
return std::make_pair(V0, NeedsShuffle);
280+
return std::make_pair(V0, ShuffleMask(std::move(MaskIndices)));
212281
}
213282
bool hasVectorInputs() const {
214283
return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class BottomUpVec final : public FunctionPass {
3636
/// Erases all dead instructions from the dead instruction candidates
3737
/// collected during vectorization.
3838
void tryEraseDeadInstrs();
39+
/// Creates a shuffle instruction that shuffles \p VecOp according to \p Mask.
40+
Value *createShuffle(Value *VecOp, const ShuffleMask &Mask);
3941
/// Packs all elements of \p ToPack into a vector and returns that vector.
4042
Value *createPack(ArrayRef<Value *> ToPack);
4143
void collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl);

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ namespace llvm::sandboxir {
2020
#define DEBUG_TYPE "SBVec:Legality"
2121

2222
#ifndef NDEBUG
23+
void ShuffleMask::dump() const {
24+
print(dbgs());
25+
dbgs() << "\n";
26+
}
27+
2328
void LegalityResult::dump() const {
2429
print(dbgs());
2530
dbgs() << "\n";
@@ -213,13 +218,12 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
213218
auto CollectDescrs = getHowToCollectValues(Bndl);
214219
if (CollectDescrs.hasVectorInputs()) {
215220
if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
216-
auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
217-
if (!NeedsShuffle)
221+
auto [Vec, Mask] = *ValueShuffleOpt;
222+
if (Mask.isIdentity())
218223
return createLegalityResult<DiamondReuse>(Vec);
219-
llvm_unreachable("TODO: Unimplemented");
220-
} else {
221-
llvm_unreachable("TODO: Unimplemented");
224+
return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
222225
}
226+
llvm_unreachable("TODO: Unimplemented");
223227
}
224228

225229
if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ void BottomUpVec::tryEraseDeadInstrs() {
179179
DeadInstrCandidates.clear();
180180
}
181181

182+
Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask) {
183+
BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp});
184+
return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
185+
VecOp->getContext(), "VShuf");
186+
}
187+
182188
Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
183189
BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
184190

@@ -295,6 +301,13 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
295301
NewVec = cast<DiamondReuse>(LegalityRes).getVector();
296302
break;
297303
}
304+
case LegalityResultID::DiamondReuseWithShuffle: {
305+
auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
306+
const ShuffleMask &Mask =
307+
cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
308+
NewVec = createShuffle(VecOp, Mask);
309+
break;
310+
}
298311
case LegalityResultID::Pack: {
299312
// If we can't vectorize the seeds then just return.
300313
if (Depth == 0)

llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,24 @@ define void @diamond(ptr %ptr) {
221221
store float %sub1, ptr %ptr1
222222
ret void
223223
}
224+
225+
define void @diamondWithShuffle(ptr %ptr) {
226+
; CHECK-LABEL: define void @diamondWithShuffle(
227+
; CHECK-SAME: ptr [[PTR:%.*]]) {
228+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
229+
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
230+
; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <2 x float> [[VECL]], <2 x float> [[VECL]], <2 x i32> <i32 1, i32 0>
231+
; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VSHUF]]
232+
; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
233+
; CHECK-NEXT: ret void
234+
;
235+
%ptr0 = getelementptr float, ptr %ptr, i32 0
236+
%ptr1 = getelementptr float, ptr %ptr, i32 1
237+
%ld0 = load float, ptr %ptr0
238+
%ld1 = load float, ptr %ptr1
239+
%sub0 = fsub float %ld0, %ld1
240+
%sub1 = fsub float %ld1, %ld0
241+
store float %sub0, ptr %ptr0
242+
store float %sub1, ptr %ptr1
243+
ret void
244+
}

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/SandboxIR/Instruction.h"
2020
#include "llvm/Support/SourceMgr.h"
2121
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
22+
#include "gmock/gmock.h"
2223
#include "gtest/gtest.h"
2324

2425
using namespace llvm;
@@ -321,7 +322,7 @@ define void @foo(ptr %ptr) {
321322
sandboxir::CollectDescr CD(std::move(Descrs));
322323
EXPECT_TRUE(CD.getSingleInput());
323324
EXPECT_EQ(CD.getSingleInput()->first, VLd);
324-
EXPECT_EQ(CD.getSingleInput()->second, false);
325+
EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(0, 1));
325326
EXPECT_TRUE(CD.hasVectorInputs());
326327
}
327328
{
@@ -331,7 +332,7 @@ define void @foo(ptr %ptr) {
331332
sandboxir::CollectDescr CD(std::move(Descrs));
332333
EXPECT_TRUE(CD.getSingleInput());
333334
EXPECT_EQ(CD.getSingleInput()->first, VLd);
334-
EXPECT_EQ(CD.getSingleInput()->second, true);
335+
EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(1, 0));
335336
EXPECT_TRUE(CD.hasVectorInputs());
336337
}
337338
{
@@ -352,3 +353,95 @@ define void @foo(ptr %ptr) {
352353
EXPECT_FALSE(CD.hasVectorInputs());
353354
}
354355
}
356+
357+
TEST_F(LegalityTest, ShuffleMask) {
358+
{
359+
// Check SmallVector constructor.
360+
SmallVector<int> Indices({0, 1, 2, 3});
361+
sandboxir::ShuffleMask Mask(std::move(Indices));
362+
EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
363+
}
364+
{
365+
// Check initializer_list constructor.
366+
sandboxir::ShuffleMask Mask({0, 1, 2, 3});
367+
EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
368+
}
369+
{
370+
// Check ArrayRef constructor.
371+
sandboxir::ShuffleMask Mask(ArrayRef<int>({0, 1, 2, 3}));
372+
EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
373+
}
374+
{
375+
// Check operator ArrayRef<int>().
376+
sandboxir::ShuffleMask Mask({0, 1, 2, 3});
377+
ArrayRef<int> Array = Mask;
378+
EXPECT_THAT(Array, testing::ElementsAre(0, 1, 2, 3));
379+
}
380+
{
381+
// Check getIdentity().
382+
auto IdentityMask = sandboxir::ShuffleMask::getIdentity(4);
383+
EXPECT_THAT(IdentityMask, testing::ElementsAre(0, 1, 2, 3));
384+
EXPECT_TRUE(IdentityMask.isIdentity());
385+
}
386+
{
387+
// Check isIdentity().
388+
sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
389+
EXPECT_TRUE(Mask1.isIdentity());
390+
sandboxir::ShuffleMask Mask2({1, 2, 3, 4});
391+
EXPECT_FALSE(Mask2.isIdentity());
392+
}
393+
{
394+
// Check operator==().
395+
sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
396+
sandboxir::ShuffleMask Mask2({0, 1, 2, 3});
397+
EXPECT_TRUE(Mask1 == Mask2);
398+
EXPECT_FALSE(Mask1 != Mask2);
399+
}
400+
{
401+
// Check operator!=().
402+
sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
403+
sandboxir::ShuffleMask Mask2({0, 1, 2, 4});
404+
EXPECT_TRUE(Mask1 != Mask2);
405+
EXPECT_FALSE(Mask1 == Mask2);
406+
}
407+
{
408+
// Check size().
409+
sandboxir::ShuffleMask Mask({0, 1, 2, 3});
410+
EXPECT_EQ(Mask.size(), 4u);
411+
}
412+
{
413+
// Check operator[].
414+
sandboxir::ShuffleMask Mask({0, 1, 2, 3});
415+
for (auto [Idx, Elm] : enumerate(Mask)) {
416+
EXPECT_EQ(Elm, Mask[Idx]);
417+
}
418+
}
419+
{
420+
// Check begin(), end().
421+
sandboxir::ShuffleMask Mask({0, 1, 2, 3});
422+
sandboxir::ShuffleMask::const_iterator Begin = Mask.begin();
423+
sandboxir::ShuffleMask::const_iterator End = Mask.begin();
424+
int Idx = 0;
425+
for (auto It = Begin; It != End; ++It) {
426+
EXPECT_EQ(*It, Mask[Idx++]);
427+
}
428+
}
429+
#ifndef NDEBUG
430+
{
431+
// Check print(OS).
432+
sandboxir::ShuffleMask Mask({0, 1, 2, 3});
433+
std::string Str;
434+
raw_string_ostream OS(Str);
435+
Mask.print(OS);
436+
EXPECT_EQ(Str, "0,1,2,3");
437+
}
438+
{
439+
// Check operator<<().
440+
sandboxir::ShuffleMask Mask({0, 1, 2, 3});
441+
std::string Str;
442+
raw_string_ostream OS(Str);
443+
OS << Mask;
444+
EXPECT_EQ(Str, "0,1,2,3");
445+
}
446+
#endif // NDEBUG
447+
}

0 commit comments

Comments
 (0)