diff --git a/.clangd b/.clangd index 8b1da38..5702466 100644 --- a/.clangd +++ b/.clangd @@ -1,2 +1,2 @@ CompileFlags: - Add: [-DENABLE_TESTS] + Add: [-DENABLE_TESTS, -UNDEBUG] diff --git a/ConflictSet.cpp b/ConflictSet.cpp index d6b08b8..8d0e85f 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -1,9 +1,19 @@ #include "ConflictSet.h" #include -#include +#include +#include +#include #include +using Key = ConflictSet::Key; + +auto operator<=>(const Key &lhs, const Key &rhs) { + const int minLen = std::min(lhs.len, rhs.len); + const int c = memcmp(lhs.p, rhs.p, minLen); + return c != 0 ? c <=> 0 : lhs.len <=> rhs.len; +} + namespace { // A node in the tree representing write conflict history. This tree maintains // several invariants: @@ -49,10 +59,10 @@ struct Node { const int c = memcmp(this + 1, &other + 1, minLen); return c != 0 ? c <=> 0 : len <=> other.len; } - auto operator<=>(std::string_view other) const { - const int minLen = std::min(len, other.size()); - const int c = memcmp(this + 1, other.data(), minLen); - return c != 0 ? c <=> 0 : len <=> int(other.size()); + auto operator<=>(const ConflictSet::Key &other) const { + const int minLen = std::min(len, other.len); + const int c = memcmp(this + 1, other.p, minLen); + return c != 0 ? c <=> 0 : len <=> other.len; } }; @@ -67,17 +77,17 @@ uint32_t fastRand() { } // Note: `rangeVersion` is left uninitialized. -Node *createNode(std::string_view key, Node *parent, int64_t pointVersion) { - assert(key.size() <= std::numeric_limits::max()); - Node *result = (Node *)malloc(sizeof(Node) + key.size()); +Node *createNode(const Key &key, Node *parent, int64_t pointVersion) { + assert(key.len <= std::numeric_limits::max()); + Node *result = (Node *)malloc(sizeof(Node) + key.len); result->maxVersion = pointVersion; result->pointVersion = pointVersion; result->child[0] = nullptr; result->child[1] = nullptr; result->parent = parent; - result->priority = fastRand(); - result->len = key.size(); - memcpy(result + 1, key.data(), key.size()); + result->priority = 0xff & fastRand(); + result->len = key.len; + memcpy(result + 1, key.p, key.len); return result; } @@ -87,6 +97,92 @@ void destroyNode(Node *node) { free(node); } +struct Iterator { + Node *node; + int cmp; +}; + +void lastLeqMulti(Node *root, std::span keys, Iterator *results) { + assert(std::is_sorted(keys.begin(), keys.end())); + + if (keys.size() == 0) { + return; + } + + struct Coro { + Node *current; + Node *result; + const Key *key; + int resultC = -1; + int index; + int lcp[2]{}; + std::strong_ordering c = std::strong_ordering::equal; + + Coro() {} + Coro(Node *current, Node *result, const Key &key, int index) + : current(current), result(result), key(&key), index(index) {} + + bool step() { + if (current == nullptr) { + return true; + } + c = *current <=> *key; + if (c == 0) { + result = current; + resultC = 0; + return true; + } + result = c < 0 ? current : result; + current = current->child[c < 0]; + return false; + } + }; + + auto coros = std::unique_ptr{new Coro[keys.size()]}; + + // Descend until queries for front and back diverge + Node *current = root; + Node *resultP = nullptr; + auto coroBegin = Coro(current, resultP, keys.front(), -1); + auto coroEnd = Coro(current, resultP, keys.back(), -1); + for (;;) { + bool done1 = coroBegin.step(); + bool done2 = coroEnd.step(); + if (!done1 && !done2 && coroBegin.c == coroEnd.c) { + assert(coroBegin.current == coroEnd.current); + assert(coroBegin.result == coroEnd.result); + current = coroBegin.current; + resultP = coroBegin.result; + } else { + break; + } + } + + int index = 0; + { + auto iter = coros.get(); + for (const auto &k : keys) { + *iter++ = Coro(current, resultP, k, index++); + } + } + auto remaining = std::span(coros.get(), keys.size()); + while (remaining.size() > 0) { + for (int i = 0; i < int(remaining.size());) { + bool done = remaining[i].step(); + if (done) { + const auto &c = remaining[i]; + results[c.index] = Iterator{c.result, c.resultC}; + if (i != int(remaining.size()) - 1) { + remaining[i] = remaining.back(); + } + remaining = remaining.subspan(0, remaining.size() - 1); + } else { + ++i; + } + } + } +} + // Return a pointer to the node whose key immediately follows `n`'s key (if // `dir` is false, precedes). Return nullptr if none exists. [[maybe_unused]] Node *next(Node *n, bool dir) { @@ -128,15 +224,14 @@ void destroyNode(Node *node) { void print(Node *node) { for (int i = 0; i < 2; ++i) { - if (node->child[0] != nullptr) { - fprintf(file, " _%.*s -> _%.*s;\n", node->len, - (const char *)(node + 1), node->child[0]->len, - (const char *)(node->child[0] + 1)); - print(node->child[0]); + if (node->child[i] != nullptr) { + fprintf(file, " k_%.*s -> k_%.*s;\n", node->len, + (const char *)(node + 1), node->child[i]->len, + (const char *)(node->child[i] + 1)); + print(node->child[i]); } else { - fprintf(file, " _%.*s -> null%d;\n", node->len, + fprintf(file, " k_%.*s -> null%d;\n", node->len, (const char *)(node + 1), id); - fprintf(file, " null%d [shape=point];\n", id); ++id; } } @@ -147,12 +242,21 @@ void destroyNode(Node *node) { fprintf(file, "digraph TreeSet {\n"); fprintf(file, " node [fontname=\"Scientifica\"];\n"); - for (auto iter = extrema(node, false); iter != nullptr; - iter = next(iter, true)) { - fprintf(file, " _%.*s;\n", node->len, (const char *)(node + 1)); - } if (node != nullptr) { - DebugDotPrinter{file}.print(node); + DebugDotPrinter printer{file}; + fprintf(file, "\n"); + printer.print(node); + fprintf(file, "\n"); + for (auto iter = extrema(node, false); iter != nullptr; + iter = next(iter, true)) { + fprintf(file, " k_%.*s [label=\"k=%.*s;p=%u;m=%d;v=%d,r=%d\"];\n", + iter->len, (const char *)(iter + 1), iter->len, + (const char *)(iter + 1), iter->priority, int(iter->maxVersion), + int(iter->pointVersion), int(iter->rangeVersion)); + } + for (int i = 0; i < printer.id; ++i) { + fprintf(file, " null%d [shape=point];\n", i); + } } fprintf(file, "}\n"); } @@ -163,13 +267,64 @@ struct ConflictSet::Impl { Node *root; int64_t oldestVersion; explicit Impl(int64_t oldestVersion) noexcept - : root(createNode("", nullptr, oldestVersion)), + : root(createNode({nullptr, 0}, nullptr, oldestVersion)), oldestVersion(oldestVersion) { root->rangeVersion = oldestVersion; } - void check(const ReadRange *reads, Result *results, int count) const {} - void addWrites(const WriteRange *writes, int count) {} + void check(const ReadRange *reads, Result *results, int count) const { + auto iters = std::unique_ptr{new Iterator[count]}; + auto begins = std::unique_ptr{new Key[count]}; + for (int i = 0; i < count; ++i) { + begins.get()[i] = reads[i].begin; + } + lastLeqMulti(root, std::span(begins.get(), count), iters.get()); + // TODO check non-singleton reads lol + for (int i = 0; i < count; ++i) { + assert(reads[i].end.len == 0); + assert(iters[i].node != nullptr); + if ((iters[i].cmp == 0 + ? iters[i].node->pointVersion + : iters[i].node->rangeVersion) > reads[i].readVersion) { + results[i] = ConflictSet::Conflict; + } + } + } + + void addWriteNaive(const WriteRange &write) { + // TODO handle non-singleton writes lol + Node **current = &root; + Node *parent = nullptr; + const auto &key = write.begin; + for (;;) { + if (*current == nullptr) { + auto *newNode = createNode(key, parent, write.writeVersion); + *current = newNode; + auto *prev = ::next(newNode, false); + assert(prev != nullptr); + assert(prev->rangeVersion <= write.writeVersion); + newNode->rangeVersion = prev->rangeVersion; + break; + } else { + // TODO this assert won't be valid in the final design + assert((*current)->maxVersion <= write.writeVersion); + (*current)->maxVersion = write.writeVersion; + auto c = key <=> **current; + if (c == 0) { + (*current)->pointVersion = write.writeVersion; + break; + } + parent = *current; + current = &((*current)->child[c > 0]); + } + } + } + + void addWrites(const WriteRange *writes, int count) { + for (const auto &w : std::span(writes, count)) { + addWriteNaive(w); + } + } void setOldestVersion(int64_t oldestVersion) { assert(oldestVersion > this->oldestVersion); @@ -224,6 +379,12 @@ ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept { #ifdef ENABLE_TESTS int main(void) { ConflictSet::Impl cs{0}; + ConflictSet::WriteRange write; + write.begin.p = (const uint8_t *)"0000"; + write.begin.len = 4; + write.end.len = 0; + write.writeVersion = 1; + cs.addWrites(&write, 1); debugPrintDot(stdout, cs.root); } #endif \ No newline at end of file diff --git a/ConflictSet.h b/ConflictSet.h index 41f3853..4f818ca 100644 --- a/ConflictSet.h +++ b/ConflictSet.h @@ -32,6 +32,8 @@ struct ConflictSet { /// `end` having length 0 denotes that this range is the single key {begin}. /// Otherwise this denotes the range [begin, end) Key end; + /// Write version must be greater than all write versions in all previous + /// calls to `addWrites` int64_t writeVersion; }; /// `reads` must be sorted ascending, and must not have adjacent or