From 8c15cb28d0616cd891f73c1253049a169cb6dd8c Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Thu, 18 Jan 2024 10:34:02 -0800 Subject: [PATCH] Finish addWriteNaive for singleton writes --- ConflictSet.cpp | 94 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 85 insertions(+), 9 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 963fca6..89dba0e 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -85,7 +85,7 @@ Node *createNode(const Key &key, Node *parent, int64_t pointVersion) { result->child[0] = nullptr; result->child[1] = nullptr; result->parent = parent; - result->priority = 0xff & fastRand(); + result->priority = fastRand(); result->len = key.len; memcpy(result + 1, key.p, key.len); return result; @@ -240,7 +240,6 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { }; fprintf(file, "digraph TreeSet {\n"); - fprintf(file, " node [fontname=\"Scientifica\"];\n"); if (node != nullptr) { DebugDotPrinter printer{file}; fprintf(file, "\n"); @@ -248,10 +247,10 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { fprintf(file, "\n"); for (auto iter = extrema(node, false); iter != nullptr; iter = next(iter, true)) { - fprintf(file, " k_%.*s [label=\"k=%.*s;p=%u;m=%d;v=%d,r=%d\"];\n", - iter->len, (const char *)(iter + 1), iter->len, - (const char *)(iter + 1), iter->priority, int(iter->maxVersion), - int(iter->pointVersion), int(iter->rangeVersion)); + fprintf(file, " k_%.*s [label=\"k=%.*s;m=%d;v=%d,r=%d\"];\n", iter->len, + (const char *)(iter + 1), iter->len, (const char *)(iter + 1), + int(iter->maxVersion), int(iter->pointVersion), + int(iter->rangeVersion)); } for (int i = 0; i < printer.id; ++i) { fprintf(file, " null%d [shape=point];\n", i); @@ -272,6 +271,57 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { return Key{itoaBuf, kMaxLength}; } +// Recompute maxVersion, and propagate up the tree as necessary +void updateMaxVersion(Node *n) { + for (;;) { + int64_t maxVersion = std::max(n->pointVersion, n->rangeVersion); + for (int i = 0; i < 2; ++i) { + maxVersion = + std::max(maxVersion, n->child[i] != nullptr ? n->child[i]->maxVersion + : maxVersion); + } + if (n->maxVersion == maxVersion || n->parent == nullptr) { + break; + } + n->maxVersion = maxVersion; + n = n->parent; + } +} + +Node **rotate(Node **node, bool dir) { + // diagram shown for dir == true + /* n + / + l + \ + lr + */ + assert(node != nullptr); + Node *n = *node; + assert(n != nullptr); + Node *parent = n->parent; + Node *l = n->child[!dir]; + assert(l != nullptr); + Node *lr = l->child[dir]; + n->child[!dir] = lr; + if (lr) { + lr->parent = n; + } + l->child[dir] = n; + n->parent = l; + l->parent = parent; + *node = l; + /* l + \ + n + / + lr + */ + updateMaxVersion(n); + updateMaxVersion(l); + return &l->child[dir]; +} + } // namespace struct ConflictSet::Impl { @@ -307,6 +357,7 @@ struct ConflictSet::Impl { Node **current = &root; Node *parent = nullptr; const auto &key = write.begin; + bool inserted = false; for (;;) { if (*current == nullptr) { auto *newNode = createNode(key, parent, write.writeVersion); @@ -315,11 +366,13 @@ struct ConflictSet::Impl { assert(prev != nullptr); assert(prev->rangeVersion <= write.writeVersion); newNode->rangeVersion = prev->rangeVersion; + inserted = true; break; } else { - // TODO this assert won't be valid in the final design - assert((*current)->maxVersion <= write.writeVersion); - (*current)->maxVersion = write.writeVersion; + // This is the key optimization - setting the max version on the way + // down the search path so we only have to do one traversal. + (*current)->maxVersion = + std::max((*current)->maxVersion, write.writeVersion); auto c = key <=> **current; if (c == 0) { (*current)->pointVersion = write.writeVersion; @@ -329,6 +382,29 @@ struct ConflictSet::Impl { current = &((*current)->child[c > 0]); } } + if (inserted) { + auto *n = *current; + assert(n != nullptr); + for (;;) { + if (n->parent == nullptr) { + break; + } + const bool dir = n == n->parent->child[1]; + assert(dir || n == n->parent->child[0]); + // p is the address of the pointer to n->parent in the tree + Node **p = n->parent->parent == nullptr + ? &root + : &n->parent->parent + ->child[n->parent->parent->child[1] == n->parent]; + assert(*p == n->parent); + if (n->parent->priority < n->priority) { + p = rotate(p, !dir); + n = (*p)->parent; + } else { + break; + } + } + } } void addWrites(const WriteRange *writes, int count) {