[ADT] Bitset: add shift operators, word accessors, and etc#193400
Open
JiachenYuan wants to merge 1 commit intollvm:mainfrom
Open
[ADT] Bitset: add shift operators, word accessors, and etc#193400JiachenYuan wants to merge 1 commit intollvm:mainfrom
JiachenYuan wants to merge 1 commit intollvm:mainfrom
Conversation
Member
|
@llvm/pr-subscribers-llvm-adt Author: Jiachen Yuan (JiachenYuan) ChangesThis PR is split out from #191757 per reviewer request. It has the following changes to
A follow-up PR will use these to re-implement The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases. Full diff: https://github.com/llvm/llvm-project/pull/193400.diff 2 Files Affected:
diff --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h
index 9dc0f24b1d9f5..3cb2b7d28d83b 100644
--- a/llvm/include/llvm/ADT/Bitset.h
+++ b/llvm/include/llvm/ADT/Bitset.h
@@ -51,8 +51,9 @@ template <unsigned NumBits> class Bitset {
constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; }
-protected:
- constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
+public:
+ explicit constexpr Bitset(
+ const std::array<uint64_t, (NumBits + 63) / 64> &B) {
if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
for (size_t I = 0; I != B.size(); ++I)
Bits[I] = B[I];
@@ -70,8 +71,6 @@ template <unsigned NumBits> class Bitset {
}
maskLastWord();
}
-
-public:
constexpr Bitset() = default;
constexpr Bitset(std::initializer_list<unsigned> Init) {
for (auto I : Init)
@@ -194,6 +193,92 @@ template <unsigned NumBits> class Bitset {
}
return false;
}
+
+ constexpr Bitset &operator<<=(unsigned N) {
+ if (N == 0)
+ return *this;
+ if (N >= NumBits) {
+ return *this = Bitset();
+ }
+ const unsigned WordShift = N / BitwordBits;
+ const unsigned BitShift = N % BitwordBits;
+ if (BitShift == 0) {
+ for (int I = NumWords - 1; I >= static_cast<int>(WordShift); --I)
+ Bits[I] = Bits[I - WordShift];
+ } else {
+ const unsigned CarryShift = BitwordBits - BitShift;
+ for (int I = NumWords - 1; I > static_cast<int>(WordShift); --I) {
+ Bits[I] = (Bits[I - WordShift] << BitShift) |
+ (Bits[I - WordShift - 1] >> CarryShift);
+ }
+ Bits[WordShift] = Bits[0] << BitShift;
+ }
+ for (unsigned I = 0; I < WordShift; ++I)
+ Bits[I] = 0;
+ maskLastWord();
+ return *this;
+ }
+
+ constexpr Bitset operator<<(unsigned N) const {
+ Bitset Result(*this);
+ Result <<= N;
+ return Result;
+ }
+
+ constexpr Bitset &operator>>=(unsigned N) {
+ if (N == 0)
+ return *this;
+ if (N >= NumBits) {
+ return *this = Bitset();
+ }
+ const unsigned WordShift = N / BitwordBits;
+ const unsigned BitShift = N % BitwordBits;
+ if (BitShift == 0) {
+ for (unsigned I = 0; I < NumWords - WordShift; ++I)
+ Bits[I] = Bits[I + WordShift];
+ } else {
+ const unsigned CarryShift = BitwordBits - BitShift;
+ for (unsigned I = 0; I < NumWords - WordShift - 1; ++I) {
+ Bits[I] = (Bits[I + WordShift] >> BitShift) |
+ (Bits[I + WordShift + 1] << CarryShift);
+ }
+ Bits[NumWords - WordShift - 1] = Bits[NumWords - 1] >> BitShift;
+ }
+ for (unsigned I = NumWords - WordShift; I < NumWords; ++I)
+ Bits[I] = 0;
+ maskLastWord();
+ return *this;
+ }
+
+ constexpr Bitset operator>>(unsigned N) const {
+ Bitset Result(*this);
+ Result >>= N;
+ return Result;
+ }
+
+ /// Return the I-th 64-bit word of the bitset, from least significant to most.
+ constexpr uint64_t getWord(unsigned I) const {
+ if constexpr (BitwordBits == 64) {
+ return Bits[I];
+ } else {
+ static_assert(BitwordBits == 32, "Unsupported word size");
+ uint64_t Lo = (2 * I < NumWords) ? Bits[2 * I] : 0;
+ uint64_t Hi = (2 * I + 1 < NumWords) ? Bits[2 * I + 1] : 0;
+ return Lo | (Hi << 32);
+ }
+ }
+
+ /// Return the index of the highest set bit, or -1 if no bits are set.
+ constexpr int findLastSet() const {
+ for (int I = NumWords - 1; I >= 0; --I)
+ if (Bits[I] != 0)
+ return I * BitwordBits +
+ (BitwordBits - 1 - countl_zero_constexpr(Bits[I]));
+ return -1;
+ }
+
+ /// Return the number of 64-bit words needed to hold all bits.
+ static constexpr unsigned getNumWords() { return (NumBits + 63) / 64; }
};
} // end namespace llvm
diff --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp
index 678197e31a379..ee3ef07d01979 100644
--- a/llvm/unittests/ADT/BitsetTest.cpp
+++ b/llvm/unittests/ADT/BitsetTest.cpp
@@ -294,4 +294,202 @@ TEST(BitsetTest, BitwiseOperators) {
TestXor128.test(127));
}
+TEST(BitsetTest, ShiftOperators) {
+ // Test left shift.
+ static_assert((Bitset<64>({0}) << 10).test(10));
+ static_assert(!(Bitset<64>({0}) << 10).test(0));
+ static_assert((Bitset<64>({63}) << 1).none());
+ static_assert((Bitset<128>({0}) << 64).test(64));
+ static_assert((Bitset<128>({63}) << 1).test(64));
+ static_assert((Bitset<128>({127}) << 1).none());
+
+ // Test right shift.
+ static_assert((Bitset<64>({10}) >> 10).test(0));
+ static_assert(!(Bitset<64>({10}) >> 10).test(10));
+ static_assert((Bitset<64>({0}) >> 1).none());
+ static_assert((Bitset<128>({64}) >> 64).test(0));
+ static_assert((Bitset<128>({64}) >> 1).test(63));
+ static_assert((Bitset<128>({0}) >> 1).none());
+
+ // Test shift by 0.
+ static_assert((Bitset<64>({10, 20}) << 0) == Bitset<64>({10, 20}));
+ static_assert((Bitset<64>({10, 20}) >> 0) == Bitset<64>({10, 20}));
+
+ // Test shift by NumBits (clears all).
+ static_assert((Bitset<64>({0, 63}) << 64).none());
+ static_assert((Bitset<64>({0, 63}) >> 64).none());
+ static_assert((Bitset<128>({0, 127}) << 128).none());
+ static_assert((Bitset<128>({0, 127}) >> 128).none());
+}
+
+TEST(BitsetTest, GetNumWords64) {
+ static_assert(Bitset<1>::getNumWords() == 1);
+ static_assert(Bitset<32>::getNumWords() == 1);
+ static_assert(Bitset<64>::getNumWords() == 1);
+ static_assert(Bitset<65>::getNumWords() == 2);
+ static_assert(Bitset<96>::getNumWords() == 2);
+ static_assert(Bitset<128>::getNumWords() == 2);
+ static_assert(Bitset<129>::getNumWords() == 3);
+}
+
+TEST(BitsetTest, GetWord) {
+ // Single-word bitset.
+ constexpr auto B64 = Bitset<64>(std::array<uint64_t, 1>{0xdeadbeefcafe1234});
+ static_assert(B64.getWord(0) == 0xdeadbeefcafe1234);
+
+ // Multi-word bitset.
+ constexpr auto B128 = Bitset<128>(
+ std::array<uint64_t, 2>{0x1111222233334444, 0xaaaabbbbccccdddd});
+ static_assert(B128.getWord(0) == 0x1111222233334444);
+ static_assert(B128.getWord(1) == 0xaaaabbbbccccdddd);
+
+ // Partial last word — high bits should be masked off.
+ constexpr auto B96 = Bitset<96>(
+ std::array<uint64_t, 2>{0xffffffffffffffff, 0xffffffffffffffff});
+ static_assert(B96.getWord(0) == 0xffffffffffffffff);
+ // Only lower 32 bits.
+ static_assert(B96.getWord(1) == 0x00000000ffffffff);
+
+ // Empty bitset.
+ static_assert(Bitset<64>().getWord(0) == 0);
+ static_assert(Bitset<128>().getWord(0) == 0);
+ static_assert(Bitset<128>().getWord(1) == 0);
+}
+
+TEST(BitsetTest, FindLastSet) {
+ // Empty bitset returns -1.
+ static_assert(Bitset<64>().findLastSet() == -1);
+ static_assert(Bitset<128>().findLastSet() == -1);
+
+ // Single bit set.
+ static_assert(Bitset<64>({0}).findLastSet() == 0);
+ static_assert(Bitset<64>({63}).findLastSet() == 63);
+ static_assert(Bitset<64>({31}).findLastSet() == 31);
+ static_assert(Bitset<128>({0}).findLastSet() == 0);
+ static_assert(Bitset<128>({64}).findLastSet() == 64);
+ static_assert(Bitset<128>({127}).findLastSet() == 127);
+
+ // Multiple bits — returns highest.
+ static_assert(Bitset<64>({0, 10, 50}).findLastSet() == 50);
+ static_assert(Bitset<128>({0, 63, 64, 100}).findLastSet() == 100);
+
+ // All bits set.
+ static_assert(Bitset<64>().set().findLastSet() == 63);
+ static_assert(Bitset<128>().set().findLastSet() == 127);
+ static_assert(Bitset<96>().set().findLastSet() == 95);
+
+ // Non-power-of-2 sizes.
+ static_assert(Bitset<33>({32}).findLastSet() == 32);
+ static_assert(Bitset<33>({0, 32}).findLastSet() == 32);
+ static_assert(Bitset<65>({64}).findLastSet() == 64);
+}
+
+TEST(BitsetTest, ShiftMultiWords) {
+ constexpr auto B192 = Bitset<192>({0, 64, 128});
+ static_assert((B192 << 1) == Bitset<192>({1, 65, 129}));
+ static_assert((B192 >> 1) == Bitset<192>({63, 127}));
+ static_assert((B192 << 64) == Bitset<192>({64, 128}));
+ static_assert((B192 >> 64) == Bitset<192>({0, 64}));
+ static_assert((Bitset<192>({63, 127}) << 1) == Bitset<192>({64, 128}));
+ static_assert((Bitset<192>({64, 128}) >> 1) == Bitset<192>({63, 127}));
+}
+
+TEST(BitsetTest, ShiftBoundaryBitShifts) {
+ static_assert((Bitset<128>({1}) << 63) == Bitset<128>({64}));
+ static_assert((Bitset<128>({64}) >> 63) == Bitset<128>({1}));
+ static_assert((Bitset<192>({1, 65}) << 63) == Bitset<192>({64, 128}));
+ // Shift by NumBits - 1.
+ static_assert((Bitset<64>({0}) << 63) == Bitset<64>({63}));
+ static_assert((Bitset<64>({63}) >> 63) == Bitset<64>({0}));
+ static_assert((Bitset<33>({0}) << 32) == Bitset<33>({32}));
+ // Full-width shift of a fully-set bitset loses exactly one bit.
+ static_assert((Bitset<128>().set() << 1).count() == 127);
+ static_assert((Bitset<128>().set() >> 1).count() == 127);
+ static_assert((Bitset<100>().set() >> 1).count() == 99);
+}
+
+TEST(BitsetTest, ShiftExcessAmount) {
+ static_assert((Bitset<64>().set() << 65).none());
+ static_assert((Bitset<64>().set() >> 200).none());
+ static_assert((Bitset<33>({0, 10, 32}) << 1000).none());
+ static_assert((Bitset<128>({0, 127}) >> 1000).none());
+ static_assert((Bitset<192>().set() << 193).none());
+}
+
+TEST(BitsetTest, ShiftAssignReturnsReference) {
+ constexpr Bitset<64> L = [] {
+ Bitset<64> X({0});
+ (X <<= 3) <<= 2;
+ return X;
+ }();
+ static_assert(L == Bitset<64>({5}));
+
+ constexpr Bitset<128> R = [] {
+ Bitset<128> X({100});
+ (X >>= 30) >>= 10;
+ return X;
+ }();
+ static_assert(R == Bitset<128>({60}));
+}
+
+TEST(BitsetTest, GetWordConsistencyWithTest) {
+ // For every set bit, getWord must report it in the expected 64-bit word.
+ constexpr auto B100 = Bitset<100>({0, 50, 64, 99});
+ static_assert((B100.getWord(0) & 1) != 0);
+ static_assert((B100.getWord(0) & (uint64_t(1) << 50)) != 0);
+ static_assert((B100.getWord(1) & 1) != 0);
+ static_assert((B100.getWord(1) & (uint64_t(1) << 35)) != 0);
+}
+
+TEST(BitsetTest, GetWordAfterMutation) {
+ // getWord reflects subsequent set / shift.
+ constexpr auto B = [] {
+ Bitset<128> X;
+ X.set(5).set(70);
+ return X;
+ }();
+ static_assert(B.getWord(0) == (uint64_t(1) << 5));
+ static_assert(B.getWord(1) == (uint64_t(1) << 6));
+
+ constexpr auto Shifted = Bitset<128>({5}) << 64;
+ static_assert(Shifted.getWord(0) == 0);
+ static_assert(Shifted.getWord(1) == (uint64_t(1) << 5));
+}
+
+TEST(BitsetTest, GetNumWordsMoreWidths) {
+ static_assert(Bitset<2>::getNumWords() == 1);
+ static_assert(Bitset<192>::getNumWords() == 3);
+ static_assert(Bitset<193>::getNumWords() == 4);
+ static_assert(Bitset<256>::getNumWords() == 4);
+}
+
+TEST(BitsetTest, FindLastSetSmallWidths) {
+ static_assert(Bitset<1>().findLastSet() == -1);
+ static_assert(Bitset<1>({0}).findLastSet() == 0);
+ static_assert(Bitset<2>({0, 1}).findLastSet() == 1);
+ static_assert(Bitset<32>({31}).findLastSet() == 31);
+ static_assert(Bitset<32>().set().findLastSet() == 31);
+}
+
+TEST(BitsetTest, FindLastSetMultiWordScan) {
+ static_assert(Bitset<192>({70}).findLastSet() == 70);
+ static_assert(Bitset<192>({64, 70, 127}).findLastSet() == 127);
+ static_assert(Bitset<192>({3}).findLastSet() == 3);
+ static_assert(Bitset<100>({99}).findLastSet() == 99);
+}
+
+TEST(BitsetTest, FindLastSetAfterMutation) {
+ constexpr auto A = Bitset<128>({0, 50, 100}).reset(100);
+ static_assert(A.findLastSet() == 50);
+
+ constexpr auto B = Bitset<64>({10}) << 20;
+ static_assert(B.findLastSet() == 30);
+
+ constexpr auto C = Bitset<64>({63}) >> 10;
+ static_assert(C.findLastSet() == 53);
+
+ constexpr auto D = Bitset<64>({63}) << 1;
+ static_assert(D.findLastSet() == -1);
+}
+
} // namespace
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR is split out from #191757 per reviewer request. It has the following changes to
llvm::Bitset<N>:operator<</<<=/>>/>>=,getNumWords(),getWord(), andfindLastSet().std::array<>constructor from protected to public and explicit.A follow-up PR will use these to re-implement
LaneBitmaskas allvm::Bitsetwrapper.The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases.