Files
conflict-set/Internal.h
Andrew Noyes f2b5e9b0bf Change max key len to 8, update corpus
Now that we don't have a fixed buffer reserved for partial key bytes,
there's nothing (obvious) that makes testing short versus long keys much
different. maybeDecreaseCapacity is an exception, and we'll write some
tests covering that manually.
2024-03-18 11:55:43 -07:00

705 lines
21 KiB
C++

#pragma once
#include "ConflictSet.h"
#include <bit>
#include <cassert>
#include <compare>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <inttypes.h>
#include <map>
#include <set>
#include <span>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <callgrind.h>
#define DEBUG_VERBOSE 0
#define SHOW_MEMORY 0
[[nodiscard]] inline auto
operator<=>(const std::span<const uint8_t> &lhs,
const std::span<const uint8_t> &rhs) noexcept {
int cl = std::min<int>(lhs.size(), rhs.size());
if (cl > 0) {
if (auto c = memcmp(lhs.data(), rhs.data(), cl) <=> 0; c != 0) {
return c;
}
}
return lhs.size() <=> rhs.size();
}
// This header contains code that we want to reuse outside of ConflictSet.cpp or
// want to exclude from coverage since it's only testing related.
// GCOVR_EXCL_START
#if SHOW_MEMORY
inline int64_t mallocBytes = 0;
inline int64_t peakMallocBytes = 0;
constexpr auto kIntMallocHeaderSize = 16;
#endif
// malloc that aborts on OOM and thus always returns a non-null pointer. Must be
// paired with `safe_free`.
__attribute__((always_inline)) inline void *safe_malloc(size_t s) {
#if SHOW_MEMORY
mallocBytes += s;
if (mallocBytes > peakMallocBytes) {
peakMallocBytes = mallocBytes;
}
void *p = malloc(s + kIntMallocHeaderSize);
if (p == nullptr) {
abort();
}
memcpy(p, &s, sizeof(s));
return (char *)p + kIntMallocHeaderSize;
#else
void *p = malloc(s);
if (p == nullptr) {
abort();
}
return p;
#endif
}
// Must be paired with `safe_malloc`.
//
// There's nothing safer about this than free. Only called safe_free for
// symmetry with safe_malloc.
__attribute__((always_inline)) inline void safe_free(void *p) {
#if SHOW_MEMORY
size_t s;
memcpy(&s, (char *)p - kIntMallocHeaderSize, sizeof(s));
mallocBytes -= s;
free((char *)p - kIntMallocHeaderSize);
#else
free(p);
#endif
}
// ==================== BEGIN ARENA IMPL ====================
/// Group allocations with similar lifetimes to amortize the cost of malloc/free
struct Arena {
explicit Arena(int initialSize = 0);
/// O(log n) in the number of allocations
~Arena();
struct ArenaImpl;
Arena(const Arena &) = delete;
Arena &operator=(const Arena &) = delete;
Arena(Arena &&other) noexcept;
Arena &operator=(Arena &&other) noexcept;
ArenaImpl *impl = nullptr;
};
[[maybe_unused]] inline void operator delete(void *, std::align_val_t,
Arena &) {}
inline void *operator new(size_t size, std::align_val_t align, Arena &arena);
void *operator new(size_t size, std::align_val_t align, Arena *arena) = delete;
[[maybe_unused]] inline void operator delete(void *, Arena &) {}
inline void *operator new(size_t size, Arena &arena) {
return operator new(size, std::align_val_t(alignof(std::max_align_t)), arena);
}
inline void *operator new(size_t size, Arena *arena) = delete;
[[maybe_unused]] inline void operator delete[](void *, Arena &) {}
inline void *operator new[](size_t size, Arena &arena) {
return operator new(size, arena);
}
inline void *operator new[](size_t size, Arena *arena) = delete;
[[maybe_unused]] inline void operator delete[](void *, std::align_val_t,
Arena &) {}
inline void *operator new[](size_t size, std::align_val_t align, Arena &arena) {
return operator new(size, align, arena);
}
inline void *operator new[](size_t size, std::align_val_t align,
Arena *arena) = delete;
/// align must be a power of two
template <class T> T *align_up(T *t, size_t align) {
auto unaligned = uintptr_t(t);
auto aligned = (unaligned + align - 1) & ~(align - 1);
return reinterpret_cast<T *>(reinterpret_cast<char *>(t) + aligned -
unaligned);
}
/// align must be a power of two
constexpr inline int align_up(uint32_t unaligned, uint32_t align) {
return (unaligned + align - 1) & ~(align - 1);
}
/// Returns the smallest power of two >= x
[[maybe_unused]] constexpr inline uint32_t nextPowerOfTwo(uint32_t x) {
return x <= 1 ? 1 : 1 << (32 - std::countl_zero(x - 1));
}
struct Arena::ArenaImpl {
Arena::ArenaImpl *prev;
int capacity;
int used;
uint8_t *begin() { return reinterpret_cast<uint8_t *>(this + 1); }
};
static_assert(sizeof(Arena::ArenaImpl) == 16);
static_assert(alignof(Arena::ArenaImpl) == 8);
inline Arena::Arena(int initialSize) : impl(nullptr) {
if (initialSize > 0) {
auto allocationSize = align_up(initialSize + sizeof(ArenaImpl), 16);
impl = (Arena::ArenaImpl *)safe_malloc(allocationSize);
impl->prev = nullptr;
impl->capacity = allocationSize - sizeof(ArenaImpl);
impl->used = 0;
}
}
inline void onDestroy(Arena::ArenaImpl *impl) {
while (impl) {
auto *prev = impl->prev;
safe_free(impl);
impl = prev;
}
}
[[maybe_unused]] inline Arena::Arena(Arena &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
[[maybe_unused]] inline Arena &Arena::operator=(Arena &&other) noexcept {
onDestroy(impl);
impl = std::exchange(other.impl, nullptr);
return *this;
}
inline Arena::~Arena() { onDestroy(impl); }
inline void *operator new(size_t size, std::align_val_t align, Arena &arena) {
int64_t aligned_size = size + size_t(align) - 1;
if (arena.impl == nullptr ||
(arena.impl->capacity - arena.impl->used) < aligned_size) {
auto allocationSize = align_up(
sizeof(Arena::ArenaImpl) +
std::max<int>(aligned_size,
(arena.impl ? std::max<int>(sizeof(Arena::ArenaImpl),
arena.impl->capacity * 2)
: 0)),
16);
auto *impl = (Arena::ArenaImpl *)safe_malloc(allocationSize);
impl->prev = arena.impl;
impl->capacity = allocationSize - sizeof(Arena::ArenaImpl);
impl->used = 0;
arena.impl = impl;
}
auto *result =
align_up(arena.impl->begin() + arena.impl->used, size_t(align));
auto usedDelta = (result - arena.impl->begin()) + size - arena.impl->used;
arena.impl->used += usedDelta;
return result;
}
/// STL-friendly allocator using an arena
template <class T> struct ArenaAlloc {
typedef T value_type;
ArenaAlloc() = delete;
explicit ArenaAlloc(Arena *arena) : arena(arena) {}
Arena *arena;
template <class U> constexpr ArenaAlloc(const ArenaAlloc<U> &other) noexcept {
arena = other.arena;
}
[[nodiscard]] T *allocate(size_t n) {
if (n > 0xfffffffffffffffful / sizeof(T)) { // NOLINT
__builtin_unreachable();
}
return static_cast<T *>((void *)new (std::align_val_t(alignof(T)), *arena)
uint8_t[n * sizeof(T)]); // NOLINT
}
void deallocate(T *, size_t) noexcept {}
};
template <class T> using Vector = std::vector<T, ArenaAlloc<T>>;
template <class T> auto vector(Arena &arena) {
return Vector<T>(ArenaAlloc<T>(&arena));
}
template <class T, class C> using Set = std::set<T, C, ArenaAlloc<T>>;
template <class T, class C = std::less<T>> auto set(Arena &arena) {
return Set<T, C>(ArenaAlloc<T>(&arena));
}
template <class T> struct MyHash;
template <class T> struct MyHash<T *> {
size_t operator()(const T *t) const noexcept {
size_t result;
memcpy(&result, &t, sizeof(result));
return result;
}
};
template <class T>
using HashSet =
std::unordered_set<T, MyHash<T>, std::equal_to<T>, ArenaAlloc<T>>;
template <class T> auto hashSet(Arena &arena) {
return HashSet<T>(ArenaAlloc<T>(&arena));
}
template <class T, class U>
bool operator==(const ArenaAlloc<T> &lhs, const ArenaAlloc<U> &rhs) {
return lhs.arena == rhs.arena;
}
template <class T, class U>
bool operator!=(const ArenaAlloc<T> &lhs, const ArenaAlloc<U> &rhs) {
return !(lhs == rhs);
}
// ==================== END ARENA IMPL ====================
// ==================== BEGIN ARBITRARY IMPL ====================
/// Think of `Arbitrary` as an attacker-controlled random number generator.
/// Usually you want your random number generator to be fair, so that you can
/// sensibly analyze probabilities. E.g. The analysis that shows that quicksort
/// is expected O(n log n) with a random pivot relies on the random pivot being
/// selected uniformly from a fair distribution.
///
/// Other times you want your randomness to be diabolically unfair, like when
/// looking for bugs and fuzzing. The random-number-like interface is still
/// convenient here, but you can potentially get much better coverage by
/// allowing the possibility of e.g. flipping heads 100 times in a row.
///
/// When it runs out of entropy, it always returns 0.
struct Arbitrary {
Arbitrary() = default;
explicit Arbitrary(std::span<const uint8_t> bytecode) : bytecode(bytecode) {}
/// Draws an arbitrary uint32_t
uint32_t next() { return consume<4>(); }
/// Draws an arbitrary element from [0, s)
uint32_t bounded(uint32_t s);
/// Fill `bytes` with `size` arbitrary bytes
void randomBytes(uint8_t *bytes, int size) {
int toFill = std::min<int>(size, bytecode.size());
if (toFill > 0) {
memcpy(bytes, bytecode.data(), toFill);
}
bytecode = bytecode.subspan(toFill, bytecode.size() - toFill);
memset(bytes + toFill, 0, size - toFill);
}
/// Fill `bytes` with `size` random hex bytes
void randomHex(uint8_t *bytes, int size) {
for (int i = 0; i < size;) {
uint8_t arbitrary = consume<1>();
bytes[i++] = "0123456789abcdef"[arbitrary & 0xf];
arbitrary >>= 4;
if (i < size) {
bytes[i++] = "0123456789abcdef"[arbitrary & 0xf];
}
}
}
template <class T, class = std::enable_if_t<std::is_trivially_copyable_v<T>>>
T randT() {
T t;
randomBytes((uint8_t *)&t, sizeof(T));
return t;
}
bool hasEntropy() const { return bytecode.size() != 0; }
private:
uint8_t consumeByte() {
if (bytecode.size() == 0) {
return 0;
}
auto result = bytecode[0];
bytecode = bytecode.subspan(1, bytecode.size() - 1);
return result;
}
template <int kBytes> uint32_t consume() {
uint32_t result = 0;
static_assert(kBytes <= 4);
for (int i = 0; i < kBytes; ++i) {
result <<= 8;
result |= consumeByte();
}
return result;
}
std::span<const uint8_t> bytecode;
};
inline uint32_t Arbitrary::bounded(uint32_t s) {
if (s == 1) {
return 0;
}
switch (32 - std::countl_zero(s - 1)) {
case 1:
case 2:
case 3:
case 4:
case 5:
case 6:
case 7:
case 8:
return consume<1>() % s;
case 9:
case 10:
case 11:
case 12:
case 13:
case 14:
case 15:
case 16:
return consume<2>() % s;
case 17:
case 18:
case 19:
case 20:
case 21:
case 22:
case 23:
case 24:
return consume<3>() % s;
default:
return consume<4>() % s;
}
}
// ==================== END ARBITRARY IMPL ====================
// ==================== BEGIN UTILITIES IMPL ====================
// Call Stepwise::step for each element of remaining until it returns true.
// Applies a permutation to `remaining` as a side effect.
template <class Stepwise> void runInterleaved(std::span<Stepwise> remaining) {
while (remaining.size() > 0) {
for (int i = 0; i < int(remaining.size());) {
bool done = remaining[i].step();
if (done) {
if (i != int(remaining.size()) - 1) {
using std::swap;
swap(remaining[i], remaining.back());
}
remaining = remaining.subspan(0, remaining.size() - 1);
} else {
++i;
}
}
}
};
template <class Stepwise> void runSequential(std::span<Stepwise> remaining) {
for (auto &r : remaining) {
while (!r.step()) {
}
}
}
struct ReferenceImpl {
explicit ReferenceImpl(int64_t oldestVersion) : oldestVersion(oldestVersion) {
writeVersionMap[""] = oldestVersion;
}
void check(const ConflictSet::ReadRange *reads, ConflictSet::Result *results,
int count) const {
for (int i = 0; i < count; ++i) {
if (reads[i].readVersion < oldestVersion) {
results[i] = ConflictSet::TooOld;
continue;
}
auto begin =
std::string((const char *)reads[i].begin.p, reads[i].begin.len);
auto end =
reads[i].end.len == 0
? begin + std::string("\x00", 1)
: std::string((const char *)reads[i].end.p, reads[i].end.len);
int64_t maxVersion = oldestVersion;
for (auto iter = --writeVersionMap.upper_bound(begin),
endIter = writeVersionMap.lower_bound(end);
iter != endIter; ++iter) {
maxVersion = std::max(maxVersion, iter->second);
}
results[i] = maxVersion > reads[i].readVersion ? ConflictSet::Conflict
: ConflictSet::Commit;
}
}
void addWrites(const ConflictSet::WriteRange *writes, int count,
int64_t writeVersion) {
for (int i = 0; i < count; ++i) {
auto begin =
std::string((const char *)writes[i].begin.p, writes[i].begin.len);
auto end =
writes[i].end.len == 0
? begin + std::string("\x00", 1)
: std::string((const char *)writes[i].end.p, writes[i].end.len);
auto prevVersion = (--writeVersionMap.upper_bound(end))->second;
for (auto iter = writeVersionMap.lower_bound(begin),
endIter = writeVersionMap.lower_bound(end);
iter != endIter;) {
iter = writeVersionMap.erase(iter);
}
writeVersionMap[begin] = writeVersion;
writeVersionMap[end] = prevVersion;
}
}
void setOldestVersion(int64_t oldestVersion) {
assert(oldestVersion >= oldestVersion);
this->oldestVersion = oldestVersion;
}
int64_t oldestVersion;
std::map<std::string, int64_t> writeVersionMap;
};
using Key = ConflictSet::Key;
inline Key operator"" _s(const char *str, size_t size) {
return {reinterpret_cast<const uint8_t *>(str), int(size)};
}
[[maybe_unused]] static Key toKey(Arena &arena, int n) {
uint8_t *buf = new (arena) uint8_t[sizeof(n)];
memcpy(buf, &n, sizeof(n));
return Key{buf, sizeof(n)};
}
[[maybe_unused]] static Key toKeyAfter(Arena &arena, int n) {
uint8_t *buf = new (arena) uint8_t[sizeof(n) + 1];
memcpy(buf, &n, sizeof(n));
buf[sizeof(n)] = 0;
return Key{buf, sizeof(n) + 1};
}
inline std::string printable(std::string_view key) {
std::string result;
for (uint8_t c : key) {
result += "x";
result += "0123456789abcdef"[c / 16];
result += "0123456789abcdef"[c % 16];
}
return result;
}
inline std::string printable(const Key &key) {
return printable(std::string_view((const char *)key.p, key.len));
}
inline std::string printable(std::span<const uint8_t> key) {
return printable(std::string_view((const char *)key.data(), key.size()));
}
inline const char *resultToStr(ConflictSet::Result r) {
switch (r) {
case ConflictSet::Commit:
return "commit";
case ConflictSet::Conflict:
return "conflict";
case ConflictSet::TooOld:
return "too old";
}
abort();
}
namespace {
template <class ConflictSetImpl> struct TestDriver {
Arbitrary arbitrary;
explicit TestDriver(const uint8_t *data, size_t size)
: arbitrary({data, size}) {}
int64_t writeVersion = 0;
int64_t oldestVersion = 0;
ConflictSetImpl cs{oldestVersion};
ReferenceImpl refImpl{oldestVersion};
constexpr static auto kMaxKeyLen = 8;
bool ok = true;
// Call until it returns true, for "done". Check internal invariants etc
// between calls to next.
bool next() {
if (!arbitrary.hasEntropy()) {
return true;
}
Arena arena;
{
int numPointWrites = arbitrary.bounded(100);
int numRangeWrites = arbitrary.bounded(100);
int64_t v = ++writeVersion;
auto *writes =
new (arena) ConflictSet::WriteRange[numPointWrites + numRangeWrites];
auto keys = set<std::string_view>(arena);
while (int(keys.size()) < numPointWrites + numRangeWrites * 2) {
if (!arbitrary.hasEntropy()) {
return true;
}
int keyLen = arbitrary.bounded(kMaxKeyLen);
auto *begin = new (arena) uint8_t[keyLen];
arbitrary.randomBytes(begin, keyLen);
keys.insert(std::string_view((const char *)begin, keyLen));
}
auto iter = keys.begin();
int i = 0;
for (int pointsRemaining = numPointWrites,
rangesRemaining = numRangeWrites;
pointsRemaining > 0 || rangesRemaining > 0; ++i) {
bool pointRead = pointsRemaining > 0 && rangesRemaining > 0
? bool(arbitrary.bounded(2))
: pointsRemaining > 0;
if (pointRead) {
assert(pointsRemaining > 0);
writes[i].begin.p = (const uint8_t *)iter->data();
writes[i].begin.len = iter->size();
writes[i].end.len = 0;
++iter;
--pointsRemaining;
} else {
assert(rangesRemaining > 0);
writes[i].begin.p = (const uint8_t *)iter->data();
writes[i].begin.len = iter->size();
++iter;
writes[i].end.p = (const uint8_t *)iter->data();
writes[i].end.len = iter->size();
++iter;
--rangesRemaining;
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
if (writes[i].end.len == 0) {
fprintf(stderr, "Write: {%s} -> %" PRId64 "\n",
printable(writes[i].begin).c_str(), writeVersion);
} else {
fprintf(stderr, "Write: [%s, %s) -> %" PRId64 "\n",
printable(writes[i].begin).c_str(),
printable(writes[i].end).c_str(), writeVersion);
}
#endif
}
assert(iter == keys.end());
assert(i == numPointWrites + numRangeWrites);
CALLGRIND_START_INSTRUMENTATION;
cs.addWrites(writes, numPointWrites + numRangeWrites, v);
CALLGRIND_STOP_INSTRUMENTATION;
refImpl.addWrites(writes, numPointWrites + numRangeWrites, v);
oldestVersion = std::max<int64_t>(writeVersion - arbitrary.bounded(10),
oldestVersion);
cs.setOldestVersion(oldestVersion);
refImpl.setOldestVersion(oldestVersion);
}
{
int numPointReads = arbitrary.bounded(100);
int numRangeReads = arbitrary.bounded(100);
int64_t v = std::max<int64_t>(writeVersion - arbitrary.bounded(10), 0);
auto *reads =
new (arena) ConflictSet::ReadRange[numPointReads + numRangeReads];
auto keys = set<std::string_view>(arena);
while (int(keys.size()) < numPointReads + numRangeReads * 2) {
if (!arbitrary.hasEntropy()) {
return true;
}
int keyLen = arbitrary.bounded(kMaxKeyLen);
auto *begin = new (arena) uint8_t[keyLen];
arbitrary.randomBytes(begin, keyLen);
keys.insert(std::string_view((const char *)begin, keyLen));
}
auto iter = keys.begin();
int i = 0;
for (int pointsRemaining = numPointReads, rangesRemaining = numRangeReads;
pointsRemaining > 0 || rangesRemaining > 0; ++i) {
bool pointRead = pointsRemaining > 0 && rangesRemaining > 0
? bool(arbitrary.bounded(2))
: pointsRemaining > 0;
if (pointRead) {
assert(pointsRemaining > 0);
reads[i].begin.p = (const uint8_t *)iter->data();
reads[i].begin.len = iter->size();
reads[i].end.len = 0;
++iter;
--pointsRemaining;
} else {
assert(rangesRemaining > 0);
reads[i].begin.p = (const uint8_t *)iter->data();
reads[i].begin.len = iter->size();
++iter;
reads[i].end.p = (const uint8_t *)iter->data();
reads[i].end.len = iter->size();
++iter;
--rangesRemaining;
}
reads[i].readVersion = v;
#if DEBUG_VERBOSE && !defined(NDEBUG)
if (reads[i].end.len == 0) {
fprintf(stderr, "Read: {%s} @ %d\n",
printable(reads[i].begin).c_str(), int(reads[i].readVersion));
} else {
fprintf(stderr, "Read: [%s, %s) @ %d\n",
printable(reads[i].begin).c_str(),
printable(reads[i].end).c_str(), int(reads[i].readVersion));
}
#endif
}
assert(iter == keys.end());
assert(i == numPointReads + numRangeReads);
auto *results1 =
new (arena) ConflictSet::Result[numPointReads + numRangeReads];
auto *results2 =
new (arena) ConflictSet::Result[numPointReads + numRangeReads];
CALLGRIND_START_INSTRUMENTATION;
cs.check(reads, results1, numPointReads + numRangeReads);
CALLGRIND_STOP_INSTRUMENTATION;
refImpl.check(reads, results2, numPointReads + numRangeReads);
for (int i = 0; i < numPointReads + numRangeReads; ++i) {
if (results1[i] != results2[i]) {
if (reads[i].end.len == 0) {
fprintf(stderr,
"Expected %s, got %s for read of {%s} at version %" PRId64
"\n",
resultToStr(results2[i]), resultToStr(results1[i]),
printable(reads[i].begin).c_str(), reads[i].readVersion);
} else {
fprintf(
stderr,
"Expected %s, got %s for read of [%s, %s) at version %" PRId64
"\n",
resultToStr(results2[i]), resultToStr(results1[i]),
printable(reads[i].begin).c_str(),
printable(reads[i].end).c_str(), reads[i].readVersion);
}
ok = false;
return true;
}
}
}
return false;
}
};
} // namespace
// GCOVR_EXCL_STOP