#include "ConflictSet.h" #include #include #include #include #include using Key = ConflictSet::Key; static auto operator<=>(const Key &lhs, const Key &rhs) { const int minLen = std::min(lhs.len, rhs.len); const int c = memcmp(lhs.p, rhs.p, minLen); return c != 0 ? c <=> 0 : lhs.len <=> rhs.len; } namespace { // A node in the tree representing write conflict history. This tree maintains // several invariants: // 1. BST invariant: all keys in the tree rooted at the left child of a node // compare less than that node's key, and all keys in the tree rooted at the // right child of a node compare greater than that node's key. // 2. Heap invariant: the priority of a node is greater than all the priorities // of its children (transitively) // 3. Max invariant: `maxVersion` is the max among all values of `pointVersion` // and `beyondVersion` for this node and its children (transitively) // 4. The lowest key (an empty byte sequence) is always physically present in // the tree so that "last less than or equal" queries are always well-defined. // Logically, the contents of the tree represent a "range map" where all of the // infinitely many points in the key space are associated with a writeVersion. // If a point is physically present in the tree, then its writeVersion is its // node's `pointVersion`. Otherwise, its writeVersion is the `rangeVersion` of // the node with the last key less than point. struct Node { // See "Max invariant" above int64_t maxVersion; // The write version of the point in the key space represented by this node's // key int64_t pointVersion; // The write version of the range immediately after this node's key, until // just before the next key in the tree. I.e. (this key, next key) int64_t rangeVersion; // child[0] is the left child or nullptr. child[1] is the right child or // nullptr Node *child[2]; // The parent of this node in the tree, or nullptr if this node is the root Node *parent; // As a treap, this tree satisfies the heap invariant on each node's priority uint32_t priority; // The length of this node's key int len; // The contents of this node's key // uint8_t[len]; auto operator<=>(const Node &other) const { const int minLen = std::min(len, other.len); const int c = memcmp(this + 1, &other + 1, minLen); return c != 0 ? c <=> 0 : len <=> other.len; } auto operator<=>(const ConflictSet::Key &other) const { const int minLen = std::min(len, other.len); const int c = memcmp(this + 1, other.p, minLen); return c != 0 ? c <=> 0 : len <=> other.len; } }; // TODO: use a better prng. This is technically vulnerable to a // denial-of-service attack that can make conflict-checking linear in the // number of nodes in the tree. thread_local uint32_t gSeed = 1013904223L; uint32_t fastRand() { auto result = gSeed; gSeed = gSeed * 1664525L + 1013904223L; return result; } // Note: `rangeVersion` is left uninitialized. Node *createNode(const Key &key, Node *parent, int64_t pointVersion) { assert(key.len <= std::numeric_limits::max()); Node *result = (Node *)malloc(sizeof(Node) + key.len); result->maxVersion = pointVersion; result->pointVersion = pointVersion; result->child[0] = nullptr; result->child[1] = nullptr; result->parent = parent; result->priority = 0xff & fastRand(); result->len = key.len; memcpy(result + 1, key.p, key.len); return result; } void destroyNode(Node *node) { assert(node->child[0] == nullptr); assert(node->child[1] == nullptr); free(node); } struct Iterator { Node *node; int cmp; }; void lastLeqMulti(Node *root, std::span keys, Iterator *results) { assert(std::is_sorted(keys.begin(), keys.end())); if (keys.size() == 0) { return; } struct Coro { Node *current; Node *result; const Key *key; int resultC = -1; int index; int lcp[2]{}; std::strong_ordering c = std::strong_ordering::equal; Coro() {} Coro(Node *current, Node *result, const Key &key, int index) : current(current), result(result), key(&key), index(index) {} bool step() { if (current == nullptr) { return true; } c = *current <=> *key; if (c == 0) { result = current; resultC = 0; return true; } result = c < 0 ? current : result; current = current->child[c < 0]; return false; } }; auto coros = std::unique_ptr{new Coro[keys.size()]}; // Descend until queries for front and back diverge Node *current = root; Node *resultP = nullptr; auto coroBegin = Coro(current, resultP, keys.front(), -1); auto coroEnd = Coro(current, resultP, keys.back(), -1); for (;;) { bool done1 = coroBegin.step(); bool done2 = coroEnd.step(); if (!done1 && !done2 && coroBegin.c == coroEnd.c) { assert(coroBegin.current == coroEnd.current); assert(coroBegin.result == coroEnd.result); current = coroBegin.current; resultP = coroBegin.result; } else { break; } } int index = 0; { auto iter = coros.get(); for (const auto &k : keys) { *iter++ = Coro(current, resultP, k, index++); } } auto remaining = std::span(coros.get(), keys.size()); while (remaining.size() > 0) { for (int i = 0; i < int(remaining.size());) { bool done = remaining[i].step(); if (done) { const auto &c = remaining[i]; results[c.index] = Iterator{c.result, c.resultC}; if (i != int(remaining.size()) - 1) { remaining[i] = remaining.back(); } remaining = remaining.subspan(0, remaining.size() - 1); } else { ++i; } } } } // Return a pointer to the node whose key immediately follows `n`'s key (if // `dir` is false, precedes). Return nullptr if none exists. [[maybe_unused]] Node *next(Node *n, bool dir) { // Traverse left spine of right child (when moving right, i.e. dir = true) if (n->child[dir]) { n = n->child[dir]; while (n->child[!dir]) { n = n->child[!dir]; } } else { // Search upward for a node such that we're the left child (when moving // right, i.e. dir = true) while (n->parent && n == n->parent->child[dir]) { n = n->parent; } n = n->parent; } return n; } // Return a pointer to the node whose key is greatest among keys in the tree // rooted at `n` (if dir = false, least). Return nullptr if none exists (i.e. // `n` is null). [[maybe_unused]] Node *extrema(Node *n, bool dir) { if (n == nullptr) { return nullptr; } while (n->child[dir] != nullptr) { n = n->child[dir]; } return n; } [[maybe_unused]] void debugPrintDot(FILE *file, Node *node) { struct DebugDotPrinter { explicit DebugDotPrinter(FILE *file) : file(file) {} void print(Node *node) { for (int i = 0; i < 2; ++i) { if (node->child[i] != nullptr) { fprintf(file, " k_%.*s -> k_%.*s;\n", node->len, (const char *)(node + 1), node->child[i]->len, (const char *)(node->child[i] + 1)); print(node->child[i]); } else { fprintf(file, " k_%.*s -> null%d;\n", node->len, (const char *)(node + 1), id); ++id; } } } int id = 0; FILE *file; }; fprintf(file, "digraph TreeSet {\n"); fprintf(file, " node [fontname=\"Scientifica\"];\n"); if (node != nullptr) { DebugDotPrinter printer{file}; fprintf(file, "\n"); printer.print(node); fprintf(file, "\n"); for (auto iter = extrema(node, false); iter != nullptr; iter = next(iter, true)) { fprintf(file, " k_%.*s [label=\"k=%.*s;p=%u;m=%d;v=%d,r=%d\"];\n", iter->len, (const char *)(iter + 1), iter->len, (const char *)(iter + 1), iter->priority, int(iter->maxVersion), int(iter->pointVersion), int(iter->rangeVersion)); } for (int i = 0; i < printer.id; ++i) { fprintf(file, " null%d [shape=point];\n", i); } } fprintf(file, "}\n"); } } // namespace struct ConflictSet::Impl { Node *root; int64_t oldestVersion; explicit Impl(int64_t oldestVersion) noexcept : root(createNode({nullptr, 0}, nullptr, oldestVersion)), oldestVersion(oldestVersion) { root->rangeVersion = oldestVersion; } void check(const ReadRange *reads, Result *results, int count) const { auto iters = std::unique_ptr{new Iterator[count]}; auto begins = std::unique_ptr{new Key[count]}; for (int i = 0; i < count; ++i) { begins.get()[i] = reads[i].begin; } lastLeqMulti(root, std::span(begins.get(), count), iters.get()); // TODO check non-singleton reads lol for (int i = 0; i < count; ++i) { assert(reads[i].end.len == 0); assert(iters[i].node != nullptr); if ((iters[i].cmp == 0 ? iters[i].node->pointVersion : iters[i].node->rangeVersion) > reads[i].readVersion) { results[i] = ConflictSet::Conflict; } } } void addWriteNaive(const WriteRange &write) { // TODO handle non-singleton writes lol Node **current = &root; Node *parent = nullptr; const auto &key = write.begin; for (;;) { if (*current == nullptr) { auto *newNode = createNode(key, parent, write.writeVersion); *current = newNode; auto *prev = ::next(newNode, false); assert(prev != nullptr); assert(prev->rangeVersion <= write.writeVersion); newNode->rangeVersion = prev->rangeVersion; break; } else { // TODO this assert won't be valid in the final design assert((*current)->maxVersion <= write.writeVersion); (*current)->maxVersion = write.writeVersion; auto c = key <=> **current; if (c == 0) { (*current)->pointVersion = write.writeVersion; break; } parent = *current; current = &((*current)->child[c > 0]); } } } void addWrites(const WriteRange *writes, int count) { for (const auto &w : std::span(writes, count)) { addWriteNaive(w); } } void setOldestVersion(int64_t oldestVersion) { assert(oldestVersion > this->oldestVersion); this->oldestVersion = oldestVersion; } ~Impl() { std::vector toFree; if (root != nullptr) { toFree.push_back(root); } while (toFree.size() > 0) { Node *n = toFree.back(); toFree.pop_back(); for (int i = 0; i < 2; ++i) { auto *c = std::exchange(n->child[i], nullptr); if (c != nullptr) { toFree.push_back(c); } } destroyNode(n); } } }; void ConflictSet::check(const ReadRange *reads, Result *results, int count) const { return impl->check(reads, results, count); } void ConflictSet::addWrites(const WriteRange *writes, int count) { return impl->addWrites(writes, count); } void ConflictSet::setOldestVersion(int64_t oldestVersion) { return impl->setOldestVersion(oldestVersion); } ConflictSet::ConflictSet(int64_t oldestVersion) : impl(new Impl{oldestVersion}) {} ConflictSet::~ConflictSet() { delete impl; } ConflictSet::ConflictSet(ConflictSet &&other) noexcept : impl(std::exchange(other.impl, nullptr)) {} ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept { impl = std::exchange(other.impl, nullptr); return *this; } #ifdef ENABLE_TESTS int main(void) { ConflictSet::Impl cs{0}; ConflictSet::WriteRange write; write.begin.p = (const uint8_t *)"0000"; write.begin.len = 4; write.end.len = 0; write.writeVersion = 1; cs.addWrites(&write, 1); debugPrintDot(stdout, cs.root); } #endif