Skip to content

Commit 488d408

Browse files
authored
Merge pull request #306 from elbeno/zero-bitset
🐛 Ensure bitset works with zero size
2 parents 9a6297d + faea9ef commit 488d408

2 files changed

Lines changed: 78 additions & 57 deletions

File tree

include/stdx/bitset.hpp

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class bitset {
5757

5858
[[nodiscard]] friend constexpr auto operator==(bitset const &lhs,
5959
bitset const &rhs) -> bool {
60+
if constexpr (N == 0) {
61+
return true;
62+
}
6063
for (auto i = std::size_t{}; i < storage_size - 1; ++i) {
6164
if (lhs.storage[i] != rhs.storage[i]) {
6265
return false;
@@ -191,6 +194,9 @@ class bitset {
191194
#endif
192195

193196
template <typename T> [[nodiscard]] constexpr auto to() const -> T {
197+
if constexpr (N == 0) {
198+
return {};
199+
}
194200
using U = underlying_type_t<T>;
195201
static_assert(
196202
unsigned_integral<U>,
@@ -329,34 +335,27 @@ class bitset {
329335
return *this;
330336
}
331337

332-
[[nodiscard]] constexpr auto all() const -> bool {
338+
[[nodiscard]] constexpr auto count() const -> std::size_t {
339+
if constexpr (N == 0) {
340+
return {};
341+
}
342+
auto n = static_cast<std::size_t>(popcount(highbits()));
333343
for (auto i = std::size_t{}; i < storage_size - 1; ++i) {
334-
if (storage[i] != allbits) {
335-
return false;
336-
}
344+
n += static_cast<std::size_t>(popcount(storage[i]));
337345
}
338-
return highbits() == lastmask;
346+
return n;
347+
}
348+
349+
[[nodiscard]] constexpr auto all() const -> bool {
350+
return count() == size();
339351
}
340352

341353
[[nodiscard]] constexpr auto any() const -> bool {
342-
for (auto i = std::size_t{}; i < storage_size - 1; ++i) {
343-
if (storage[i] != 0) {
344-
return true;
345-
}
346-
}
347-
return highbits() != 0;
354+
return count() != std::size_t{};
348355
}
349356

350357
[[nodiscard]] constexpr auto none() const -> bool { return not any(); }
351358

352-
[[nodiscard]] constexpr auto count() const -> std::size_t {
353-
std::size_t n{};
354-
for (auto i = std::size_t{}; i < storage_size - 1; ++i) {
355-
n += static_cast<std::size_t>(popcount(storage[i]));
356-
}
357-
return n + static_cast<std::size_t>(popcount(highbits()));
358-
}
359-
360359
[[nodiscard]] constexpr auto lowest_unset() const {
361360
std::size_t i = 0;
362361
for (auto e : storage) {
@@ -397,54 +396,58 @@ class bitset {
397396
}
398397

399398
constexpr auto operator<<=(std::size_t pos) LIFETIMEBOUND->bitset & {
400-
auto dst = storage_size - 1;
401-
auto const start = dst - (pos / storage_elem_size);
402-
pos %= storage_elem_size;
403-
404-
if (pos == 0) {
405-
for (auto i = start; i > std::size_t{}; --i) {
406-
storage[dst] = storage[i];
407-
--dst;
399+
if constexpr (N != 0) {
400+
auto dst = storage_size - 1;
401+
auto const start = dst - (pos / storage_elem_size);
402+
pos %= storage_elem_size;
403+
404+
if (pos == 0) {
405+
for (auto i = start; i > std::size_t{}; --i) {
406+
storage[dst] = storage[i];
407+
--dst;
408+
}
409+
} else {
410+
auto const borrow_shift = storage_elem_size - pos;
411+
for (auto i = start; i > std::size_t{}; --i) {
412+
storage[dst] = static_cast<elem_t>(storage[i] << pos);
413+
storage[dst] |=
414+
static_cast<elem_t>(storage[i - 1] >> borrow_shift);
415+
--dst;
416+
}
408417
}
409-
} else {
410-
auto const borrow_shift = storage_elem_size - pos;
411-
for (auto i = start; i > std::size_t{}; --i) {
412-
storage[dst] = static_cast<elem_t>(storage[i] << pos);
413-
storage[dst] |=
414-
static_cast<elem_t>(storage[i - 1] >> borrow_shift);
415-
--dst;
418+
storage[dst] = static_cast<elem_t>(storage.front() << pos);
419+
while (dst > std::size_t{}) {
420+
storage[--dst] = 0;
416421
}
417422
}
418-
storage[dst] = static_cast<elem_t>(storage.front() << pos);
419-
while (dst > std::size_t{}) {
420-
storage[--dst] = 0;
421-
}
422423
return *this;
423424
}
424425

425426
constexpr auto operator>>=(std::size_t pos) LIFETIMEBOUND->bitset & {
426-
auto dst = std::size_t{};
427-
auto const start = pos / storage_elem_size;
428-
pos %= storage_elem_size;
429-
430-
if (pos == 0) {
431-
for (auto i = start; i < storage_size - 1; ++i) {
432-
storage[dst] = storage[i];
433-
++dst;
427+
if constexpr (N != 0) {
428+
auto dst = std::size_t{};
429+
auto const start = pos / storage_elem_size;
430+
pos %= storage_elem_size;
431+
432+
if (pos == 0) {
433+
for (auto i = start; i < storage_size - 1; ++i) {
434+
storage[dst] = storage[i];
435+
++dst;
436+
}
437+
} else {
438+
auto const borrow_shift = storage_elem_size - pos;
439+
for (auto i = start; i < storage_size - 1; ++i) {
440+
storage[dst] = static_cast<elem_t>(storage[i] >> pos);
441+
storage[dst] |=
442+
static_cast<elem_t>(storage[i + 1] << borrow_shift);
443+
++dst;
444+
}
434445
}
435-
} else {
436-
auto const borrow_shift = storage_elem_size - pos;
437-
for (auto i = start; i < storage_size - 1; ++i) {
438-
storage[dst] = static_cast<elem_t>(storage[i] >> pos);
439-
storage[dst] |=
440-
static_cast<elem_t>(storage[i + 1] << borrow_shift);
441-
++dst;
446+
storage[dst++] = static_cast<elem_t>(storage.back() >> pos);
447+
while (dst < storage_size) {
448+
storage[dst++] = 0;
442449
}
443450
}
444-
storage[dst++] = static_cast<elem_t>(storage.back() >> pos);
445-
while (dst < storage_size) {
446-
storage[dst++] = 0;
447-
}
448451
return *this;
449452
}
450453
};

test/bitset.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,21 @@ TEST_CASE("construct with a ct_string", "[bitset]") {
499499
stdx::bitset<9ul, std::uint16_t>{0b101010101ul});
500500
}
501501
#endif
502+
503+
TEST_CASE("zero size bitset", "[bitset]") {
504+
constexpr auto bs1 = stdx::bitset<0>{};
505+
STATIC_REQUIRE(bs1.count() == 0u);
506+
STATIC_REQUIRE(bs1.all());
507+
STATIC_REQUIRE(not bs1.any());
508+
STATIC_REQUIRE(bs1.none());
509+
510+
constexpr auto bs2 = stdx::bitset<0>{};
511+
STATIC_REQUIRE(bs1 == bs2);
512+
STATIC_REQUIRE(bs2.to<std::uint8_t>() == 0);
513+
514+
auto bs3 = stdx::bitset<0>{};
515+
bs3 <<= 1;
516+
CHECK(bs3.to<std::uint8_t>() == 0);
517+
bs3 >>= 1;
518+
CHECK(bs3.to<std::uint8_t>() == 0);
519+
}

0 commit comments

Comments
 (0)