From 9363d7866cb22138ef7ca055971379fcd4ee300c Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Wed, 7 Feb 2024 15:33:15 -0800 Subject: [PATCH] Specify maxVersion meaning --- ConflictSet.cpp | 142 +++++++++++++++++++++++++++++++++++------------- Internal.h | 1 + 2 files changed, 106 insertions(+), 37 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index c828f39..7ca8296 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -36,6 +36,8 @@ enum class Type : int8_t { struct Node { /* begin section that's copied to the next node */ Node *parent = nullptr; + // The max write version over all keys that start with the search path up to + // this point int64_t maxVersion; Entry entry; int16_t numChildren = 0; @@ -615,7 +617,7 @@ struct Iterator { }; namespace { -std::string getSearchPath(Node *n) { +std::string getSearchPathPrintable(Node *n) { Arena arena; if (n == nullptr) { return ""; @@ -639,6 +641,43 @@ std::string getSearchPath(Node *n) { return std::string(); } } + +std::string strinc(std::string_view str, bool &ok) { + int index; + for (index = str.size() - 1; index >= 0; index--) + if (str[index] != 255) + break; + + // Must not be called with a string that consists only of zero or more '\xff' + // bytes. + if (index < 0) { + ok = false; + return {}; + } + ok = true; + + auto r = std::string(str.substr(0, index + 1)); + ((uint8_t &)r[r.size() - 1])++; + return r; +} + +std::string getSearchPath(Node *n) { + assert(n != nullptr); + Arena arena; + auto result = vector(arena); + for (;;) { + for (int i = n->partialKeyLen - 1; i >= 0; --i) { + result.push_back(n->partialKey[i]); + } + if (n->parent == nullptr) { + break; + } + result.push_back(n->parentsIndex); + n = n->parent; + } + std::reverse(result.begin(), result.end()); + return std::string((const char *)result.data(), result.size()); +} } // namespace Iterator firstGeq(Node *n, const std::span key) { @@ -710,20 +749,19 @@ downLeftSpine: } } +Iterator firstGeq(Node *n, std::string_view key) { + return firstGeq( + n, std::span((const uint8_t *)key.data(), key.size())); +} + // Logically this is the same as performing firstGeq and then checking against -// point or range version according to cmp, but this version short circuits if -// it can prove that both point and range versions of firstGeq are <= -// readVersion. +// point or range version according to cmp, but this version short circuits as +// soon as it can prove that there's no conflict bool checkPointRead(Node *n, const std::span key, int64_t readVersion) { auto remaining = key; Node *nextSib = nullptr; for (;;) { - if (std::max(nextSib != nullptr ? nextSib->maxVersion - : std::numeric_limits::lowest(), - n->maxVersion) <= readVersion) { - return true; - } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); for (int i = 0; i < commonLen; ++i) { @@ -747,6 +785,9 @@ bool checkPointRead(Node *n, const std::span key, goto downLeftSpine; } } + if (n->maxVersion <= readVersion) { + return true; + } if (remaining.size() == 0) { if (n->entryPresent) { return n->entry.pointVersion <= readVersion; @@ -780,9 +821,6 @@ downLeftSpine: return true; } for (;;) { - if (n->maxVersion <= readVersion) { - return true; - } if (n->entryPresent) { return n->entry.rangeVersion <= readVersion; } @@ -817,10 +855,10 @@ bool checkRangeRead(Node *n, const std::span begin, #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "firstGeq for `%s' got `%s'\n", printable(begin).c_str(), - getSearchPath(left.n).c_str()); + getSearchPathPrintable(left.n).c_str()); fprintf(stderr, "firstGeq for `%s' got `%s'\n", printable(end).c_str(), - getSearchPath(right.n).c_str()); - fprintf(stderr, "lca `%s'\n", getSearchPath(lca).c_str()); + getSearchPathPrintable(right.n).c_str()); + fprintf(stderr, "lca `%s'\n", getSearchPathPrintable(lca).c_str()); #endif if (left.n != nullptr && left.cmp != 0 && left.n->entry.rangeVersion > readVersion) { @@ -875,11 +913,11 @@ bool checkRangeRead(Node *n, const std::span begin, } // Returns a pointer to the newly inserted node. caller is reponsible for -// setting 'entry' fields on the result, which may have !entryPresent. The -// search path for `key` will have maxVersion at least `writeVersion` as a -// postcondition. +// setting 'entry' fields and `maxVersion` on the result, which may have +// !entryPresent. The search path of the result's parent will have +// `maxVersion` at least `writeVersion` as a postcondition. [[nodiscard]] Node *insert(Node **self_, std::span key, - int64_t writeVersion) { + int64_t writeVersion, bool begin) { for (;;) { auto &self = *self_; // Handle an existing partial key @@ -914,18 +952,27 @@ bool checkRangeRead(Node *n, const std::span begin, key = key.subspan(self->partialKeyLen, key.size() - self->partialKeyLen); } - self->maxVersion = std::max(self->maxVersion, writeVersion); + if (begin) { + self->maxVersion = std::max(self->maxVersion, writeVersion); + } if (key.size() == 0) { return self; } + + if (!begin) { + self->maxVersion = std::max(self->maxVersion, writeVersion); + } + auto &child = getOrCreateChild(self, key.front()); if (!child) { child = newNode(); - child->maxVersion = writeVersion; child->parent = self; child->parentsIndex = key.front(); + child->maxVersion = + begin ? writeVersion : std::numeric_limits::lowest(); } + self_ = &child; key = key.subspan(1, key.size() - 1); } @@ -977,7 +1024,7 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { if (w.end.len > 0) { auto *begin = insert(&root, std::span(w.begin.p, w.begin.len), - w.writeVersion); + w.writeVersion, true); const bool insertedBegin = !std::exchange(begin->entryPresent, true); @@ -985,11 +1032,15 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { auto *p = nextLogical(begin); begin->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; + begin->entry.pointVersion = w.writeVersion; + begin->maxVersion = w.writeVersion; } - begin->entry.pointVersion = w.writeVersion; + begin->maxVersion = std::max(begin->maxVersion, w.writeVersion); + begin->entry.pointVersion = + std::max(begin->entry.pointVersion, w.writeVersion); auto *end = insert(&root, std::span(w.end.p, w.end.len), - w.writeVersion); + w.writeVersion, false); const bool insertedEnd = !std::exchange(end->entryPresent, true); @@ -997,6 +1048,7 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { auto *p = nextLogical(end); end->entry.pointVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; + end->maxVersion = std::max(end->maxVersion, end->entry.pointVersion); } end->entry.rangeVersion = w.writeVersion; @@ -1019,16 +1071,18 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { } else { auto *n = insert(&root, std::span(w.begin.p, w.begin.len), - w.writeVersion); + w.writeVersion, true); if (!n->entryPresent) { auto *p = nextLogical(n); n->entryPresent = true; n->entry.pointVersion = w.writeVersion; + n->maxVersion = w.writeVersion; n->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; } else { n->entry.pointVersion = std::max(n->entry.pointVersion, w.writeVersion); + n->maxVersion = std::max(n->maxVersion, w.writeVersion); } } } @@ -1134,10 +1188,11 @@ namespace { " k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", (void *)n, n->maxVersion, n->entry.pointVersion, - n->entry.rangeVersion, getSearchPath(n).c_str(), x, y); + n->entry.rangeVersion, getSearchPathPrintable(n).c_str(), x, y); } else { fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", - (void *)n, n->maxVersion, getSearchPath(n).c_str(), x, y); + (void *)n, n->maxVersion, getSearchPathPrintable(n).c_str(), x, + y); } x += kSeparation; for (int child = getChildGeq(n, 0); child >= 0; @@ -1164,7 +1219,7 @@ void checkParentPointers(Node *node, bool &success) { auto *child = getChildExists(node, i); if (child->parent != node) { fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n", - getSearchPath(node).c_str(), i, (void *)child->parent, + getSearchPathPrintable(node).c_str(), i, (void *)child->parent, (void *)node); success = false; } @@ -1172,18 +1227,31 @@ void checkParentPointers(Node *node, bool &success) { } } -[[maybe_unused]] int64_t checkMaxVersion(Node *node, bool &success) { - int64_t expected = - node->entryPresent - ? std::max(node->entry.pointVersion, node->entry.rangeVersion) - : std::numeric_limits::lowest(); +[[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node, + bool &success) { + int64_t expected = std::numeric_limits::lowest(); + if (node->entryPresent) { + expected = std::max(expected, node->entry.pointVersion); + } for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); - expected = std::max(expected, checkMaxVersion(child, success)); + expected = std::max(expected, checkMaxVersion(root, child, success)); + if (child->entryPresent) { + expected = std::max(expected, child->entry.rangeVersion); + } + } + auto key = getSearchPath(root); + bool ok; + auto inc = strinc(key, ok); + if (ok) { + auto borrowed = firstGeq(root, inc); + if (borrowed.n != nullptr) { + expected = std::max(expected, borrowed.n->entry.rangeVersion); + } } if (node->maxVersion != expected) { fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n", - getSearchPath(node).c_str(), node->maxVersion, expected); + getSearchPathPrintable(node).c_str(), node->maxVersion, expected); success = false; } return expected; @@ -1198,7 +1266,7 @@ void checkParentPointers(Node *node, bool &success) { if (e == 0) { Arena arena; fprintf(stderr, "%s has child %02x with no reachable entries\n", - getSearchPath(node).c_str(), i); + getSearchPathPrintable(node).c_str(), i); success = false; } } @@ -1209,7 +1277,7 @@ bool checkCorrectness(Node *node) { bool success = true; checkParentPointers(node, success); - checkMaxVersion(node, success); + checkMaxVersion(node, node, success); checkEntriesExist(node, success); return success; diff --git a/Internal.h b/Internal.h index a618dee..01a88aa 100644 --- a/Internal.h +++ b/Internal.h @@ -530,6 +530,7 @@ template struct TestDriver { { int numPointReads = arbitrary.bounded(100); int numRangeReads = arbitrary.bounded(100); + numRangeReads = 0; int64_t v = std::max(writeVersion - arbitrary.bounded(10), 0); auto *reads = new (arena) ConflictSet::ReadRange[numPointReads + numRangeReads];