diff --git a/ConflictSet.cpp b/ConflictSet.cpp index ebb7d93..110fec5 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -939,6 +939,54 @@ Node *getChild(Node *self, uint8_t index) { } } +struct ChildAndMaxVersion { + Node *child; + InternalVersionT maxVersion; +}; + +ChildAndMaxVersion getChildAndMaxVersion(Node0 *, uint8_t) { return {}; } +ChildAndMaxVersion getChildAndMaxVersion(Node3 *self, uint8_t index) { + int i = getNodeIndex(self, index); + if (i < 0) { + return {}; + } + return {self->children[i], self->childMaxVersion[i]}; +} +ChildAndMaxVersion getChildAndMaxVersion(Node16 *self, uint8_t index) { + int i = getNodeIndex(self, index); + if (i < 0) { + return {}; + } + return {self->children[i], self->childMaxVersion[i]}; +} +ChildAndMaxVersion getChildAndMaxVersion(Node48 *self, uint8_t index) { + int i = self->index[index]; + if (i < 0) { + return {}; + } + return {self->children[i], self->childMaxVersion[i]}; +} +ChildAndMaxVersion getChildAndMaxVersion(Node256 *self, uint8_t index) { + return {self->children[index], self->childMaxVersion[index]}; +} + +ChildAndMaxVersion getChildAndMaxVersion(Node *self, uint8_t index) { + switch (self->getType()) { + case Type_Node0: + return getChildAndMaxVersion(static_cast(self), index); + case Type_Node3: + return getChildAndMaxVersion(static_cast(self), index); + case Type_Node16: + return getChildAndMaxVersion(static_cast(self), index); + case Type_Node48: + return getChildAndMaxVersion(static_cast(self), index); + case Type_Node256: + return getChildAndMaxVersion(static_cast(self), index); + default: // GCOVR_EXCL_LINE + __builtin_unreachable(); // GCOVR_EXCL_LINE + } +} + template Node *getChildGeqSimd(NodeT *self, int child) { static_assert(std::is_same_v || std::is_same_v); @@ -1783,6 +1831,7 @@ int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { // Performs a physical search for remaining struct SearchStepWise { Node *n; + InternalVersionT maxV; std::span remaining; SearchStepWise() {} @@ -1795,7 +1844,8 @@ struct SearchStepWise { if (remaining.size() == 0) { return true; } - auto *child = getChild(n, remaining[0]); + auto [child, v] = getChildAndMaxVersion(n, remaining[0]); + maxV = v; if (child == nullptr) { return true; } @@ -1825,12 +1875,7 @@ bool checkPointRead(Node *n, const std::span key, fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); #endif auto remaining = key; - auto *impl = tls->impl; for (;; ++tls->point_read_iterations_accum) { - if (maxVersion(n, impl) <= readVersion) { - ++tls->point_read_short_circuit_accum; - return true; - } if (remaining.size() == 0) { if (n->entryPresent) { return n->entry.pointVersion <= readVersion; @@ -1839,7 +1884,7 @@ bool checkPointRead(Node *n, const std::span key, goto downLeftSpine; } - auto *child = getChild(n, remaining[0]); + auto [child, maxV] = getChildAndMaxVersion(n, remaining[0]); if (child == nullptr) { auto c = getChildGeq(n, remaining[0]); if (c != nullptr) { @@ -1881,6 +1926,11 @@ bool checkPointRead(Node *n, const std::span key, goto downLeftSpine; } } + + if (maxV <= readVersion) { + ++tls->point_read_short_circuit_accum; + return true; + } } downLeftSpine: for (; !n->entryPresent; n = getFirstChildExists(n)) { @@ -1900,17 +1950,11 @@ bool checkPrefixRead(Node *n, const std::span key, auto remaining = key; auto *impl = tls->impl; for (;; ++tls->prefix_read_iterations_accum) { - auto m = maxVersion(n, impl); if (remaining.size() == 0) { - return m <= readVersion; + return maxVersion(n, impl) <= readVersion; } - if (m <= readVersion) { - ++tls->prefix_read_short_circuit_accum; - return true; - } - - auto *child = getChild(n, remaining[0]); + auto [child, maxV] = getChildAndMaxVersion(n, remaining[0]); if (child == nullptr) { auto c = getChildGeq(n, remaining[0]); if (c != nullptr) { @@ -1956,6 +2000,11 @@ bool checkPrefixRead(Node *n, const std::span key, goto downLeftSpine; } } + + if (maxV <= readVersion) { + ++tls->prefix_read_short_circuit_accum; + return true; + } } downLeftSpine: for (; !n->entryPresent; n = getFirstChildExists(n)) { @@ -2511,10 +2560,6 @@ struct CheckRangeLeftSide { bool ok; bool step() { - if (maxVersion(n, impl) <= readVersion) { - ok = true; - return true; - } if (remaining.size() == 0) { assert(searchPathLen >= prefixLen); ok = maxVersion(n, impl) <= readVersion; @@ -2528,7 +2573,7 @@ struct CheckRangeLeftSide { } } - auto *child = getChild(n, remaining[0]); + auto [child, maxV] = getChildAndMaxVersion(n, remaining[0]); if (child == nullptr) { auto c = getChildGeq(n, remaining[0]); if (c != nullptr) { @@ -2591,6 +2636,10 @@ struct CheckRangeLeftSide { return true; } } + if (maxV <= readVersion) { + ok = true; + return true; + } return false; } @@ -2752,18 +2801,17 @@ bool checkRangeRead(Node *n, std::span begin, SearchStepWise search{n, begin.subspan(0, lcp)}; Arena arena; - auto *impl = tls->impl; for (;; ++tls->range_read_iterations_accum) { assert(getSearchPath(arena, search.n) <=> begin.subspan(0, lcp - search.remaining.size()) == 0); - if (maxVersion(search.n, impl) <= readVersion) { - ++tls->range_read_short_circuit_accum; - return true; - } if (search.step()) { break; } + if (search.maxV <= readVersion) { + ++tls->range_read_short_circuit_accum; + return true; + } } assert(getSearchPath(arena, search.n) <=> begin.subspan(0, lcp - search.remaining.size()) ==