#include "ConflictSet.h" #include #include #include #include #include #include #include #include #define SHOW_PRIORITY 0 #define DEBUG 0 using Key = ConflictSet::Key; static auto operator<=>(const Key &lhs, const Key &rhs) { const int minLen = std::min(lhs.len, rhs.len); const int c = memcmp(lhs.p, rhs.p, minLen); return c != 0 ? c <=> 0 : lhs.len <=> rhs.len; } // ==================== BEGIN ARENA IMPL ==================== /// Group allocations with similar lifetimes to amortize the cost of malloc/free struct Arena { explicit Arena(int initialSize = 0); /// O(log n) in the number of allocations ~Arena(); struct ArenaImpl; Arena(const Arena &) = delete; Arena &operator=(const Arena &) = delete; Arena(Arena &&other) noexcept; Arena &operator=(Arena &&other) noexcept; private: ArenaImpl *impl = nullptr; friend void *operator new(size_t size, std::align_val_t align, Arena &arena); }; inline void operator delete(void *, std::align_val_t, Arena &) {} void *operator new(size_t size, std::align_val_t align, Arena &arena); void *operator new(size_t size, std::align_val_t align, Arena *arena) = delete; inline void operator delete(void *, Arena &) {} inline void *operator new(size_t size, Arena &arena) { return operator new(size, std::align_val_t(alignof(std::max_align_t)), arena); } inline void *operator new(size_t size, Arena *arena) = delete; inline void operator delete[](void *, Arena &) {} inline void *operator new[](size_t size, Arena &arena) { return operator new(size, arena); } inline void *operator new[](size_t size, Arena *arena) = delete; inline void operator delete[](void *, std::align_val_t, Arena &) {} inline void *operator new[](size_t size, std::align_val_t align, Arena &arena) { return operator new(size, align, arena); } inline void *operator new[](size_t size, std::align_val_t align, Arena *arena) = delete; /// align must be a power of two template T *align_up(T *t, size_t align) { auto unaligned = uintptr_t(t); auto aligned = (unaligned + align - 1) & ~(align - 1); return reinterpret_cast(reinterpret_cast(t) + aligned - unaligned); } /// align must be a power of two constexpr inline int align_up(uint32_t unaligned, uint32_t align) { return (unaligned + align - 1) & ~(align - 1); } /// Returns the smallest power of two >= x constexpr inline uint32_t nextPowerOfTwo(uint32_t x) { return x <= 1 ? 1 : 1 << (32 - __builtin_clz(x - 1)); } /// \private struct Arena::ArenaImpl { Arena::ArenaImpl *prev; int capacity; int used; uint8_t *begin() { return reinterpret_cast(this + 1); } }; static_assert(sizeof(Arena::ArenaImpl) == 16); static_assert(alignof(Arena::ArenaImpl) == 8); Arena::Arena(int initialSize) : impl(nullptr) { if (initialSize > 0) { auto allocationSize = align_up(initialSize + sizeof(ArenaImpl), 16); impl = (Arena::ArenaImpl *)malloc(allocationSize); impl->prev = nullptr; impl->capacity = allocationSize - sizeof(ArenaImpl); impl->used = 0; } } namespace { void onDestroy(Arena::ArenaImpl *impl) { while (impl) { auto *prev = impl->prev; free(impl); impl = prev; } } } // namespace Arena::Arena(Arena &&other) noexcept : impl(std::exchange(other.impl, nullptr)) {} Arena &Arena::operator=(Arena &&other) noexcept { onDestroy(impl); impl = std::exchange(other.impl, nullptr); return *this; } Arena::~Arena() { onDestroy(impl); } void *operator new(size_t size, std::align_val_t align, Arena &arena) { int64_t aligned_size = size + size_t(align) - 1; if (arena.impl == nullptr || (arena.impl->capacity - arena.impl->used) < aligned_size) { auto allocationSize = align_up( sizeof(Arena::ArenaImpl) + std::max(aligned_size, (arena.impl ? std::max(sizeof(Arena::ArenaImpl), arena.impl->capacity * 2) : 0)), 16); auto *impl = (Arena::ArenaImpl *)malloc(allocationSize); impl->prev = arena.impl; impl->capacity = allocationSize - sizeof(Arena::ArenaImpl); impl->used = 0; arena.impl = impl; } auto *result = align_up(arena.impl->begin() + arena.impl->used, size_t(align)); auto usedDelta = (result - arena.impl->begin()) + size - arena.impl->used; arena.impl->used += usedDelta; return result; } /// STL-friendly allocator using an arena template struct ArenaAlloc { typedef T value_type; ArenaAlloc() = delete; explicit ArenaAlloc(Arena *arena) : arena(arena) {} Arena *arena; template constexpr ArenaAlloc(const ArenaAlloc &other) noexcept { arena = other.arena; } [[nodiscard]] T *allocate(size_t n) { if (n > 0xfffffffffffffffful / sizeof(T)) { // NOLINT fprintf(stderr, "Requested bad alloc! sizeof(T): %zu, n: %zu\n", sizeof(T), n); // NOLINT fflush(stderr); abort(); } return static_cast((void *)new (std::align_val_t(alignof(T)), *arena) uint8_t[n * sizeof(T)]); // NOLINT } void deallocate(T *, size_t) noexcept {} private: }; template bool operator==(const ArenaAlloc &lhs, const ArenaAlloc &rhs) { return lhs.arena == rhs.arena; } template bool operator!=(const ArenaAlloc &lhs, const ArenaAlloc &rhs) { return !(lhs == rhs); } // ==================== END ARENA IMPL ==================== namespace { // A node in the tree representing write conflict history. This tree maintains // several invariants: // 1. BST invariant: all keys in the tree rooted at the left child of a node // compare less than that node's key, and all keys in the tree rooted at the // right child of a node compare greater than that node's key. // 2. Heap invariant: the priority of a node is >= all the priorities // of its children (transitively) // 3. Max invariant: `maxVersion` is the max among all values of `pointVersion` // and `beyondVersion` for this node and its children (transitively) // 4. The lowest key (an empty byte sequence) is always physically present in // the tree so that "last less than or equal" queries are always well-defined. // Logically, the contents of the tree represent a "range map" where all of the // infinitely many points in the key space are associated with a writeVersion. // If a point is physically present in the tree, then its writeVersion is its // node's `pointVersion`. Otherwise, its writeVersion is the `rangeVersion` of // the node with the last key less than point. struct Node { // See "Max invariant" above int64_t maxVersion; // The write version of the point in the key space represented by this node's // key int64_t pointVersion; // The write version of the range immediately after this node's key, until // just before the next key in the tree. I.e. (this key, next key) int64_t rangeVersion; // child[0] is the left child or nullptr. child[1] is the right child or // nullptr Node *child[2]; // The parent of this node in the tree, or nullptr if this node is the root Node *parent; // As a treap, this tree satisfies the heap invariant on each node's priority uint32_t priority; // The length of this node's key int len; // The contents of this node's key // uint8_t[len]; auto operator<=>(const ConflictSet::Key &other) const { const int minLen = std::min(len, other.len); const int c = memcmp(this + 1, other.p, minLen); return c != 0 ? c <=> 0 : len <=> other.len; } }; // TODO: use a better prng. This is technically vulnerable to a // denial-of-service attack that can make conflict-checking linear in the // number of nodes in the tree. thread_local uint32_t gSeed = 1013904223L; uint32_t fastRand() { auto result = gSeed; gSeed = gSeed * 1664525L + 1013904223L; return result; } // Note: `rangeVersion` is left uninitialized. Node *createNode(const Key &key, Node *parent, int64_t pointVersion) { assert(key.len <= std::numeric_limits::max()); Node *result = (Node *)malloc(sizeof(Node) + key.len); result->maxVersion = pointVersion; result->pointVersion = pointVersion; result->child[0] = nullptr; result->child[1] = nullptr; result->parent = parent; result->priority = fastRand(); #if SHOW_PRIORITY result->priority &= 0xff; #endif result->len = key.len; memcpy(result + 1, key.p, key.len); return result; } void destroyNode(Node *node) { assert(node->child[0] == nullptr); assert(node->child[1] == nullptr); free(node); } struct Iterator { Node *node; int cmp; }; // Call Stepwise::step for each element of remaining until it returns true. // Applies a permutation to remaining as a side effect. template void runInterleaved(std::span remaining, int stepLimit = -1) { while (remaining.size() > 0) { for (int i = 0; i < int(remaining.size());) { if (stepLimit-- == 0) { return; } bool done = remaining[i].step(); if (done) { if (i != int(remaining.size()) - 1) { using std::swap; swap(remaining[i], remaining.back()); } remaining = remaining.subspan(0, remaining.size() - 1); } else { ++i; } } } }; template void runSequential(std::span remaining, int stepLimit = -1) { for (auto &r : remaining) { if (stepLimit-- == 0) { return; } while (!r.step()) { if (stepLimit-- == 0) { return; } } } } struct StepwiseLastLeq { Node *current; Node *result; const Key *key; int resultC = -1; int index; std::strong_ordering c = std::strong_ordering::equal; StepwiseLastLeq() {} StepwiseLastLeq(Node *current, Node *result, const Key &key, int index) : current(current), result(result), key(&key), index(index) {} bool step() { if (current == nullptr) { return true; } c = *current <=> *key; if (c == 0) { result = current; resultC = 0; return true; } result = c < 0 ? current : result; current = current->child[c < 0]; return false; } }; void lastLeqMulti(Arena &arena, Node *root, std::span keys, Iterator *results) { assert(std::is_sorted(keys.begin(), keys.end())); if (keys.size() == 0) { return; } auto *stepwiseLastLeqs = new (arena) StepwiseLastLeq[keys.size()]; // Descend until queries for front and back diverge Node *current = root; Node *resultP = nullptr; auto stepwiseFront = StepwiseLastLeq(current, resultP, keys.front(), -1); auto stepwiseBack = StepwiseLastLeq(current, resultP, keys.back(), -1); for (;;) { bool done1 = stepwiseFront.step(); bool done2 = stepwiseBack.step(); if (!done1 && !done2 && stepwiseFront.c == stepwiseBack.c) { assert(stepwiseFront.current == stepwiseBack.current); assert(stepwiseFront.result == stepwiseBack.result); current = stepwiseFront.current; resultP = stepwiseFront.result; } else { break; } } int index = 0; { auto iter = stepwiseLastLeqs; for (const auto &k : keys) { *iter++ = StepwiseLastLeq(current, resultP, k, index++); } } auto stepwiseSpan = std::span(stepwiseLastLeqs, keys.size()); runInterleaved(stepwiseSpan); for (const auto &stepwise : stepwiseSpan) { results[stepwise.index] = Iterator{stepwise.result, stepwise.resultC}; } } // Return a pointer to the node whose key immediately follows `n`'s key (if // `dir` is false, precedes). Return nullptr if none exists. [[maybe_unused]] Node *next(Node *n, bool dir) { // Traverse left spine of right child (when moving right, i.e. dir = true) if (n->child[dir]) { n = n->child[dir]; while (n->child[!dir]) { n = n->child[!dir]; } } else { // Search upward for a node such that we're the left child (when moving // right, i.e. dir = true) while (n->parent && n == n->parent->child[dir]) { n = n->parent; } n = n->parent; } return n; } // Return a pointer to the node whose key is greatest among keys in the tree // rooted at `n` (if dir = false, least). Return nullptr if none exists (i.e. // `n` is null). [[maybe_unused]] Node *extrema(Node *n, bool dir) { if (n == nullptr) { return nullptr; } while (n->child[dir] != nullptr) { n = n->child[dir]; } return n; } [[maybe_unused]] void debugPrintDot(FILE *file, Node *node) { struct DebugDotPrinter { explicit DebugDotPrinter(FILE *file) : file(file) {} void print(Node *node) { for (int i = 0; i < 2; ++i) { if (node->child[i] != nullptr) { fprintf(file, " k_%.*s -> k_%.*s;\n", node->len, (const char *)(node + 1), node->child[i]->len, (const char *)(node->child[i] + 1)); print(node->child[i]); } else { fprintf(file, " k_%.*s -> null%d;\n", node->len, (const char *)(node + 1), id); ++id; } } } int id = 0; FILE *file; }; fprintf(file, "digraph TreeSet {\n"); if (node != nullptr) { DebugDotPrinter printer{file}; fprintf(file, "\n"); printer.print(node); fprintf(file, "\n"); for (auto iter = extrema(node, false); iter != nullptr; iter = next(iter, true)) { fprintf(file, " k_%.*s [label=\"k=\\\"%.*s\\\" " #if SHOW_PRIORITY "p=%u " #endif "m=%d v=%d r=%d\"];\n", iter->len, (const char *)(iter + 1), iter->len, (const char *)(iter + 1), #if SHOW_PRIORITY iter->priority, #endif int(iter->maxVersion), int(iter->pointVersion), int(iter->rangeVersion)); } for (int i = 0; i < printer.id; ++i) { fprintf(file, " null%d [shape=point];\n", i); } } fprintf(file, "}\n"); } [[maybe_unused]] Key toKey(Arena &arena, int n) { constexpr int kMaxLength = 4; // TODO use arena allocation int i = kMaxLength; uint8_t *itoaBuf = new (arena) uint8_t[kMaxLength]; memset(itoaBuf, '0', kMaxLength); do { itoaBuf[--i] = "0123456789abcdef"[n % 16]; n /= 16; } while (n); return Key{itoaBuf, kMaxLength}; } // Recompute maxVersion, and propagate up the tree as necessary // TODO interleave this? Will require careful analysis for correctness, and the // performance gains may not be worth it. void updateMaxVersion(Node *n) { for (;;) { int64_t maxVersion = std::max(n->pointVersion, n->rangeVersion); for (int i = 0; i < 2; ++i) { maxVersion = std::max(maxVersion, n->child[i] != nullptr ? n->child[i]->maxVersion : maxVersion); } if (n->maxVersion == maxVersion) { break; } n->maxVersion = maxVersion; if (n->parent == nullptr) { break; } n = n->parent; } } void rotate(Node **node, bool dir) { // diagram shown for dir == true /* n / l \ lr */ assert(node != nullptr); Node *n = *node; assert(n != nullptr); Node *parent = n->parent; Node *l = n->child[!dir]; assert(l != nullptr); Node *lr = l->child[dir]; n->child[!dir] = lr; if (lr) { lr->parent = n; } l->child[dir] = n; n->parent = l; l->parent = parent; *node = l; /* l \ n / lr */ updateMaxVersion(n); updateMaxVersion(l); } void checkParentPointers(Node *node, bool &success) { for (int i = 0; i < 2; ++i) { if (node->child[i] != nullptr) { if (node->child[i]->parent != node) { fprintf(stderr, "%.*s child %d has parent pointer %p. Expected %p\n", node->len, (const char *)(node + 1), i, (void *)node->child[i]->parent, (void *)node); success = false; } checkParentPointers(node->child[i], success); } } } int64_t checkMaxVersion(Node *node, bool &success) { int64_t expected = std::max(node->pointVersion, node->rangeVersion); for (int i = 0; i < 2; ++i) { if (node->child[i] != nullptr) { expected = std::max(expected, checkMaxVersion(node->child[i], success)); } } if (node->maxVersion != expected) { fprintf(stderr, "%.*s has max version %d. Expected %d\n", node->len, (const char *)(node + 1), int(node->maxVersion), int(expected)); success = false; } return expected; } bool checkInvariants(Node *node) { bool success = true; // Check bst invariant Arena arena; std::vector> keys{ ArenaAlloc(&arena)}; for (auto iter = extrema(node, false); iter != nullptr; iter = next(iter, true)) { keys.push_back(std::string_view((char *)(iter + 1), iter->len)); for (int i = 0; i < 2; ++i) { if (iter->child[i] != nullptr) { if (iter->priority < iter->child[i]->priority) { fprintf(stderr, "%.*s has priority < its child %.*s\n", iter->len, (const char *)(iter + 1), iter->child[i]->len, (const char *)(iter->child[i] + 1)); success = false; } } } } assert(std::is_sorted(keys.begin(), keys.end())); checkMaxVersion(node, success); checkParentPointers(node, success); // TODO Compare logical contents of map with // reference implementation return success; } } // namespace struct ConflictSet::Impl { Node *root; int64_t oldestVersion; explicit Impl(int64_t oldestVersion) noexcept : root(createNode({nullptr, 0}, nullptr, oldestVersion)), oldestVersion(oldestVersion) { root->rangeVersion = oldestVersion; } void check(const ReadRange *reads, Result *results, int count) const { Arena arena; auto *iters = new (arena) Iterator[count]; auto *begins = new (arena) Key[count]; for (int i = 0; i < count; ++i) { begins[i] = reads[i].begin; } lastLeqMulti(arena, root, std::span(begins, count), iters); // TODO check non-singleton reads lol for (int i = 0; i < count; ++i) { assert(reads[i].end.len == 0); assert(iters[i].node != nullptr); if ((iters[i].cmp == 0 ? iters[i].node->pointVersion : iters[i].node->rangeVersion) > reads[i].readVersion) { results[i] = ConflictSet::Conflict; } } } struct StepwiseInsert { // After this phase, the heap invariant may be violated for // (*current)->parent. Node **current; Node *parent; const Key *key; int64_t writeVersion; StepwiseInsert() {} StepwiseInsert(Node **root, const Key &key, int64_t writeVersion) : current(root), parent(nullptr), key(&key), writeVersion(writeVersion) {} bool step() { #if DEBUG fprintf(stderr, "Step insert of %.*s. At node: %.*s\n", key->len, key->p, (*current) ? (*current)->len : 7, (*current) ? (const char *)((*current) + 1) : "nullptr"); #endif if (*current == nullptr) { auto *newNode = createNode(*key, parent, writeVersion); *current = newNode; // We could interleave the iteration in ::next, but we'd need a careful // analysis for correctness and it's unlikely to be worthwhile. auto *prev = ::next(newNode, false); assert(prev != nullptr); assert(prev->rangeVersion <= writeVersion); newNode->rangeVersion = prev->rangeVersion; return true; } else { // This is the key optimization - setting the max version on the way // down the search path so we only have to do one traversal. (*current)->maxVersion = std::max((*current)->maxVersion, writeVersion); auto c = *key <=> **current; if (c == 0) { (*current)->pointVersion = writeVersion; return true; } parent = *current; current = &((*current)->child[c > 0]); } return false; } }; void addWrites(const WriteRange *writes, int count) { Arena arena; auto *stepwiseInserts = new (arena) StepwiseInsert[count]; for (int i = 0; i < count; ++i) { // TODO handle non-singleton writes lol assert(writes[i].end.len == 0); stepwiseInserts[i] = StepwiseInsert{&root, writes[i].begin, writes[i].writeVersion}; } // TODO Descend until queries for front and back diverge // Mitigate potential n^2 behavior of insertion by shuffling the insertion // order. Not sure how this interacts with interleaved insertion but it's // probably fine. // TODO better/faster RNG? std::mt19937 g(fastRand()); std::shuffle(stepwiseInserts, stepwiseInserts + count, g); runInterleaved(std::span(stepwiseInserts, count)); std::vector> workList{ ArenaAlloc(&arena)}; workList.reserve(count); for (int i = 0; i < count; ++i) { workList.push_back(*stepwiseInserts[i].current); } while (!workList.empty()) { Node *n = workList.back(); workList.pop_back(); #if DEBUG fprintf(stderr, "\tcheck heap invariant %.*s\n", n->len, (const char *)(n + 1)); #endif if (n->parent == nullptr) { continue; } const bool dir = n == n->parent->child[1]; assert(dir || n == n->parent->child[0]); // p is the address of the pointer to n->parent in the tree Node **p = n->parent->parent == nullptr ? &root : &n->parent->parent ->child[n->parent->parent->child[1] == n->parent]; assert(*p == n->parent); if (n->parent->priority < n->priority) { #if DEBUG fprintf(stderr, "\trotate %.*s %s\n", n->len, (const char *)(n + 1), !dir ? "right" : "left"); #endif rotate(p, !dir); workList.push_back(*p); assert((*p)->child[!dir] != nullptr); auto *lr = (*p)->child[!dir]->child[dir]; if (lr != nullptr) { workList.push_back(lr); } } } } void setOldestVersion(int64_t oldestVersion) { assert(oldestVersion > this->oldestVersion); this->oldestVersion = oldestVersion; } ~Impl() { Arena arena; std::vector> toFree{ArenaAlloc(&arena)}; if (root != nullptr) { toFree.push_back(root); } while (toFree.size() > 0) { Node *n = toFree.back(); toFree.pop_back(); for (int i = 0; i < 2; ++i) { auto *c = std::exchange(n->child[i], nullptr); if (c != nullptr) { toFree.push_back(c); } } destroyNode(n); } } }; void ConflictSet::check(const ReadRange *reads, Result *results, int count) const { return impl->check(reads, results, count); } void ConflictSet::addWrites(const WriteRange *writes, int count) { return impl->addWrites(writes, count); } void ConflictSet::setOldestVersion(int64_t oldestVersion) { return impl->setOldestVersion(oldestVersion); } ConflictSet::ConflictSet(int64_t oldestVersion) : impl(new(malloc(sizeof(Impl))) Impl{oldestVersion}) {} ConflictSet::~ConflictSet() { impl->~Impl(); free(impl); } ConflictSet::ConflictSet(ConflictSet &&other) noexcept : impl(std::exchange(other.impl, nullptr)) {} ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept { impl = std::exchange(other.impl, nullptr); return *this; } #ifdef ENABLE_TESTS int main(void) { int64_t writeVersion = 0; ConflictSet::Impl cs{writeVersion}; constexpr int kNumKeys = 5; ConflictSet::WriteRange write[kNumKeys]; Arena arena; for (int i = 0; i < kNumKeys; ++i) { write[i].begin = toKey(arena, i); write[i].end.len = 0; write[i].writeVersion = ++writeVersion; } cs.addWrites(write, kNumKeys); debugPrintDot(stdout, cs.root); bool success = checkInvariants(cs.root); return success ? 0 : 1; } #endif