diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 8edd983..0a6b389 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -161,7 +161,6 @@ struct Node { Node *parent = nullptr; // The max write version over all keys that start with the search path up to // this point - int64_t maxVersion; Entry entry; int16_t numChildren = 0; bool entryPresent = false; @@ -314,27 +313,7 @@ int64_t getChildMaxVersion(Node *self, uint8_t index) { } // Precondition - an entry for index must exist in the node -void setParentsChildMaxVersion(Node *self) { - int index = self->parentsIndex; - self = self->parent; - if (self == nullptr) { - return; - } - if (self->type <= Type::Node16) { - auto *self16 = static_cast(self); - int i = getNodeIndex(self16, index); - self16->children[i].childMaxVersion = self16->children[i].child->maxVersion; - } else if (self->type == Type::Node48) { - auto *self48 = static_cast(self); - assert(self48->bitSet.test(index)); - self48->children[self48->index[index]].childMaxVersion = - self48->children[self48->index[index]].child->maxVersion; - } else { - auto *self256 = static_cast(self); - self256->children[index].childMaxVersion = - self256->children[index].child->maxVersion; - } -} +int64_t &maxVersion(Node *n, ConflictSet::Impl *); Node *getChild(Node *self, uint8_t index) { if (self->type <= Type::Node16) { @@ -875,13 +854,13 @@ std::string getSearchPathPrintable(Node *n); // point or range version according to cmp, but this version short circuits as // soon as it can prove that there's no conflict. bool checkPointRead(Node *n, const std::span key, - int64_t readVersion) { + int64_t readVersion, ConflictSet::Impl *impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); #endif auto remaining = key; for (;;) { - if (n->maxVersion <= readVersion) { + if (maxVersion(n, impl) <= readVersion) { return true; } if (remaining.size() == 0) { @@ -1040,13 +1019,14 @@ Vector getSearchPath(Arena &arena, Node *n) { // Return true if the max version among all keys that start with key + [child], // where begin < child < end, is <= readVersion bool checkRangeStartsWith(Node *n, std::span key, int begin, - int end, int64_t readVersion) { + int end, int64_t readVersion, + ConflictSet::Impl *impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end); #endif auto remaining = key; for (;;) { - if (n->maxVersion <= readVersion) { + if (maxVersion(n, impl) <= readVersion) { return true; } if (remaining.size() == 0) { @@ -1090,7 +1070,7 @@ bool checkRangeStartsWith(Node *n, std::span key, int begin, if (n->entryPresent && n->entry.rangeVersion > readVersion) { return false; } - return n->maxVersion <= readVersion; + return maxVersion(n, impl) <= readVersion; } return true; } @@ -1114,8 +1094,9 @@ downLeftSpine: // that are >= key is <= readVersion struct CheckRangeLeftSide { CheckRangeLeftSide(Node *n, std::span key, int prefixLen, - int64_t readVersion) - : n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion) { + int64_t readVersion, ConflictSet::Impl *impl) + : n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion), + impl(impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check range left side from %s for keys starting with %s\n", printable(key).c_str(), @@ -1127,6 +1108,7 @@ struct CheckRangeLeftSide { std::span remaining; int prefixLen; int64_t readVersion; + ConflictSet::Impl *impl; int searchPathLen = 0; bool ok; @@ -1135,13 +1117,13 @@ struct CheckRangeLeftSide { bool step() { switch (phase) { case Search: { - if (n->maxVersion <= readVersion) { + if (maxVersion(n, impl) <= readVersion) { ok = true; return true; } if (remaining.size() == 0) { assert(searchPathLen >= prefixLen); - ok = n->maxVersion <= readVersion; + ok = maxVersion(n, impl) <= readVersion; return true; } @@ -1161,7 +1143,7 @@ struct CheckRangeLeftSide { return downLeftSpine(); } n = getChildExists(n, c); - ok = n->maxVersion <= readVersion; + ok = maxVersion(n, impl) <= readVersion; return true; } else { n = nextSibling(n); @@ -1188,7 +1170,7 @@ struct CheckRangeLeftSide { ok = false; return true; } - ok = n->maxVersion <= readVersion; + ok = maxVersion(n, impl) <= readVersion; return true; } else { n = nextSibling(n); @@ -1207,7 +1189,7 @@ struct CheckRangeLeftSide { ok = false; return true; } - ok = n->maxVersion <= readVersion; + ok = maxVersion(n, impl) <= readVersion; return true; } } @@ -1240,9 +1222,9 @@ struct CheckRangeLeftSide { // that are < key is <= readVersion struct CheckRangeRightSide { CheckRangeRightSide(Node *n, std::span key, int prefixLen, - int64_t readVersion) + int64_t readVersion, ConflictSet::Impl *impl) : n(n), key(key), remaining(key), prefixLen(prefixLen), - readVersion(readVersion) { + readVersion(readVersion), impl(impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check range right side to %s for keys starting with %s\n", printable(key).c_str(), @@ -1255,6 +1237,7 @@ struct CheckRangeRightSide { std::span remaining; int prefixLen; int64_t readVersion; + ConflictSet::Impl *impl; int searchPathLen = 0; bool ok; @@ -1353,7 +1336,7 @@ struct CheckRangeRightSide { bool backtrack() { for (;;) { - if (searchPathLen > prefixLen && n->maxVersion > readVersion) { + if (searchPathLen > prefixLen && maxVersion(n, impl) > readVersion) { ok = false; return true; } @@ -1385,12 +1368,13 @@ struct CheckRangeRightSide { }; bool checkRangeRead(Node *n, std::span begin, - std::span end, int64_t readVersion) { + std::span end, int64_t readVersion, + ConflictSet::Impl *impl) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && end.back() == 0) { - return checkPointRead(n, begin, readVersion); + return checkPointRead(n, begin, readVersion, impl); } SearchStepWise search{n, begin.subspan(0, lcp)}; @@ -1399,7 +1383,7 @@ bool checkRangeRead(Node *n, std::span begin, assert(getSearchPath(arena, search.n) <=> begin.subspan(0, lcp - search.remaining.size()) == 0); - if (search.n->maxVersion <= readVersion) { + if (maxVersion(search.n, impl) <= readVersion) { return true; } if (search.step()) { @@ -1419,19 +1403,19 @@ bool checkRangeRead(Node *n, std::span begin, lcp -= consumed; if (lcp == int(begin.size())) { - CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion}; + CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion, impl}; while (!checkRangeRightSide.step()) ; return checkRangeRightSide.ok; } if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp], - readVersion)) { + readVersion, impl)) { return false; } - CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion}; - CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion}; + CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion, impl}; + CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion, impl}; for (;;) { bool leftDone = checkRangeLeftSide.step(); @@ -1462,7 +1446,8 @@ bool checkRangeRead(Node *n, std::span begin, // `maxVersion` at least `writeVersion` as a postcondition. template [[nodiscard]] Node *insert(Node **self, std::span key, - int64_t writeVersion, NodeAllocators *allocators) { + int64_t writeVersion, NodeAllocators *allocators, + ConflictSet::Impl *impl) { for (;;) { @@ -1473,6 +1458,7 @@ template (*self)->partialKey, key.data(), commonLen); if (partialKeyIndex < (*self)->partialKeyLen) { auto *old = *self; + int64_t oldMaxVersion = maxVersion(old, impl); *self = allocators->node4.allocate(); @@ -1485,7 +1471,7 @@ template old; old->parent = *self; old->parentsIndex = old->partialKey[partialKeyIndex]; - setParentsChildMaxVersion(old); + maxVersion(old, impl) = oldMaxVersion; memmove(old->partialKey, old->partialKey + partialKeyIndex + 1, old->partialKeyLen - (partialKeyIndex + 1)); @@ -1505,8 +1491,8 @@ template } if constexpr (kBegin) { - (*self)->maxVersion = std::max((*self)->maxVersion, writeVersion); - setParentsChildMaxVersion(*self); + auto &m = maxVersion(*self, impl); + m = std::max(m, writeVersion); } if (key.size() == 0) { @@ -1514,8 +1500,8 @@ template } if constexpr (!kBegin) { - (*self)->maxVersion = std::max((*self)->maxVersion, writeVersion); - setParentsChildMaxVersion(*self); + auto &m = maxVersion(*self, impl); + m = std::max(m, writeVersion); } auto &child = getOrCreateChild(*self, key.front(), allocators); @@ -1523,9 +1509,8 @@ template child = allocators->node4.allocate(); child->parent = *self; child->parentsIndex = key.front(); - child->maxVersion = + maxVersion(child, impl) = kBegin ? writeVersion : std::numeric_limits::lowest(); - setParentsChildMaxVersion(child); } self = &child; @@ -1553,14 +1538,13 @@ void destroyTree(Node *root) { void addPointWrite(Node *&root, int64_t oldestVersion, std::span key, int64_t writeVersion, - NodeAllocators *allocators) { - auto *n = insert(&root, key, writeVersion, allocators); + NodeAllocators *allocators, ConflictSet::Impl *impl) { + auto *n = insert(&root, key, writeVersion, allocators, impl); if (!n->entryPresent) { auto *p = nextLogical(n); n->entryPresent = true; n->entry.pointVersion = writeVersion; - n->maxVersion = writeVersion; - setParentsChildMaxVersion(n); + maxVersion(n, impl) = writeVersion; n->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; } else { @@ -1570,13 +1554,15 @@ void addPointWrite(Node *&root, int64_t oldestVersion, void addWriteRange(Node *&root, int64_t oldestVersion, std::span begin, std::span end, - int64_t writeVersion, NodeAllocators *allocators) { + int64_t writeVersion, NodeAllocators *allocators, + ConflictSet::Impl *impl) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && end.back() == 0) { - return addPointWrite(root, oldestVersion, begin, writeVersion, allocators); + return addPointWrite(root, oldestVersion, begin, writeVersion, allocators, + impl); } auto remaining = begin.subspan(0, lcp); @@ -1597,8 +1583,8 @@ void addWriteRange(Node *&root, int64_t oldestVersion, break; } - n->maxVersion = std::max(n->maxVersion, writeVersion); - setParentsChildMaxVersion(n); + auto &m = maxVersion(n, impl); + m = std::max(m, writeVersion); remaining = remaining.subspan(n->partialKeyLen + 1, remaining.size() - (n->partialKeyLen + 1)); @@ -1614,7 +1600,8 @@ void addWriteRange(Node *&root, int64_t oldestVersion, begin = begin.subspan(consumed, begin.size() - consumed); end = end.subspan(consumed, end.size() - consumed); - auto *beginNode = insert(useAsRoot, begin, writeVersion, allocators); + auto *beginNode = + insert(useAsRoot, begin, writeVersion, allocators, impl); const bool insertedBegin = !std::exchange(beginNode->entryPresent, true); @@ -1623,14 +1610,14 @@ void addWriteRange(Node *&root, int64_t oldestVersion, beginNode->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; beginNode->entry.pointVersion = writeVersion; - beginNode->maxVersion = writeVersion; + maxVersion(beginNode, impl) = writeVersion; } - beginNode->maxVersion = std::max(beginNode->maxVersion, writeVersion); - setParentsChildMaxVersion(beginNode); + auto &m = maxVersion(beginNode, impl); + m = std::max(m, writeVersion); beginNode->entry.pointVersion = std::max(beginNode->entry.pointVersion, writeVersion); - auto *endNode = insert(useAsRoot, end, writeVersion, allocators); + auto *endNode = insert(useAsRoot, end, writeVersion, allocators, impl); const bool insertedEnd = !std::exchange(endNode->entryPresent, true); @@ -1638,15 +1625,14 @@ void addWriteRange(Node *&root, int64_t oldestVersion, auto *p = nextLogical(endNode); endNode->entry.pointVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; - endNode->maxVersion = - std::max(endNode->maxVersion, endNode->entry.pointVersion); - setParentsChildMaxVersion(endNode); + auto &m = maxVersion(endNode, impl); + m = std::max(m, endNode->entry.pointVersion); } endNode->entry.rangeVersion = writeVersion; if (insertedEnd) { // beginNode may have been invalidated - beginNode = insert(useAsRoot, begin, writeVersion, allocators); + beginNode = insert(useAsRoot, begin, writeVersion, allocators, impl); } for (beginNode = nextLogical(beginNode); beginNode != endNode;) { @@ -1771,7 +1757,7 @@ Iterator firstGeq(Node *n, std::string_view key) { struct __attribute__((visibility("hidden"))) ConflictSet::Impl { - void check(const ReadRange *reads, Result *result, int count) const { + void check(const ReadRange *reads, Result *result, int count) { for (int i = 0; i < count; ++i) { const auto &r = reads[i]; auto begin = std::span(r.begin.p, r.begin.len); @@ -1779,8 +1765,8 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { result[i] = reads[i].readVersion < oldestVersion ? TooOld : (end.size() > 0 - ? checkRangeRead(root, begin, end, reads[i].readVersion) - : checkPointRead(root, begin, reads[i].readVersion)) + ? checkRangeRead(root, begin, end, reads[i].readVersion, this) + : checkPointRead(root, begin, reads[i].readVersion, this)) ? Commit : Conflict; } @@ -1794,10 +1780,11 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { if (w.end.len > 0) { keyUpdates += 2; addWriteRange(root, oldestVersion, begin, end, w.writeVersion, - &allocators); + &allocators, this); } else { keyUpdates += 1; - addPointWrite(root, oldestVersion, begin, w.writeVersion, &allocators); + addPointWrite(root, oldestVersion, begin, w.writeVersion, &allocators, + this); } } } @@ -1840,7 +1827,7 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) { // Insert "" root = allocators.node4.allocate(); - root->maxVersion = oldestVersion; + rootMaxVersion = oldestVersion; root->entry.pointVersion = oldestVersion; root->entry.rangeVersion = oldestVersion; root->entryPresent = true; @@ -1854,9 +1841,31 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { int64_t keyUpdates = 0; Node *root; + int64_t rootMaxVersion; int64_t oldestVersion; }; +// Precondition - an entry for index must exist in the node +int64_t &maxVersion(Node *n, ConflictSet::Impl *impl) { + int index = n->parentsIndex; + n = n->parent; + if (n == nullptr) { + return impl->rootMaxVersion; + } + if (n->type <= Type::Node16) { + auto *n16 = static_cast(n); + int i = getNodeIndex(n16, index); + return n16->children[i].childMaxVersion; + } else if (n->type == Type::Node48) { + auto *n48 = static_cast(n); + assert(n48->bitSet.test(index)); + return n48->children[n48->index[index]].childMaxVersion; + } else { + auto *n256 = static_cast(n); + return n256->children[index].childMaxVersion; + } +} + // ==================== END IMPLEMENTATION ==================== // GCOVR_EXCL_START @@ -1988,13 +1997,15 @@ std::string getSearchPath(Node *n) { return std::string((const char *)result.data(), result.size()); } -[[maybe_unused]] void debugPrintDot(FILE *file, Node *node) { +[[maybe_unused]] void debugPrintDot(FILE *file, Node *node, + ConflictSet::Impl *impl) { constexpr int kSeparation = 3; struct DebugDotPrinter { - explicit DebugDotPrinter(FILE *file) : file(file) {} + explicit DebugDotPrinter(FILE *file, ConflictSet::Impl *impl) + : file(file), impl(impl) {} void print(Node *n, int y = 0) { assert(n != nullptr); @@ -2002,12 +2013,12 @@ std::string getSearchPath(Node *n) { fprintf(file, " k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", - (void *)n, n->maxVersion, n->entry.pointVersion, + (void *)n, maxVersion(n, impl), n->entry.pointVersion, n->entry.rangeVersion, getPartialKeyPrintable(n).c_str(), x, y); } else { fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", - (void *)n, n->maxVersion, getPartialKeyPrintable(n).c_str(), x, - y); + (void *)n, maxVersion(n, impl), + getPartialKeyPrintable(n).c_str(), x, y); } x += kSeparation; for (int child = getChildGeq(n, 0); child >= 0; @@ -2019,12 +2030,13 @@ std::string getSearchPath(Node *n) { } int x = 0; FILE *file; + ConflictSet::Impl *impl; }; fprintf(file, "digraph ConflictSet {\n"); fprintf(file, " node [shape = box];\n"); assert(node != nullptr); - DebugDotPrinter printer{file}; + DebugDotPrinter printer{file, impl}; printer.print(node); fprintf(file, "}\n"); } @@ -2043,15 +2055,16 @@ void checkParentPointers(Node *node, bool &success) { } [[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node, - int64_t oldestVersion, bool &success) { + int64_t oldestVersion, bool &success, + ConflictSet::Impl *impl) { int64_t expected = std::numeric_limits::lowest(); if (node->entryPresent) { expected = std::max(expected, node->entry.pointVersion); } for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); - expected = std::max(expected, - checkMaxVersion(root, child, oldestVersion, success)); + expected = std::max( + expected, checkMaxVersion(root, child, oldestVersion, success, impl)); if (child->entryPresent) { expected = std::max(expected, child->entry.rangeVersion); } @@ -2067,17 +2080,19 @@ void checkParentPointers(Node *node, bool &success) { } if (node->parent != nullptr && getChildMaxVersion(node->parent, node->parentsIndex) != - node->maxVersion) { + maxVersion(node, impl)) { fprintf(stderr, "%s has max version %" PRId64 " . But parent has child max version %" PRId64 "\n", - getSearchPathPrintable(node).c_str(), node->maxVersion, + getSearchPathPrintable(node).c_str(), maxVersion(node, impl), getChildMaxVersion(node->parent, node->parentsIndex)); success = false; } - if (node->maxVersion > oldestVersion && node->maxVersion != expected) { + if (maxVersion(node, impl) > oldestVersion && + maxVersion(node, impl) != expected) { fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n", - getSearchPathPrintable(node).c_str(), node->maxVersion, expected); + getSearchPathPrintable(node).c_str(), maxVersion(node, impl), + expected); success = false; } return expected; @@ -2099,11 +2114,12 @@ void checkParentPointers(Node *node, bool &success) { return total; } -bool checkCorrectness(Node *node, int64_t oldestVersion) { +bool checkCorrectness(Node *node, int64_t oldestVersion, + ConflictSet::Impl *impl) { bool success = true; checkParentPointers(node, success); - checkMaxVersion(node, node, oldestVersion, success); + checkMaxVersion(node, node, oldestVersion, success, impl); checkEntriesExist(node, success); return success; @@ -2134,7 +2150,7 @@ void printTree() { write[i].writeVersion = ++writeVersion; } cs.addWrites(write, kNumKeys); - debugPrintDot(stdout, cs.root); + debugPrintDot(stdout, cs.root, &cs); } int main(void) { @@ -2151,16 +2167,17 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { for (;;) { bool done = driver.next(); if (!driver.ok) { - debugPrintDot(stdout, driver.cs.root); + debugPrintDot(stdout, driver.cs.root, &driver.cs); fflush(stdout); abort(); } #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check correctness\n"); #endif - bool success = checkCorrectness(driver.cs.root, driver.cs.oldestVersion); + bool success = + checkCorrectness(driver.cs.root, driver.cs.oldestVersion, &driver.cs); if (!success) { - debugPrintDot(stdout, driver.cs.root); + debugPrintDot(stdout, driver.cs.root, &driver.cs); fflush(stdout); abort(); }