Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 93 additions & 4 deletions llvm/include/llvm/ADT/Bitset.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "llvm/ADT/bit.h"
#include <array>
#include <cassert>
#include <climits>
#include <cstdint>

Expand Down Expand Up @@ -51,8 +52,9 @@ template <unsigned NumBits> class Bitset {

constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; }

protected:
constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
public:
explicit constexpr Bitset(
const std::array<uint64_t, (NumBits + 63) / 64> &B) {
if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
for (size_t I = 0; I != B.size(); ++I)
Bits[I] = B[I];
Expand All @@ -70,8 +72,6 @@ template <unsigned NumBits> class Bitset {
}
maskLastWord();
}

public:
constexpr Bitset() = default;
constexpr Bitset(std::initializer_list<unsigned> Init) {
for (auto I : Init)
Expand Down Expand Up @@ -194,6 +194,95 @@ template <unsigned NumBits> class Bitset {
}
return false;
}

constexpr Bitset &operator<<=(unsigned N) {
if (N == 0)
return *this;
if (N >= NumBits) {
return *this = Bitset();
}
const unsigned WordShift = N / BitwordBits;
const unsigned BitShift = N % BitwordBits;
if (BitShift == 0) {
for (int I = NumWords - 1; I >= static_cast<int>(WordShift); --I)
Bits[I] = Bits[I - WordShift];
} else {
const unsigned CarryShift = BitwordBits - BitShift;
for (int I = NumWords - 1; I > static_cast<int>(WordShift); --I) {
Bits[I] = (Bits[I - WordShift] << BitShift) |
(Bits[I - WordShift - 1] >> CarryShift);
}
Bits[WordShift] = Bits[0] << BitShift;
}
for (unsigned I = 0; I < WordShift; ++I)
Bits[I] = 0;
maskLastWord();
return *this;
}

constexpr Bitset operator<<(unsigned N) const {
Bitset Result(*this);
Result <<= N;
return Result;
}

constexpr Bitset &operator>>=(unsigned N) {
if (N == 0)
return *this;
if (N >= NumBits) {
return *this = Bitset();
}
const unsigned WordShift = N / BitwordBits;
const unsigned BitShift = N % BitwordBits;
if (BitShift == 0) {
for (unsigned I = 0; I < NumWords - WordShift; ++I)
Bits[I] = Bits[I + WordShift];
} else {
const unsigned CarryShift = BitwordBits - BitShift;
for (unsigned I = 0; I < NumWords - WordShift - 1; ++I) {
Bits[I] = (Bits[I + WordShift] >> BitShift) |
(Bits[I + WordShift + 1] << CarryShift);
}
Bits[NumWords - WordShift - 1] = Bits[NumWords - 1] >> BitShift;
}
for (unsigned I = NumWords - WordShift; I < NumWords; ++I)
Bits[I] = 0;
maskLastWord();
return *this;
}

constexpr Bitset operator>>(unsigned N) const {
Bitset Result(*this);
Result >>= N;
return Result;
}

/// Return the I-th 64-bit word of the bitset, from least significant to most.
constexpr uint64_t getWord64(unsigned I) const {
assert(I < getNumWords64() && "Word index out of range");
if constexpr (BitwordBits == 64) {
return Bits[I];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also do some check or workaround for index-out-of-bounds, like you did for the 32bit case

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks for catching this! I added an assertion to check index-out-of-bounds.

} else {
static_assert(BitwordBits == 32, "Unsupported word size");
// When Bitword is 32-bit, for a valid I, the first word is always
// present, but the second may not be present.
uint64_t Lo = Bits[2 * I];
uint64_t Hi = (2 * I + 1 < NumWords) ? Bits[2 * I + 1] : 0;
return Lo | (Hi << 32);
}
}

/// Return the index of the highest set bit, or -1 if no bits are set.
constexpr int findLastSet() const {
for (int I = NumWords - 1; I >= 0; --I)
if (Bits[I] != 0)
return I * BitwordBits +
(BitwordBits - 1 - countl_zero_constexpr(Bits[I]));
return -1;
}

/// Return the number of 64-bit words needed to hold all bits.
static constexpr unsigned getNumWords64() { return (NumBits + 63) / 64; }
};

} // end namespace llvm
Expand Down
198 changes: 198 additions & 0 deletions llvm/unittests/ADT/BitsetTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,202 @@ TEST(BitsetTest, BitwiseOperators) {
TestXor128.test(127));
}

TEST(BitsetTest, ShiftOperators) {
// Test left shift.
static_assert((Bitset<64>({0}) << 10).test(10));
static_assert(!(Bitset<64>({0}) << 10).test(0));
static_assert((Bitset<64>({63}) << 1).none());
static_assert((Bitset<128>({0}) << 64).test(64));
static_assert((Bitset<128>({63}) << 1).test(64));
static_assert((Bitset<128>({127}) << 1).none());

// Test right shift.
static_assert((Bitset<64>({10}) >> 10).test(0));
static_assert(!(Bitset<64>({10}) >> 10).test(10));
static_assert((Bitset<64>({0}) >> 1).none());
static_assert((Bitset<128>({64}) >> 64).test(0));
static_assert((Bitset<128>({64}) >> 1).test(63));
static_assert((Bitset<128>({0}) >> 1).none());

// Test shift by 0.
static_assert((Bitset<64>({10, 20}) << 0) == Bitset<64>({10, 20}));
static_assert((Bitset<64>({10, 20}) >> 0) == Bitset<64>({10, 20}));

// Test shift by NumBits (clears all).
static_assert((Bitset<64>({0, 63}) << 64).none());
static_assert((Bitset<64>({0, 63}) >> 64).none());
static_assert((Bitset<128>({0, 127}) << 128).none());
static_assert((Bitset<128>({0, 127}) >> 128).none());
}

TEST(BitsetTest, GetNumWords64) {
static_assert(Bitset<1>::getNumWords64() == 1);
static_assert(Bitset<32>::getNumWords64() == 1);
static_assert(Bitset<64>::getNumWords64() == 1);
static_assert(Bitset<65>::getNumWords64() == 2);
static_assert(Bitset<96>::getNumWords64() == 2);
static_assert(Bitset<128>::getNumWords64() == 2);
static_assert(Bitset<129>::getNumWords64() == 3);
}

TEST(BitsetTest, GetWord) {
// Single-word bitset.
constexpr auto B64 = Bitset<64>(std::array<uint64_t, 1>{0xdeadbeefcafe1234});
static_assert(B64.getWord64(0) == 0xdeadbeefcafe1234);

// Multi-word bitset.
constexpr auto B128 = Bitset<128>(
std::array<uint64_t, 2>{0x1111222233334444, 0xaaaabbbbccccdddd});
static_assert(B128.getWord64(0) == 0x1111222233334444);
static_assert(B128.getWord64(1) == 0xaaaabbbbccccdddd);

// Partial last word — high bits should be masked off.
constexpr auto B96 = Bitset<96>(
std::array<uint64_t, 2>{0xffffffffffffffff, 0xffffffffffffffff});
static_assert(B96.getWord64(0) == 0xffffffffffffffff);
// Only lower 32 bits.
static_assert(B96.getWord64(1) == 0x00000000ffffffff);

// Empty bitset.
static_assert(Bitset<64>().getWord64(0) == 0);
static_assert(Bitset<128>().getWord64(0) == 0);
static_assert(Bitset<128>().getWord64(1) == 0);
}

TEST(BitsetTest, FindLastSet) {
// Empty bitset returns -1.
static_assert(Bitset<64>().findLastSet() == -1);
static_assert(Bitset<128>().findLastSet() == -1);

// Single bit set.
static_assert(Bitset<64>({0}).findLastSet() == 0);
static_assert(Bitset<64>({63}).findLastSet() == 63);
static_assert(Bitset<64>({31}).findLastSet() == 31);
static_assert(Bitset<128>({0}).findLastSet() == 0);
static_assert(Bitset<128>({64}).findLastSet() == 64);
static_assert(Bitset<128>({127}).findLastSet() == 127);

// Multiple bits — returns highest.
static_assert(Bitset<64>({0, 10, 50}).findLastSet() == 50);
static_assert(Bitset<128>({0, 63, 64, 100}).findLastSet() == 100);

// All bits set.
static_assert(Bitset<64>().set().findLastSet() == 63);
static_assert(Bitset<128>().set().findLastSet() == 127);
static_assert(Bitset<96>().set().findLastSet() == 95);

// Non-power-of-2 sizes.
static_assert(Bitset<33>({32}).findLastSet() == 32);
static_assert(Bitset<33>({0, 32}).findLastSet() == 32);
static_assert(Bitset<65>({64}).findLastSet() == 64);
}

TEST(BitsetTest, ShiftMultiWords) {
constexpr auto B192 = Bitset<192>({0, 64, 128});
static_assert((B192 << 1) == Bitset<192>({1, 65, 129}));
static_assert((B192 >> 1) == Bitset<192>({63, 127}));
static_assert((B192 << 64) == Bitset<192>({64, 128}));
static_assert((B192 >> 64) == Bitset<192>({0, 64}));
static_assert((Bitset<192>({63, 127}) << 1) == Bitset<192>({64, 128}));
static_assert((Bitset<192>({64, 128}) >> 1) == Bitset<192>({63, 127}));
}

TEST(BitsetTest, ShiftBoundaryBitShifts) {
static_assert((Bitset<128>({1}) << 63) == Bitset<128>({64}));
static_assert((Bitset<128>({64}) >> 63) == Bitset<128>({1}));
static_assert((Bitset<192>({1, 65}) << 63) == Bitset<192>({64, 128}));
// Shift by NumBits - 1.
static_assert((Bitset<64>({0}) << 63) == Bitset<64>({63}));
static_assert((Bitset<64>({63}) >> 63) == Bitset<64>({0}));
static_assert((Bitset<33>({0}) << 32) == Bitset<33>({32}));
// Full-width shift of a fully-set bitset loses exactly one bit.
static_assert((Bitset<128>().set() << 1).count() == 127);
static_assert((Bitset<128>().set() >> 1).count() == 127);
static_assert((Bitset<100>().set() >> 1).count() == 99);
}

TEST(BitsetTest, ShiftExcessAmount) {
static_assert((Bitset<64>().set() << 65).none());
static_assert((Bitset<64>().set() >> 200).none());
static_assert((Bitset<33>({0, 10, 32}) << 1000).none());
static_assert((Bitset<128>({0, 127}) >> 1000).none());
static_assert((Bitset<192>().set() << 193).none());
}

TEST(BitsetTest, ShiftAssignReturnsReference) {
constexpr Bitset<64> L = [] {
Bitset<64> X({0});
(X <<= 3) <<= 2;
return X;
}();
static_assert(L == Bitset<64>({5}));

constexpr Bitset<128> R = [] {
Bitset<128> X({100});
(X >>= 30) >>= 10;
return X;
}();
static_assert(R == Bitset<128>({60}));
}

TEST(BitsetTest, GetWordConsistencyWithTest) {
// For every set bit, getWord must report it in the expected 64-bit word.
constexpr auto B100 = Bitset<100>({0, 50, 64, 99});
static_assert((B100.getWord64(0) & 1) != 0);
static_assert((B100.getWord64(0) & (uint64_t(1) << 50)) != 0);
static_assert((B100.getWord64(1) & 1) != 0);
static_assert((B100.getWord64(1) & (uint64_t(1) << 35)) != 0);
}

TEST(BitsetTest, GetWordAfterMutation) {
// getWord reflects subsequent set / shift.
constexpr auto B = [] {
Bitset<128> X;
X.set(5).set(70);
return X;
}();
static_assert(B.getWord64(0) == (uint64_t(1) << 5));
static_assert(B.getWord64(1) == (uint64_t(1) << 6));

constexpr auto Shifted = Bitset<128>({5}) << 64;
static_assert(Shifted.getWord64(0) == 0);
static_assert(Shifted.getWord64(1) == (uint64_t(1) << 5));
}

TEST(BitsetTest, GetNumWordsMoreWidths) {
static_assert(Bitset<2>::getNumWords64() == 1);
static_assert(Bitset<192>::getNumWords64() == 3);
static_assert(Bitset<193>::getNumWords64() == 4);
static_assert(Bitset<256>::getNumWords64() == 4);
}

TEST(BitsetTest, FindLastSetSmallWidths) {
static_assert(Bitset<1>().findLastSet() == -1);
static_assert(Bitset<1>({0}).findLastSet() == 0);
static_assert(Bitset<2>({0, 1}).findLastSet() == 1);
static_assert(Bitset<32>({31}).findLastSet() == 31);
static_assert(Bitset<32>().set().findLastSet() == 31);
}

TEST(BitsetTest, FindLastSetMultiWordScan) {
static_assert(Bitset<192>({70}).findLastSet() == 70);
static_assert(Bitset<192>({64, 70, 127}).findLastSet() == 127);
static_assert(Bitset<192>({3}).findLastSet() == 3);
static_assert(Bitset<100>({99}).findLastSet() == 99);
}

TEST(BitsetTest, FindLastSetAfterMutation) {
constexpr auto A = Bitset<128>({0, 50, 100}).reset(100);
static_assert(A.findLastSet() == 50);

constexpr auto B = Bitset<64>({10}) << 20;
static_assert(B.findLastSet() == 30);

constexpr auto C = Bitset<64>({63}) >> 10;
static_assert(C.findLastSet() == 53);

constexpr auto D = Bitset<64>({63}) << 1;
static_assert(D.findLastSet() == -1);
}

} // namespace