Skip to content

[SandboxVectorizer] New class to actually collect and manage seeds #113386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2024

Conversation

Sterling-Augustine
Copy link
Contributor

This relands d91318b, with test-only changes to make gcc-10 happy.

This relands d91318b, with test-only
changes to make gcc-10 happy.
@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (Sterling-Augustine)

Changes

This relands d91318b, with test-only changes to make gcc-10 happy.


Full diff: https://github.com/llvm/llvm-project/pull/113386.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h (+27)
  • (added) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h (+30)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp (+67)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp (+184)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
index a4512862136a8b..ed1cb8488c29eb 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
@@ -284,6 +284,33 @@ class SeedContainer {
 #endif // NDEBUG
 };
 
+class SeedCollector {
+  SeedContainer StoreSeeds;
+  SeedContainer LoadSeeds;
+  Context &Ctx;
+
+  /// \Returns the number of SeedBundle groups for all seed types.
+  /// This is to be used for limiting compilation time.
+  unsigned totalNumSeedGroups() const {
+    return StoreSeeds.size() + LoadSeeds.size();
+  }
+
+public:
+  SeedCollector(BasicBlock *BB, ScalarEvolution &SE);
+  ~SeedCollector();
+
+  iterator_range<SeedContainer::iterator> getStoreSeeds() {
+    return {StoreSeeds.begin(), StoreSeeds.end()};
+  }
+  iterator_range<SeedContainer::iterator> getLoadSeeds() {
+    return {LoadSeeds.begin(), LoadSeeds.end()};
+  }
+#ifndef NDEBUG
+  void print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
 } // namespace llvm::sandboxir
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SEEDCOLLECTOR_H
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
new file mode 100644
index 00000000000000..64f57edb38484e
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -0,0 +1,30 @@
+//===- VecUtils.h -----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Collector for SandboxVectorizer related convenience functions that don't
+// belong in other classes.
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+
+class Utils {
+public:
+  /// \Returns the number of elements in \p Ty. That is the number of lanes if a
+  /// fixed vector or 1 if scalar. ScalableVectors have unknown size and
+  /// therefore are unsupported.
+  static int getNumElements(Type *Ty) {
+    assert(!isa<ScalableVectorType>(Ty));
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
+  }
+  /// Returns \p Ty if scalar or its element type if vector.
+  static Type *getElementType(Type *Ty) {
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
+  }
+}
+
+#endif LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
index 66fac080a7b7cc..0d928af1902073 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
@@ -22,6 +22,16 @@ namespace llvm::sandboxir {
 cl::opt<unsigned> SeedBundleSizeLimit(
     "sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden,
     cl::desc("Limit the size of the seed bundle to cap compilation time."));
+#define LoadSeedsDef "loads"
+#define StoreSeedsDef "stores"
+cl::opt<std::string> CollectSeeds(
+    "sbvec-collect-seeds", cl::init(LoadSeedsDef "," StoreSeedsDef), cl::Hidden,
+    cl::desc("Collect these seeds. Use empty for none or a comma-separated "
+             "list of '" LoadSeedsDef "' and '" StoreSeedsDef "'."));
+cl::opt<unsigned> SeedGroupsLimit(
+    "sbvec-seed-groups-limit", cl::init(256), cl::Hidden,
+    cl::desc("Limit the number of collected seeds groups in a BB to "
+             "cap compilation time."));
 
 MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
                                                     unsigned MaxVecRegBits,
@@ -131,4 +141,61 @@ void SeedContainer::print(raw_ostream &OS) const {
 LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); }
 #endif // NDEBUG
 
+template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
+  if (LSI->isSimple())
+    return true;
+  auto *Ty = Utils::getExpectedType(LSI);
+  // Omit types that are architecturally unvectorizable
+  if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty())
+    return false;
+  // Omit vector types without compile-time-known lane counts
+  if (isa<ScalableVectorType>(Ty))
+    return false;
+  if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
+    return VectorType::isValidElementType(VTy->getElementType());
+  return VectorType::isValidElementType(Ty);
+}
+
+template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
+template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
+
+SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE)
+    : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) {
+  // TODO: Register a callback for updating the Collector data structures upon
+  // instr removal
+
+  bool CollectStores = CollectSeeds.find(StoreSeedsDef) != std::string::npos;
+  bool CollectLoads = CollectSeeds.find(LoadSeedsDef) != std::string::npos;
+  if (!CollectStores && !CollectLoads)
+    return;
+  // Actually collect the seeds.
+  for (auto &I : *BB) {
+    if (StoreInst *SI = dyn_cast<StoreInst>(&I))
+      if (CollectStores && isValidMemSeed(SI))
+        StoreSeeds.insert(SI);
+    if (LoadInst *LI = dyn_cast<LoadInst>(&I))
+      if (CollectLoads && isValidMemSeed(LI))
+        LoadSeeds.insert(LI);
+    // Cap compilation time.
+    if (totalNumSeedGroups() > SeedGroupsLimit)
+      break;
+  }
+}
+
+SeedCollector::~SeedCollector() {
+  // TODO: Unregister the callback for updating the seed datastructures upon
+  // instr removal
+}
+
+#ifndef NDEBUG
+void SeedCollector::print(raw_ostream &OS) const {
+  OS << "=== StoreSeeds ===\n";
+  StoreSeeds.print(OS);
+  OS << "=== LoadSeeds ===\n";
+  LoadSeeds.print(OS);
+}
+
+void SeedCollector::dump() const { print(dbgs()); }
+#endif
+
 } // namespace llvm::sandboxir
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index 82b230d50c4ec9..b1b0156a4fe31a 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -22,6 +22,22 @@
 
 using namespace llvm;
 
+// TODO: gcc-10 has a bug that causes the below line not to compile due to some
+// macro-magic in gunit in combination with a class with pure-virtual
+// function. Once gcc-10 is no longer supported, replace this function with
+// something like the following:
+//
+// EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
+static void
+ExpectThatElementsAre(sandboxir::SeedBundle &SR,
+                      llvm::ArrayRef<sandboxir::Instruction *> Contents) {
+  EXPECT_EQ(range_size(SR), Contents.size());
+  auto CI = Contents.begin();
+  if (range_size(SR) == Contents.size())
+    for (auto &S : SR)
+      EXPECT_EQ(S, *CI++);
+}
+
 struct SeedBundleTest : public testing::Test {
   LLVMContext C;
   std::unique_ptr<Module> M;
@@ -268,3 +284,171 @@ define void @foo(ptr %ptrA, float %val, ptr %ptrB) {
   }
   EXPECT_EQ(Cnt, 0u);
 }
+
+TEST_F(SeedBundleTest, ConsecutiveStores) {
+  // Where "Consecutive" means the stores address consecutive locations in
+  // memory, but not in program order. Check to see that the collector puts them
+  // in the proper order for vectorization.
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ptr2 = getelementptr float, ptr %ptr, i32 2
+  %ptr3 = getelementptr float, ptr %ptr, i32 3
+  store float %val, ptr %ptr0
+  store float %val, ptr %ptr2
+  store float %val, ptr %ptr1
+  store float %val, ptr %ptr3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 4);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St2 = &*It++;
+  auto *St1 = &*It++;
+  auto *St3 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  auto &SB = *StoreSeedsRange.begin();
+  //  Expect just one vector of store seeds
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  ExpectThatElementsAre(SB, {St0, St1, St2, St3});
+}
+
+TEST_F(SeedBundleTest, StoresWithGaps) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 3
+  %ptr2 = getelementptr float, ptr %ptr, i32 5
+  %ptr3 = getelementptr float, ptr %ptr, i32 7
+  store float %val, ptr %ptr0
+  store float %val, ptr %ptr2
+  store float %val, ptr %ptr1
+  store float %val, ptr %ptr3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 4);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St2 = &*It++;
+  auto *St1 = &*It++;
+  auto *St3 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  auto &SB = *StoreSeedsRange.begin();
+  // Expect just one vector of store seeds
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  ExpectThatElementsAre(SB, {St0, St1, St2, St3});
+}
+
+TEST_F(SeedBundleTest, VectorStores) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, <2 x float> %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  store <2 x float> %val, ptr %ptr1
+  store <2 x float> %val, ptr %ptr0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 2);
+  // StX with X as the order by offset in memory
+  auto *St1 = &*It++;
+  auto *St0 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  auto &SB = *StoreSeedsRange.begin();
+  ExpectThatElementsAre(SB, {St0, St1});
+}
+
+TEST_F(SeedBundleTest, MixedScalarVectors) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %v, <2 x float> %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ptr3 = getelementptr float, ptr %ptr, i32 3
+  store float %v, ptr %ptr0
+  store float %v, ptr %ptr3
+  store <2 x float> %val, ptr %ptr1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 3);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St3 = &*It++;
+  auto *St1 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  auto &SB = *StoreSeedsRange.begin();
+  ExpectThatElementsAre(SB, {St0, St1, St3});
+}

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-vectorizers

Author: None (Sterling-Augustine)

Changes

This relands d91318b, with test-only changes to make gcc-10 happy.


Full diff: https://github.com/llvm/llvm-project/pull/113386.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h (+27)
  • (added) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h (+30)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp (+67)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp (+184)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
index a4512862136a8b..ed1cb8488c29eb 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
@@ -284,6 +284,33 @@ class SeedContainer {
 #endif // NDEBUG
 };
 
+class SeedCollector {
+  SeedContainer StoreSeeds;
+  SeedContainer LoadSeeds;
+  Context &Ctx;
+
+  /// \Returns the number of SeedBundle groups for all seed types.
+  /// This is to be used for limiting compilation time.
+  unsigned totalNumSeedGroups() const {
+    return StoreSeeds.size() + LoadSeeds.size();
+  }
+
+public:
+  SeedCollector(BasicBlock *BB, ScalarEvolution &SE);
+  ~SeedCollector();
+
+  iterator_range<SeedContainer::iterator> getStoreSeeds() {
+    return {StoreSeeds.begin(), StoreSeeds.end()};
+  }
+  iterator_range<SeedContainer::iterator> getLoadSeeds() {
+    return {LoadSeeds.begin(), LoadSeeds.end()};
+  }
+#ifndef NDEBUG
+  void print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
 } // namespace llvm::sandboxir
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SEEDCOLLECTOR_H
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
new file mode 100644
index 00000000000000..64f57edb38484e
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -0,0 +1,30 @@
+//===- VecUtils.h -----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Collector for SandboxVectorizer related convenience functions that don't
+// belong in other classes.
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+
+class Utils {
+public:
+  /// \Returns the number of elements in \p Ty. That is the number of lanes if a
+  /// fixed vector or 1 if scalar. ScalableVectors have unknown size and
+  /// therefore are unsupported.
+  static int getNumElements(Type *Ty) {
+    assert(!isa<ScalableVectorType>(Ty));
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
+  }
+  /// Returns \p Ty if scalar or its element type if vector.
+  static Type *getElementType(Type *Ty) {
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
+  }
+}
+
+#endif LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
index 66fac080a7b7cc..0d928af1902073 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
@@ -22,6 +22,16 @@ namespace llvm::sandboxir {
 cl::opt<unsigned> SeedBundleSizeLimit(
     "sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden,
     cl::desc("Limit the size of the seed bundle to cap compilation time."));
+#define LoadSeedsDef "loads"
+#define StoreSeedsDef "stores"
+cl::opt<std::string> CollectSeeds(
+    "sbvec-collect-seeds", cl::init(LoadSeedsDef "," StoreSeedsDef), cl::Hidden,
+    cl::desc("Collect these seeds. Use empty for none or a comma-separated "
+             "list of '" LoadSeedsDef "' and '" StoreSeedsDef "'."));
+cl::opt<unsigned> SeedGroupsLimit(
+    "sbvec-seed-groups-limit", cl::init(256), cl::Hidden,
+    cl::desc("Limit the number of collected seeds groups in a BB to "
+             "cap compilation time."));
 
 MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
                                                     unsigned MaxVecRegBits,
@@ -131,4 +141,61 @@ void SeedContainer::print(raw_ostream &OS) const {
 LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); }
 #endif // NDEBUG
 
+template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
+  if (LSI->isSimple())
+    return true;
+  auto *Ty = Utils::getExpectedType(LSI);
+  // Omit types that are architecturally unvectorizable
+  if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty())
+    return false;
+  // Omit vector types without compile-time-known lane counts
+  if (isa<ScalableVectorType>(Ty))
+    return false;
+  if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
+    return VectorType::isValidElementType(VTy->getElementType());
+  return VectorType::isValidElementType(Ty);
+}
+
+template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
+template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
+
+SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE)
+    : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) {
+  // TODO: Register a callback for updating the Collector data structures upon
+  // instr removal
+
+  bool CollectStores = CollectSeeds.find(StoreSeedsDef) != std::string::npos;
+  bool CollectLoads = CollectSeeds.find(LoadSeedsDef) != std::string::npos;
+  if (!CollectStores && !CollectLoads)
+    return;
+  // Actually collect the seeds.
+  for (auto &I : *BB) {
+    if (StoreInst *SI = dyn_cast<StoreInst>(&I))
+      if (CollectStores && isValidMemSeed(SI))
+        StoreSeeds.insert(SI);
+    if (LoadInst *LI = dyn_cast<LoadInst>(&I))
+      if (CollectLoads && isValidMemSeed(LI))
+        LoadSeeds.insert(LI);
+    // Cap compilation time.
+    if (totalNumSeedGroups() > SeedGroupsLimit)
+      break;
+  }
+}
+
+SeedCollector::~SeedCollector() {
+  // TODO: Unregister the callback for updating the seed datastructures upon
+  // instr removal
+}
+
+#ifndef NDEBUG
+void SeedCollector::print(raw_ostream &OS) const {
+  OS << "=== StoreSeeds ===\n";
+  StoreSeeds.print(OS);
+  OS << "=== LoadSeeds ===\n";
+  LoadSeeds.print(OS);
+}
+
+void SeedCollector::dump() const { print(dbgs()); }
+#endif
+
 } // namespace llvm::sandboxir
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index 82b230d50c4ec9..b1b0156a4fe31a 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -22,6 +22,22 @@
 
 using namespace llvm;
 
+// TODO: gcc-10 has a bug that causes the below line not to compile due to some
+// macro-magic in gunit in combination with a class with pure-virtual
+// function. Once gcc-10 is no longer supported, replace this function with
+// something like the following:
+//
+// EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
+static void
+ExpectThatElementsAre(sandboxir::SeedBundle &SR,
+                      llvm::ArrayRef<sandboxir::Instruction *> Contents) {
+  EXPECT_EQ(range_size(SR), Contents.size());
+  auto CI = Contents.begin();
+  if (range_size(SR) == Contents.size())
+    for (auto &S : SR)
+      EXPECT_EQ(S, *CI++);
+}
+
 struct SeedBundleTest : public testing::Test {
   LLVMContext C;
   std::unique_ptr<Module> M;
@@ -268,3 +284,171 @@ define void @foo(ptr %ptrA, float %val, ptr %ptrB) {
   }
   EXPECT_EQ(Cnt, 0u);
 }
+
+TEST_F(SeedBundleTest, ConsecutiveStores) {
+  // Where "Consecutive" means the stores address consecutive locations in
+  // memory, but not in program order. Check to see that the collector puts them
+  // in the proper order for vectorization.
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ptr2 = getelementptr float, ptr %ptr, i32 2
+  %ptr3 = getelementptr float, ptr %ptr, i32 3
+  store float %val, ptr %ptr0
+  store float %val, ptr %ptr2
+  store float %val, ptr %ptr1
+  store float %val, ptr %ptr3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 4);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St2 = &*It++;
+  auto *St1 = &*It++;
+  auto *St3 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  auto &SB = *StoreSeedsRange.begin();
+  //  Expect just one vector of store seeds
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  ExpectThatElementsAre(SB, {St0, St1, St2, St3});
+}
+
+TEST_F(SeedBundleTest, StoresWithGaps) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 3
+  %ptr2 = getelementptr float, ptr %ptr, i32 5
+  %ptr3 = getelementptr float, ptr %ptr, i32 7
+  store float %val, ptr %ptr0
+  store float %val, ptr %ptr2
+  store float %val, ptr %ptr1
+  store float %val, ptr %ptr3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 4);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St2 = &*It++;
+  auto *St1 = &*It++;
+  auto *St3 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  auto &SB = *StoreSeedsRange.begin();
+  // Expect just one vector of store seeds
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  ExpectThatElementsAre(SB, {St0, St1, St2, St3});
+}
+
+TEST_F(SeedBundleTest, VectorStores) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, <2 x float> %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  store <2 x float> %val, ptr %ptr1
+  store <2 x float> %val, ptr %ptr0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 2);
+  // StX with X as the order by offset in memory
+  auto *St1 = &*It++;
+  auto *St0 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  auto &SB = *StoreSeedsRange.begin();
+  ExpectThatElementsAre(SB, {St0, St1});
+}
+
+TEST_F(SeedBundleTest, MixedScalarVectors) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %v, <2 x float> %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ptr3 = getelementptr float, ptr %ptr, i32 3
+  store float %v, ptr %ptr0
+  store float %v, ptr %ptr3
+  store <2 x float> %val, ptr %ptr1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 3);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St3 = &*It++;
+  auto *St1 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  auto &SB = *StoreSeedsRange.begin();
+  ExpectThatElementsAre(SB, {St0, St1, St3});
+}

@Sterling-Augustine Sterling-Augustine merged commit f1be516 into llvm:main Oct 23, 2024
11 checks passed
@frobtech frobtech mentioned this pull request Oct 25, 2024
@@ -131,4 +141,61 @@ void SeedContainer::print(raw_ostream &OS) const {
LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); }
#endif // NDEBUG

template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot of the checks here don't seem to be tested

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's always hard to decide where to checkpoint work in progress.

From the original description:

"There are many more tests to add, but I would like to get this reviewed
and the details sorted out before it grows too big."

Tests coming are for loads and additional validity checks.

NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants