Skip to content

[ADT] Bitset: add shift operators, word accessors, and etc#193400

Open
JiachenYuan wants to merge 1 commit intollvm:mainfrom
JiachenYuan:perf/jiachen/bitset_prepare_for_lbm
Open

[ADT] Bitset: add shift operators, word accessors, and etc#193400
JiachenYuan wants to merge 1 commit intollvm:mainfrom
JiachenYuan:perf/jiachen/bitset_prepare_for_lbm

Conversation

@JiachenYuan
Copy link
Copy Markdown
Contributor

This PR is split out from #191757 per reviewer request. It has the following changes to llvm::Bitset<N>:

  • Added operator<</<<=/>>/>>=, getNumWords(), getWord(), and findLastSet().
  • Moved the std::array<> constructor from protected to public and explicit.

A follow-up PR will use these to re-implement LaneBitmask as a llvm::Bitset wrapper.


The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases.

@JiachenYuan JiachenYuan marked this pull request as ready for review April 22, 2026 03:55
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 22, 2026

@llvm/pr-subscribers-llvm-adt

Author: Jiachen Yuan (JiachenYuan)

Changes

This PR is split out from #191757 per reviewer request. It has the following changes to llvm::Bitset&lt;N&gt;:

  • Added operator&lt;&lt;/&lt;&lt;=/&gt;&gt;/&gt;&gt;=, getNumWords(), getWord(), and findLastSet().
  • Moved the std::array&lt;&gt; constructor from protected to public and explicit.

A follow-up PR will use these to re-implement LaneBitmask as a llvm::Bitset wrapper.


The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases.


Full diff: https://github.com/llvm/llvm-project/pull/193400.diff

2 Files Affected:

  • (modified) llvm/include/llvm/ADT/Bitset.h (+89-4)
  • (modified) llvm/unittests/ADT/BitsetTest.cpp (+198)
diff --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h
index 9dc0f24b1d9f5..3cb2b7d28d83b 100644
--- a/llvm/include/llvm/ADT/Bitset.h
+++ b/llvm/include/llvm/ADT/Bitset.h
@@ -51,8 +51,9 @@ template <unsigned NumBits> class Bitset {
 
   constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; }
 
-protected:
-  constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
+public:
+  explicit constexpr Bitset(
+      const std::array<uint64_t, (NumBits + 63) / 64> &B) {
     if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
       for (size_t I = 0; I != B.size(); ++I)
         Bits[I] = B[I];
@@ -70,8 +71,6 @@ template <unsigned NumBits> class Bitset {
     }
     maskLastWord();
   }
-
-public:
   constexpr Bitset() = default;
   constexpr Bitset(std::initializer_list<unsigned> Init) {
     for (auto I : Init)
@@ -194,6 +193,92 @@ template <unsigned NumBits> class Bitset {
     }
     return false;
   }
+
+  constexpr Bitset &operator<<=(unsigned N) {
+    if (N == 0)
+      return *this;
+    if (N >= NumBits) {
+      return *this = Bitset();
+    }
+    const unsigned WordShift = N / BitwordBits;
+    const unsigned BitShift = N % BitwordBits;
+    if (BitShift == 0) {
+      for (int I = NumWords - 1; I >= static_cast<int>(WordShift); --I)
+        Bits[I] = Bits[I - WordShift];
+    } else {
+      const unsigned CarryShift = BitwordBits - BitShift;
+      for (int I = NumWords - 1; I > static_cast<int>(WordShift); --I) {
+        Bits[I] = (Bits[I - WordShift] << BitShift) |
+                  (Bits[I - WordShift - 1] >> CarryShift);
+      }
+      Bits[WordShift] = Bits[0] << BitShift;
+    }
+    for (unsigned I = 0; I < WordShift; ++I)
+      Bits[I] = 0;
+    maskLastWord();
+    return *this;
+  }
+
+  constexpr Bitset operator<<(unsigned N) const {
+    Bitset Result(*this);
+    Result <<= N;
+    return Result;
+  }
+
+  constexpr Bitset &operator>>=(unsigned N) {
+    if (N == 0)
+      return *this;
+    if (N >= NumBits) {
+      return *this = Bitset();
+    }
+    const unsigned WordShift = N / BitwordBits;
+    const unsigned BitShift = N % BitwordBits;
+    if (BitShift == 0) {
+      for (unsigned I = 0; I < NumWords - WordShift; ++I)
+        Bits[I] = Bits[I + WordShift];
+    } else {
+      const unsigned CarryShift = BitwordBits - BitShift;
+      for (unsigned I = 0; I < NumWords - WordShift - 1; ++I) {
+        Bits[I] = (Bits[I + WordShift] >> BitShift) |
+                  (Bits[I + WordShift + 1] << CarryShift);
+      }
+      Bits[NumWords - WordShift - 1] = Bits[NumWords - 1] >> BitShift;
+    }
+    for (unsigned I = NumWords - WordShift; I < NumWords; ++I)
+      Bits[I] = 0;
+    maskLastWord();
+    return *this;
+  }
+
+  constexpr Bitset operator>>(unsigned N) const {
+    Bitset Result(*this);
+    Result >>= N;
+    return Result;
+  }
+
+  /// Return the I-th 64-bit word of the bitset, from least significant to most.
+  constexpr uint64_t getWord(unsigned I) const {
+    if constexpr (BitwordBits == 64) {
+      return Bits[I];
+    } else {
+      static_assert(BitwordBits == 32, "Unsupported word size");
+      uint64_t Lo = (2 * I < NumWords) ? Bits[2 * I] : 0;
+      uint64_t Hi = (2 * I + 1 < NumWords) ? Bits[2 * I + 1] : 0;
+      return Lo | (Hi << 32);
+    }
+  }
+
+  /// Return the index of the highest set bit, or -1 if no bits are set.
+  constexpr int findLastSet() const {
+    for (int I = NumWords - 1; I >= 0; --I)
+      if (Bits[I] != 0)
+        return I * BitwordBits +
+               (BitwordBits - 1 - countl_zero_constexpr(Bits[I]));
+    return -1;
+  }
+
+  /// Return the number of 64-bit words needed to hold all bits.
+  static constexpr unsigned getNumWords() { return (NumBits + 63) / 64; }
 };
 
 } // end namespace llvm
diff --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp
index 678197e31a379..ee3ef07d01979 100644
--- a/llvm/unittests/ADT/BitsetTest.cpp
+++ b/llvm/unittests/ADT/BitsetTest.cpp
@@ -294,4 +294,202 @@ TEST(BitsetTest, BitwiseOperators) {
                 TestXor128.test(127));
 }
 
+TEST(BitsetTest, ShiftOperators) {
+  // Test left shift.
+  static_assert((Bitset<64>({0}) << 10).test(10));
+  static_assert(!(Bitset<64>({0}) << 10).test(0));
+  static_assert((Bitset<64>({63}) << 1).none());
+  static_assert((Bitset<128>({0}) << 64).test(64));
+  static_assert((Bitset<128>({63}) << 1).test(64));
+  static_assert((Bitset<128>({127}) << 1).none());
+
+  // Test right shift.
+  static_assert((Bitset<64>({10}) >> 10).test(0));
+  static_assert(!(Bitset<64>({10}) >> 10).test(10));
+  static_assert((Bitset<64>({0}) >> 1).none());
+  static_assert((Bitset<128>({64}) >> 64).test(0));
+  static_assert((Bitset<128>({64}) >> 1).test(63));
+  static_assert((Bitset<128>({0}) >> 1).none());
+
+  // Test shift by 0.
+  static_assert((Bitset<64>({10, 20}) << 0) == Bitset<64>({10, 20}));
+  static_assert((Bitset<64>({10, 20}) >> 0) == Bitset<64>({10, 20}));
+
+  // Test shift by NumBits (clears all).
+  static_assert((Bitset<64>({0, 63}) << 64).none());
+  static_assert((Bitset<64>({0, 63}) >> 64).none());
+  static_assert((Bitset<128>({0, 127}) << 128).none());
+  static_assert((Bitset<128>({0, 127}) >> 128).none());
+}
+
+TEST(BitsetTest, GetNumWords64) {
+  static_assert(Bitset<1>::getNumWords() == 1);
+  static_assert(Bitset<32>::getNumWords() == 1);
+  static_assert(Bitset<64>::getNumWords() == 1);
+  static_assert(Bitset<65>::getNumWords() == 2);
+  static_assert(Bitset<96>::getNumWords() == 2);
+  static_assert(Bitset<128>::getNumWords() == 2);
+  static_assert(Bitset<129>::getNumWords() == 3);
+}
+
+TEST(BitsetTest, GetWord) {
+  // Single-word bitset.
+  constexpr auto B64 = Bitset<64>(std::array<uint64_t, 1>{0xdeadbeefcafe1234});
+  static_assert(B64.getWord(0) == 0xdeadbeefcafe1234);
+
+  // Multi-word bitset.
+  constexpr auto B128 = Bitset<128>(
+      std::array<uint64_t, 2>{0x1111222233334444, 0xaaaabbbbccccdddd});
+  static_assert(B128.getWord(0) == 0x1111222233334444);
+  static_assert(B128.getWord(1) == 0xaaaabbbbccccdddd);
+
+  // Partial last word — high bits should be masked off.
+  constexpr auto B96 = Bitset<96>(
+      std::array<uint64_t, 2>{0xffffffffffffffff, 0xffffffffffffffff});
+  static_assert(B96.getWord(0) == 0xffffffffffffffff);
+  // Only lower 32 bits.
+  static_assert(B96.getWord(1) == 0x00000000ffffffff);
+
+  // Empty bitset.
+  static_assert(Bitset<64>().getWord(0) == 0);
+  static_assert(Bitset<128>().getWord(0) == 0);
+  static_assert(Bitset<128>().getWord(1) == 0);
+}
+
+TEST(BitsetTest, FindLastSet) {
+  // Empty bitset returns -1.
+  static_assert(Bitset<64>().findLastSet() == -1);
+  static_assert(Bitset<128>().findLastSet() == -1);
+
+  // Single bit set.
+  static_assert(Bitset<64>({0}).findLastSet() == 0);
+  static_assert(Bitset<64>({63}).findLastSet() == 63);
+  static_assert(Bitset<64>({31}).findLastSet() == 31);
+  static_assert(Bitset<128>({0}).findLastSet() == 0);
+  static_assert(Bitset<128>({64}).findLastSet() == 64);
+  static_assert(Bitset<128>({127}).findLastSet() == 127);
+
+  // Multiple bits — returns highest.
+  static_assert(Bitset<64>({0, 10, 50}).findLastSet() == 50);
+  static_assert(Bitset<128>({0, 63, 64, 100}).findLastSet() == 100);
+
+  // All bits set.
+  static_assert(Bitset<64>().set().findLastSet() == 63);
+  static_assert(Bitset<128>().set().findLastSet() == 127);
+  static_assert(Bitset<96>().set().findLastSet() == 95);
+
+  // Non-power-of-2 sizes.
+  static_assert(Bitset<33>({32}).findLastSet() == 32);
+  static_assert(Bitset<33>({0, 32}).findLastSet() == 32);
+  static_assert(Bitset<65>({64}).findLastSet() == 64);
+}
+
+TEST(BitsetTest, ShiftMultiWords) {
+  constexpr auto B192 = Bitset<192>({0, 64, 128});
+  static_assert((B192 << 1) == Bitset<192>({1, 65, 129}));
+  static_assert((B192 >> 1) == Bitset<192>({63, 127}));
+  static_assert((B192 << 64) == Bitset<192>({64, 128}));
+  static_assert((B192 >> 64) == Bitset<192>({0, 64}));
+  static_assert((Bitset<192>({63, 127}) << 1) == Bitset<192>({64, 128}));
+  static_assert((Bitset<192>({64, 128}) >> 1) == Bitset<192>({63, 127}));
+}
+
+TEST(BitsetTest, ShiftBoundaryBitShifts) {
+  static_assert((Bitset<128>({1}) << 63) == Bitset<128>({64}));
+  static_assert((Bitset<128>({64}) >> 63) == Bitset<128>({1}));
+  static_assert((Bitset<192>({1, 65}) << 63) == Bitset<192>({64, 128}));
+  // Shift by NumBits - 1.
+  static_assert((Bitset<64>({0}) << 63) == Bitset<64>({63}));
+  static_assert((Bitset<64>({63}) >> 63) == Bitset<64>({0}));
+  static_assert((Bitset<33>({0}) << 32) == Bitset<33>({32}));
+  // Full-width shift of a fully-set bitset loses exactly one bit.
+  static_assert((Bitset<128>().set() << 1).count() == 127);
+  static_assert((Bitset<128>().set() >> 1).count() == 127);
+  static_assert((Bitset<100>().set() >> 1).count() == 99);
+}
+
+TEST(BitsetTest, ShiftExcessAmount) {
+  static_assert((Bitset<64>().set() << 65).none());
+  static_assert((Bitset<64>().set() >> 200).none());
+  static_assert((Bitset<33>({0, 10, 32}) << 1000).none());
+  static_assert((Bitset<128>({0, 127}) >> 1000).none());
+  static_assert((Bitset<192>().set() << 193).none());
+}
+
+TEST(BitsetTest, ShiftAssignReturnsReference) {
+  constexpr Bitset<64> L = [] {
+    Bitset<64> X({0});
+    (X <<= 3) <<= 2;
+    return X;
+  }();
+  static_assert(L == Bitset<64>({5}));
+
+  constexpr Bitset<128> R = [] {
+    Bitset<128> X({100});
+    (X >>= 30) >>= 10;
+    return X;
+  }();
+  static_assert(R == Bitset<128>({60}));
+}
+
+TEST(BitsetTest, GetWordConsistencyWithTest) {
+  // For every set bit, getWord must report it in the expected 64-bit word.
+  constexpr auto B100 = Bitset<100>({0, 50, 64, 99});
+  static_assert((B100.getWord(0) & 1) != 0);
+  static_assert((B100.getWord(0) & (uint64_t(1) << 50)) != 0);
+  static_assert((B100.getWord(1) & 1) != 0);
+  static_assert((B100.getWord(1) & (uint64_t(1) << 35)) != 0);
+}
+
+TEST(BitsetTest, GetWordAfterMutation) {
+  // getWord reflects subsequent set / shift.
+  constexpr auto B = [] {
+    Bitset<128> X;
+    X.set(5).set(70);
+    return X;
+  }();
+  static_assert(B.getWord(0) == (uint64_t(1) << 5));
+  static_assert(B.getWord(1) == (uint64_t(1) << 6));
+
+  constexpr auto Shifted = Bitset<128>({5}) << 64;
+  static_assert(Shifted.getWord(0) == 0);
+  static_assert(Shifted.getWord(1) == (uint64_t(1) << 5));
+}
+
+TEST(BitsetTest, GetNumWordsMoreWidths) {
+  static_assert(Bitset<2>::getNumWords() == 1);
+  static_assert(Bitset<192>::getNumWords() == 3);
+  static_assert(Bitset<193>::getNumWords() == 4);
+  static_assert(Bitset<256>::getNumWords() == 4);
+}
+
+TEST(BitsetTest, FindLastSetSmallWidths) {
+  static_assert(Bitset<1>().findLastSet() == -1);
+  static_assert(Bitset<1>({0}).findLastSet() == 0);
+  static_assert(Bitset<2>({0, 1}).findLastSet() == 1);
+  static_assert(Bitset<32>({31}).findLastSet() == 31);
+  static_assert(Bitset<32>().set().findLastSet() == 31);
+}
+
+TEST(BitsetTest, FindLastSetMultiWordScan) {
+  static_assert(Bitset<192>({70}).findLastSet() == 70);
+  static_assert(Bitset<192>({64, 70, 127}).findLastSet() == 127);
+  static_assert(Bitset<192>({3}).findLastSet() == 3);
+  static_assert(Bitset<100>({99}).findLastSet() == 99);
+}
+
+TEST(BitsetTest, FindLastSetAfterMutation) {
+  constexpr auto A = Bitset<128>({0, 50, 100}).reset(100);
+  static_assert(A.findLastSet() == 50);
+
+  constexpr auto B = Bitset<64>({10}) << 20;
+  static_assert(B.findLastSet() == 30);
+
+  constexpr auto C = Bitset<64>({63}) >> 10;
+  static_assert(C.findLastSet() == 53);
+
+  constexpr auto D = Bitset<64>({63}) << 1;
+  static_assert(D.findLastSet() == -1);
+}
+
 } // namespace

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants