Files
conflict-set/ConflictSet.cpp
2024-01-18 17:26:55 -08:00

802 lines
24 KiB
C++

#include "ConflictSet.h"
#include <cassert>
#include <compare>
#include <memory>
#include <random>
#include <span>
#include <string_view>
#include <utility>
#include <vector>
#define SHOW_PRIORITY 0
#define DEBUG 0
using Key = ConflictSet::Key;
static auto operator<=>(const Key &lhs, const Key &rhs) {
const int minLen = std::min(lhs.len, rhs.len);
const int c = memcmp(lhs.p, rhs.p, minLen);
return c != 0 ? c <=> 0 : lhs.len <=> rhs.len;
}
// ==================== BEGIN ARENA IMPL ====================
/// Group allocations with similar lifetimes to amortize the cost of malloc/free
struct Arena {
explicit Arena(int initialSize = 0);
/// O(log n) in the number of allocations
~Arena();
struct ArenaImpl;
Arena(const Arena &) = delete;
Arena &operator=(const Arena &) = delete;
Arena(Arena &&other) noexcept;
Arena &operator=(Arena &&other) noexcept;
private:
ArenaImpl *impl = nullptr;
friend void *operator new(size_t size, std::align_val_t align, Arena &arena);
};
inline void operator delete(void *, std::align_val_t, Arena &) {}
void *operator new(size_t size, std::align_val_t align, Arena &arena);
void *operator new(size_t size, std::align_val_t align, Arena *arena) = delete;
inline void operator delete(void *, Arena &) {}
inline void *operator new(size_t size, Arena &arena) {
return operator new(size, std::align_val_t(alignof(std::max_align_t)), arena);
}
inline void *operator new(size_t size, Arena *arena) = delete;
inline void operator delete[](void *, Arena &) {}
inline void *operator new[](size_t size, Arena &arena) {
return operator new(size, arena);
}
inline void *operator new[](size_t size, Arena *arena) = delete;
inline void operator delete[](void *, std::align_val_t, Arena &) {}
inline void *operator new[](size_t size, std::align_val_t align, Arena &arena) {
return operator new(size, align, arena);
}
inline void *operator new[](size_t size, std::align_val_t align,
Arena *arena) = delete;
/// align must be a power of two
template <class T> T *align_up(T *t, size_t align) {
auto unaligned = uintptr_t(t);
auto aligned = (unaligned + align - 1) & ~(align - 1);
return reinterpret_cast<T *>(reinterpret_cast<char *>(t) + aligned -
unaligned);
}
/// align must be a power of two
constexpr inline int align_up(uint32_t unaligned, uint32_t align) {
return (unaligned + align - 1) & ~(align - 1);
}
/// Returns the smallest power of two >= x
constexpr inline uint32_t nextPowerOfTwo(uint32_t x) {
return x <= 1 ? 1 : 1 << (32 - __builtin_clz(x - 1));
}
/// \private
struct Arena::ArenaImpl {
Arena::ArenaImpl *prev;
int capacity;
int used;
uint8_t *begin() { return reinterpret_cast<uint8_t *>(this + 1); }
};
static_assert(sizeof(Arena::ArenaImpl) == 16);
static_assert(alignof(Arena::ArenaImpl) == 8);
Arena::Arena(int initialSize) : impl(nullptr) {
if (initialSize > 0) {
auto allocationSize = align_up(initialSize + sizeof(ArenaImpl), 16);
impl = (Arena::ArenaImpl *)malloc(allocationSize);
impl->prev = nullptr;
impl->capacity = allocationSize - sizeof(ArenaImpl);
impl->used = 0;
}
}
namespace {
void onDestroy(Arena::ArenaImpl *impl) {
while (impl) {
auto *prev = impl->prev;
free(impl);
impl = prev;
}
}
} // namespace
Arena::Arena(Arena &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
Arena &Arena::operator=(Arena &&other) noexcept {
onDestroy(impl);
impl = std::exchange(other.impl, nullptr);
return *this;
}
Arena::~Arena() { onDestroy(impl); }
void *operator new(size_t size, std::align_val_t align, Arena &arena) {
int64_t aligned_size = size + size_t(align) - 1;
if (arena.impl == nullptr ||
(arena.impl->capacity - arena.impl->used) < aligned_size) {
auto allocationSize = align_up(
sizeof(Arena::ArenaImpl) +
std::max<int>(aligned_size,
(arena.impl ? std::max<int>(sizeof(Arena::ArenaImpl),
arena.impl->capacity * 2)
: 0)),
16);
auto *impl = (Arena::ArenaImpl *)malloc(allocationSize);
impl->prev = arena.impl;
impl->capacity = allocationSize - sizeof(Arena::ArenaImpl);
impl->used = 0;
arena.impl = impl;
}
auto *result =
align_up(arena.impl->begin() + arena.impl->used, size_t(align));
auto usedDelta = (result - arena.impl->begin()) + size - arena.impl->used;
arena.impl->used += usedDelta;
return result;
}
/// STL-friendly allocator using an arena
template <class T> struct ArenaAlloc {
typedef T value_type;
ArenaAlloc() = delete;
explicit ArenaAlloc(Arena *arena) : arena(arena) {}
Arena *arena;
template <class U> constexpr ArenaAlloc(const ArenaAlloc<U> &other) noexcept {
arena = other.arena;
}
[[nodiscard]] T *allocate(size_t n) {
if (n > 0xfffffffffffffffful / sizeof(T)) { // NOLINT
fprintf(stderr, "Requested bad alloc! sizeof(T): %zu, n: %zu\n",
sizeof(T), n); // NOLINT
fflush(stderr);
abort();
}
return static_cast<T *>((void *)new (std::align_val_t(alignof(T)), *arena)
uint8_t[n * sizeof(T)]); // NOLINT
}
void deallocate(T *, size_t) noexcept {}
private:
};
template <class T, class U>
bool operator==(const ArenaAlloc<T> &lhs, const ArenaAlloc<U> &rhs) {
return lhs.arena == rhs.arena;
}
template <class T, class U>
bool operator!=(const ArenaAlloc<T> &lhs, const ArenaAlloc<U> &rhs) {
return !(lhs == rhs);
}
// ==================== END ARENA IMPL ====================
namespace {
// A node in the tree representing write conflict history. This tree maintains
// several invariants:
// 1. BST invariant: all keys in the tree rooted at the left child of a node
// compare less than that node's key, and all keys in the tree rooted at the
// right child of a node compare greater than that node's key.
// 2. Heap invariant: the priority of a node is >= all the priorities
// of its children (transitively)
// 3. Max invariant: `maxVersion` is the max among all values of `pointVersion`
// and `beyondVersion` for this node and its children (transitively)
// 4. The lowest key (an empty byte sequence) is always physically present in
// the tree so that "last less than or equal" queries are always well-defined.
// Logically, the contents of the tree represent a "range map" where all of the
// infinitely many points in the key space are associated with a writeVersion.
// If a point is physically present in the tree, then its writeVersion is its
// node's `pointVersion`. Otherwise, its writeVersion is the `rangeVersion` of
// the node with the last key less than point.
struct Node {
// See "Max invariant" above
int64_t maxVersion;
// The write version of the point in the key space represented by this node's
// key
int64_t pointVersion;
// The write version of the range immediately after this node's key, until
// just before the next key in the tree. I.e. (this key, next key)
int64_t rangeVersion;
// child[0] is the left child or nullptr. child[1] is the right child or
// nullptr
Node *child[2];
// The parent of this node in the tree, or nullptr if this node is the root
Node *parent;
// As a treap, this tree satisfies the heap invariant on each node's priority
uint32_t priority;
// The length of this node's key
int len;
// The contents of this node's key
// uint8_t[len];
auto operator<=>(const ConflictSet::Key &other) const {
const int minLen = std::min<int>(len, other.len);
const int c = memcmp(this + 1, other.p, minLen);
return c != 0 ? c <=> 0 : len <=> other.len;
}
};
// TODO: use a better prng. This is technically vulnerable to a
// denial-of-service attack that can make conflict-checking linear in the
// number of nodes in the tree.
thread_local uint32_t gSeed = 1013904223L;
uint32_t fastRand() {
auto result = gSeed;
gSeed = gSeed * 1664525L + 1013904223L;
return result;
}
// Note: `rangeVersion` is left uninitialized.
Node *createNode(const Key &key, Node *parent, int64_t pointVersion) {
assert(key.len <= std::numeric_limits<int>::max());
Node *result = (Node *)malloc(sizeof(Node) + key.len);
result->maxVersion = pointVersion;
result->pointVersion = pointVersion;
result->child[0] = nullptr;
result->child[1] = nullptr;
result->parent = parent;
result->priority = fastRand();
#if SHOW_PRIORITY
result->priority &= 0xff;
#endif
result->len = key.len;
memcpy(result + 1, key.p, key.len);
return result;
}
void destroyNode(Node *node) {
assert(node->child[0] == nullptr);
assert(node->child[1] == nullptr);
free(node);
}
struct Iterator {
Node *node;
int cmp;
};
// Call Stepwise::step for each element of remaining until it returns true.
// Applies a permutation to remaining as a side effect.
template <class Stepwise>
void runInterleaved(std::span<Stepwise> remaining, int stepLimit = -1) {
while (remaining.size() > 0) {
for (int i = 0; i < int(remaining.size());) {
if (stepLimit-- == 0) {
return;
}
bool done = remaining[i].step();
if (done) {
if (i != int(remaining.size()) - 1) {
using std::swap;
swap(remaining[i], remaining.back());
}
remaining = remaining.subspan(0, remaining.size() - 1);
} else {
++i;
}
}
}
};
template <class Stepwise>
void runSequential(std::span<Stepwise> remaining, int stepLimit = -1) {
for (auto &r : remaining) {
if (stepLimit-- == 0) {
return;
}
while (!r.step()) {
if (stepLimit-- == 0) {
return;
}
}
}
}
struct StepwiseLastLeq {
Node *current;
Node *result;
const Key *key;
int resultC = -1;
int index;
std::strong_ordering c = std::strong_ordering::equal;
StepwiseLastLeq() {}
StepwiseLastLeq(Node *current, Node *result, const Key &key, int index)
: current(current), result(result), key(&key), index(index) {}
bool step() {
if (current == nullptr) {
return true;
}
c = *current <=> *key;
if (c == 0) {
result = current;
resultC = 0;
return true;
}
result = c < 0 ? current : result;
current = current->child[c < 0];
return false;
}
};
void lastLeqMulti(Arena &arena, Node *root, std::span<Key> keys,
Iterator *results) {
assert(std::is_sorted(keys.begin(), keys.end()));
if (keys.size() == 0) {
return;
}
auto *stepwiseLastLeqs = new (arena) StepwiseLastLeq[keys.size()];
// Descend until queries for front and back diverge
Node *current = root;
Node *resultP = nullptr;
auto stepwiseFront = StepwiseLastLeq(current, resultP, keys.front(), -1);
auto stepwiseBack = StepwiseLastLeq(current, resultP, keys.back(), -1);
for (;;) {
bool done1 = stepwiseFront.step();
bool done2 = stepwiseBack.step();
if (!done1 && !done2 && stepwiseFront.c == stepwiseBack.c) {
assert(stepwiseFront.current == stepwiseBack.current);
assert(stepwiseFront.result == stepwiseBack.result);
current = stepwiseFront.current;
resultP = stepwiseFront.result;
} else {
break;
}
}
int index = 0;
{
auto iter = stepwiseLastLeqs;
for (const auto &k : keys) {
*iter++ = StepwiseLastLeq(current, resultP, k, index++);
}
}
auto stepwiseSpan = std::span<StepwiseLastLeq>(stepwiseLastLeqs, keys.size());
runInterleaved(stepwiseSpan);
for (const auto &stepwise : stepwiseSpan) {
results[stepwise.index] = Iterator{stepwise.result, stepwise.resultC};
}
}
// Return a pointer to the node whose key immediately follows `n`'s key (if
// `dir` is false, precedes). Return nullptr if none exists.
[[maybe_unused]] Node *next(Node *n, bool dir) {
// Traverse left spine of right child (when moving right, i.e. dir = true)
if (n->child[dir]) {
n = n->child[dir];
while (n->child[!dir]) {
n = n->child[!dir];
}
} else {
// Search upward for a node such that we're the left child (when moving
// right, i.e. dir = true)
while (n->parent && n == n->parent->child[dir]) {
n = n->parent;
}
n = n->parent;
}
return n;
}
// Return a pointer to the node whose key is greatest among keys in the tree
// rooted at `n` (if dir = false, least). Return nullptr if none exists (i.e.
// `n` is null).
[[maybe_unused]] Node *extrema(Node *n, bool dir) {
if (n == nullptr) {
return nullptr;
}
while (n->child[dir] != nullptr) {
n = n->child[dir];
}
return n;
}
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node) {
struct DebugDotPrinter {
explicit DebugDotPrinter(FILE *file) : file(file) {}
void print(Node *node) {
for (int i = 0; i < 2; ++i) {
if (node->child[i] != nullptr) {
fprintf(file, " k_%.*s -> k_%.*s;\n", node->len,
(const char *)(node + 1), node->child[i]->len,
(const char *)(node->child[i] + 1));
print(node->child[i]);
} else {
fprintf(file, " k_%.*s -> null%d;\n", node->len,
(const char *)(node + 1), id);
++id;
}
}
}
int id = 0;
FILE *file;
};
fprintf(file, "digraph TreeSet {\n");
if (node != nullptr) {
DebugDotPrinter printer{file};
fprintf(file, "\n");
printer.print(node);
fprintf(file, "\n");
for (auto iter = extrema(node, false); iter != nullptr;
iter = next(iter, true)) {
fprintf(file,
" k_%.*s [label=\"k=\\\"%.*s\\\" "
#if SHOW_PRIORITY
"p=%u "
#endif
"m=%d v=%d r=%d\"];\n",
iter->len, (const char *)(iter + 1), iter->len,
(const char *)(iter + 1),
#if SHOW_PRIORITY
iter->priority,
#endif
int(iter->maxVersion), int(iter->pointVersion),
int(iter->rangeVersion));
}
for (int i = 0; i < printer.id; ++i) {
fprintf(file, " null%d [shape=point];\n", i);
}
}
fprintf(file, "}\n");
}
[[maybe_unused]] Key toKey(Arena &arena, int n) {
constexpr int kMaxLength = 4;
// TODO use arena allocation
int i = kMaxLength;
uint8_t *itoaBuf = new (arena) uint8_t[kMaxLength];
memset(itoaBuf, '0', kMaxLength);
do {
itoaBuf[--i] = "0123456789abcdef"[n % 16];
n /= 16;
} while (n);
return Key{itoaBuf, kMaxLength};
}
// Recompute maxVersion, and propagate up the tree as necessary
// TODO interleave this? Will require careful analysis for correctness, and the
// performance gains may not be worth it.
void updateMaxVersion(Node *n) {
for (;;) {
int64_t maxVersion = std::max(n->pointVersion, n->rangeVersion);
for (int i = 0; i < 2; ++i) {
maxVersion =
std::max(maxVersion, n->child[i] != nullptr ? n->child[i]->maxVersion
: maxVersion);
}
if (n->maxVersion == maxVersion) {
break;
}
n->maxVersion = maxVersion;
if (n->parent == nullptr) {
break;
}
n = n->parent;
}
}
void rotate(Node **node, bool dir) {
// diagram shown for dir == true
/* n
/
l
\
lr
*/
assert(node != nullptr);
Node *n = *node;
assert(n != nullptr);
Node *parent = n->parent;
Node *l = n->child[!dir];
assert(l != nullptr);
Node *lr = l->child[dir];
n->child[!dir] = lr;
if (lr) {
lr->parent = n;
}
l->child[dir] = n;
n->parent = l;
l->parent = parent;
*node = l;
/* l
\
n
/
lr
*/
updateMaxVersion(n);
updateMaxVersion(l);
}
void checkParentPointers(Node *node, bool &success) {
for (int i = 0; i < 2; ++i) {
if (node->child[i] != nullptr) {
if (node->child[i]->parent != node) {
fprintf(stderr, "%.*s child %d has parent pointer %p. Expected %p\n",
node->len, (const char *)(node + 1), i,
(void *)node->child[i]->parent, (void *)node);
}
checkParentPointers(node->child[i], success);
}
}
}
int64_t checkMaxVersion(Node *node, bool &success) {
int64_t expected = std::max(node->pointVersion, node->rangeVersion);
for (int i = 0; i < 2; ++i) {
if (node->child[i] != nullptr) {
expected = std::max(expected, checkMaxVersion(node->child[i], success));
}
}
if (node->maxVersion != expected) {
fprintf(stderr, "%.*s has max version %d. Expected %d\n", node->len,
(const char *)(node + 1), int(node->maxVersion), int(expected));
}
success = false;
return expected;
}
bool checkInvariants(Node *node) {
bool success = true;
// Check bst invariant
Arena arena;
std::vector<std::string_view, ArenaAlloc<std::string_view>> keys{
ArenaAlloc<std::string_view>(&arena)};
for (auto iter = extrema(node, false); iter != nullptr;
iter = next(iter, true)) {
keys.push_back(std::string_view((char *)(iter + 1), iter->len));
for (int i = 0; i < 2; ++i) {
if (iter->child[i] != nullptr) {
if (iter->priority < iter->child[i]->priority) {
fprintf(stderr, "%.*s has priority < its child %.*s\n", iter->len,
(const char *)(iter + 1), iter->child[i]->len,
(const char *)(iter->child[i] + 1));
success = false;
}
}
}
}
assert(std::is_sorted(keys.begin(), keys.end()));
checkMaxVersion(node, success);
checkParentPointers(node, success);
// TODO Compare logical contents of map with
// reference implementation
return success;
}
} // namespace
struct ConflictSet::Impl {
Node *root;
int64_t oldestVersion;
explicit Impl(int64_t oldestVersion) noexcept
: root(createNode({nullptr, 0}, nullptr, oldestVersion)),
oldestVersion(oldestVersion) {
root->rangeVersion = oldestVersion;
}
void check(const ReadRange *reads, Result *results, int count) const {
Arena arena;
auto *iters = new (arena) Iterator[count];
auto *begins = new (arena) Key[count];
for (int i = 0; i < count; ++i) {
begins[i] = reads[i].begin;
}
lastLeqMulti(arena, root, std::span<Key>(begins, count), iters);
// TODO check non-singleton reads lol
for (int i = 0; i < count; ++i) {
assert(reads[i].end.len == 0);
assert(iters[i].node != nullptr);
if ((iters[i].cmp == 0
? iters[i].node->pointVersion
: iters[i].node->rangeVersion) > reads[i].readVersion) {
results[i] = ConflictSet::Conflict;
}
}
}
struct StepwiseInsert {
// Search phase state. After this phase, the heap invariant may be violated
// for (*current)->parent.
Node **current;
Node *parent;
const Key *key;
int64_t writeVersion;
StepwiseInsert() {}
StepwiseInsert(Node **root, const Key &key, int64_t writeVersion)
: current(root), parent(nullptr), key(&key),
writeVersion(writeVersion) {}
bool step() {
#if DEBUG
fprintf(stderr, "Step insert of %.*s. At node: %.*s\n", key->len, key->p,
(*current) ? (*current)->len : 7,
(*current) ? (const char *)((*current) + 1) : "nullptr");
#endif
if (*current == nullptr) {
auto *newNode = createNode(*key, parent, writeVersion);
*current = newNode;
// We could interleave the iteration in next, but we'd need a careful
// analysis for correctness and it's unlikely to be worthwhile.
auto *prev = ::next(newNode, false);
assert(prev != nullptr);
assert(prev->rangeVersion <= writeVersion);
newNode->rangeVersion = prev->rangeVersion;
return true;
} else {
// This is the key optimization - setting the max version on the way
// down the search path so we only have to do one traversal.
(*current)->maxVersion = std::max((*current)->maxVersion, writeVersion);
auto c = *key <=> **current;
if (c == 0) {
(*current)->pointVersion = writeVersion;
return true;
}
parent = *current;
current = &((*current)->child[c > 0]);
}
return false;
}
};
void addWrites(const WriteRange *writes, int count) {
Arena arena;
auto *stepwiseInserts = new (arena) StepwiseInsert[count];
for (int i = 0; i < count; ++i) {
// TODO handle non-singleton writes lol
assert(writes[i].end.len == 0);
stepwiseInserts[i] =
StepwiseInsert{&root, writes[i].begin, writes[i].writeVersion};
}
// TODO Descend until queries for front and back diverge
// Mitigate potential n^2 behavior of insertion by shuffling the insertion
// order. Not sure how this interacts with interleaved insertion but it's
// probably fine.
// TODO better/faster RNG?
std::mt19937 g(fastRand());
std::shuffle(stepwiseInserts, stepwiseInserts + count, g);
runInterleaved(std::span<StepwiseInsert>(stepwiseInserts, count));
std::vector<Node *, ArenaAlloc<Node *>> workList{
ArenaAlloc<Node *>(&arena)};
workList.reserve(count);
for (int i = 0; i < count; ++i) {
workList.push_back(*stepwiseInserts[i].current);
}
while (!workList.empty()) {
Node *n = workList.back();
workList.pop_back();
#if DEBUG
fprintf(stderr, "\tcheck heap invariant %.*s\n", n->len,
(const char *)(n + 1));
#endif
if (n->parent == nullptr) {
continue;
}
const bool dir = n == n->parent->child[1];
assert(dir || n == n->parent->child[0]);
// p is the address of the pointer to n->parent in the tree
Node **p = n->parent->parent == nullptr
? &root
: &n->parent->parent
->child[n->parent->parent->child[1] == n->parent];
assert(*p == n->parent);
if (n->parent->priority < n->priority) {
#if DEBUG
fprintf(stderr, "\trotate %.*s %s\n", n->len, (const char *)(n + 1),
!dir ? "right" : "left");
#endif
rotate(p, !dir);
workList.push_back(*p);
assert((*p)->child[!dir] != nullptr);
auto *lr = (*p)->child[!dir]->child[dir];
if (lr != nullptr) {
workList.push_back(lr);
}
}
}
}
void setOldestVersion(int64_t oldestVersion) {
assert(oldestVersion > this->oldestVersion);
this->oldestVersion = oldestVersion;
}
~Impl() {
Arena arena;
std::vector<Node *, ArenaAlloc<Node *>> toFree{ArenaAlloc<Node *>(&arena)};
if (root != nullptr) {
toFree.push_back(root);
}
while (toFree.size() > 0) {
Node *n = toFree.back();
toFree.pop_back();
for (int i = 0; i < 2; ++i) {
auto *c = std::exchange(n->child[i], nullptr);
if (c != nullptr) {
toFree.push_back(c);
}
}
destroyNode(n);
}
}
};
void ConflictSet::check(const ReadRange *reads, Result *results,
int count) const {
return impl->check(reads, results, count);
}
void ConflictSet::addWrites(const WriteRange *writes, int count) {
return impl->addWrites(writes, count);
}
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
return impl->setOldestVersion(oldestVersion);
}
ConflictSet::ConflictSet(int64_t oldestVersion)
: impl(new(malloc(sizeof(Impl))) Impl{oldestVersion}) {}
ConflictSet::~ConflictSet() {
impl->~Impl();
free(impl);
}
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
impl = std::exchange(other.impl, nullptr);
return *this;
}
#ifdef ENABLE_TESTS
int main(void) {
int64_t writeVersion = 0;
ConflictSet::Impl cs{writeVersion};
constexpr int kNumKeys = 5;
ConflictSet::WriteRange write[kNumKeys];
Arena arena;
for (int i = 0; i < kNumKeys; ++i) {
write[i].begin = toKey(arena, i);
write[i].end.len = 0;
write[i].writeVersion = ++writeVersion;
}
cs.addWrites(write, kNumKeys);
debugPrintDot(stdout, cs.root);
checkInvariants(cs.root);
}
#endif