diff --git a/ConflictSet.cpp b/ConflictSet.cpp index ac30fce..e8ded02 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -1963,155 +1963,6 @@ Node *nextSibling(Node *node) { } } -// Logically this is the same as performing firstGeq and then checking against -// point or range version according to cmp, but this version short circuits as -// soon as it can prove that there's no conflict. -bool checkPointRead(Node *n, const std::span key, - InternalVersionT readVersion, ReadContext *tls) { - ++tls->point_read_accum; -#if DEBUG_VERBOSE && !defined(NDEBUG) - fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); -#endif - auto remaining = key; - for (;; ++tls->point_read_iterations_accum) { - if (remaining.size() == 0) { - if (n->entryPresent) { - return n->entry.pointVersion <= readVersion; - } - n = getFirstChildExists(n); - goto downLeftSpine; - } - - auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]); - Node *child = c; - if (child == nullptr) { - auto c = getChildGeq(n, remaining[0]); - if (c != nullptr) { - n = c; - goto downLeftSpine; - } else { - n = nextSibling(n); - if (n == nullptr) { - return true; - } - goto downLeftSpine; - } - } - - n = child; - remaining = remaining.subspan(1, remaining.size() - 1); - - if (n->partialKeyLen > 0) { - int commonLen = std::min(n->partialKeyLen, remaining.size()); - int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); - if (i < commonLen) { - auto c = n->partialKey()[i] <=> remaining[i]; - if (c > 0) { - goto downLeftSpine; - } else { - n = nextSibling(n); - if (n == nullptr) { - return true; - } - goto downLeftSpine; - } - } - if (commonLen == n->partialKeyLen) { - // partial key matches - remaining = remaining.subspan(commonLen, remaining.size() - commonLen); - } else if (n->partialKeyLen > int(remaining.size())) { - // n is the first physical node greater than remaining, and there's no - // eq node - goto downLeftSpine; - } - } - - if (maxV <= readVersion) { - ++tls->point_read_short_circuit_accum; - return true; - } - } -downLeftSpine: - for (; !n->entryPresent; n = getFirstChildExists(n)) { - } - return n->entry.rangeVersion <= readVersion; -} - -// Logically this is the same as performing firstGeq and then checking against -// max version or range version if this prefix doesn't exist, but this version -// short circuits as soon as it can prove that there's no conflict. -bool checkPrefixRead(Node *n, const std::span key, - InternalVersionT readVersion, ReadContext *tls) { - ++tls->prefix_read_accum; -#if DEBUG_VERBOSE && !defined(NDEBUG) - fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str()); -#endif - auto remaining = key; - for (;; ++tls->prefix_read_iterations_accum) { - if (remaining.size() == 0) { - // There's no way to encode a prefix read of "", so n is not the root - return maxVersion(n) <= readVersion; - } - - auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]); - Node *child = c; - if (child == nullptr) { - auto c = getChildGeq(n, remaining[0]); - if (c != nullptr) { - n = c; - goto downLeftSpine; - } else { - n = nextSibling(n); - if (n == nullptr) { - return true; - } - goto downLeftSpine; - } - } - - n = child; - remaining = remaining.subspan(1, remaining.size() - 1); - - if (n->partialKeyLen > 0) { - int commonLen = std::min(n->partialKeyLen, remaining.size()); - int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); - if (i < commonLen) { - auto c = n->partialKey()[i] <=> remaining[i]; - if (c > 0) { - goto downLeftSpine; - } else { - n = nextSibling(n); - if (n == nullptr) { - return true; - } - goto downLeftSpine; - } - } - if (commonLen == n->partialKeyLen) { - // partial key matches - remaining = remaining.subspan(commonLen, remaining.size() - commonLen); - } else if (n->partialKeyLen > int(remaining.size())) { - // n is the first physical node greater than remaining, and there's no - // eq node. All physical nodes that start with prefix are reachable from - // n. - if (maxVersion(n) > readVersion) { - return false; - } - goto downLeftSpine; - } - } - - if (maxV <= readVersion) { - ++tls->prefix_read_short_circuit_accum; - return true; - } - } -downLeftSpine: - for (; !n->entryPresent; n = getFirstChildExists(n)) { - } - return n->entry.rangeVersion <= readVersion; -} - #ifdef HAS_AVX uint32_t compare16(const InternalVersionT *vs, InternalVersionT rv) { #if USE_64_BIT @@ -2815,20 +2666,9 @@ downLeftSpine: } } // namespace -bool checkRangeRead(Node *n, std::span begin, +bool checkRangeRead(int lcp, Node *n, std::span begin, std::span end, InternalVersionT readVersion, ReadContext *tls) { - int lcp = longestCommonPrefix(begin.data(), end.data(), - std::min(begin.size(), end.size())); - if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && - end.back() == 0) { - return checkPointRead(n, begin, readVersion, tls); - } - if (lcp == int(begin.size() - 1) && end.size() == begin.size() && - int(begin.back()) + 1 == int(end.back())) { - return checkPrefixRead(n, begin, readVersion, tls); - } - ++tls->range_read_accum; auto remaining = begin.subspan(0, lcp); @@ -3429,6 +3269,134 @@ void down_left_spine(CheckJob *job, CheckContext *context) { } // namespace check_point_read_state_machine +namespace check_prefix_read_state_machine { + +FLATTEN PRESERVE_NONE void begin(CheckJob *, CheckContext *); + +template +FLATTEN PRESERVE_NONE void iter(CheckJob *, CheckContext *); + +FLATTEN PRESERVE_NONE void down_left_spine(CheckJob *, CheckContext *); + +static Continuation iterTable[] = {iter, iter, iter, + iter, iter}; + +void begin(CheckJob *job, CheckContext *context) { + ++context->tls->prefix_read_accum; +#if DEBUG_VERBOSE && !defined(NDEBUG) + fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str()); +#endif + + // There's no way to encode a prefix read of "" + assert(job->begin.size() > 0); + + auto taggedChild = getChild(job->n, job->begin[0]); + Node *child = taggedChild; + if (child == nullptr) [[unlikely]] { + auto c = getChildGeq(job->n, job->begin[0]); + if (c != nullptr) { + job->n = c; + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } else { + // The root never has a next sibling + job->setResult(true); + MUSTTAIL return complete(job, context); + } + } + job->continuation = iterTable[taggedChild.getType()]; + job->n = child; + __builtin_prefetch(child); + MUSTTAIL return keepGoing(job, context); +} + +template void iter(CheckJob *job, CheckContext *context) { + + assert(NodeT::kType == job->n->getType()); + NodeT *n = static_cast(job->n); + job->begin = job->begin.subspan(1, job->begin.size() - 1); + + if (n->partialKeyLen > 0) { + int commonLen = std::min(n->partialKeyLen, job->begin.size()); + int i = longestCommonPrefix(n->partialKey(), job->begin.data(), commonLen); + if (i < commonLen) [[unlikely]] { + auto c = n->partialKey()[i] <=> job->begin[i]; + if (c > 0) { + job->continuation = down_left_spine; + MUSTTAIL return down_left_spine(job, context); + } else { + job->n = nextSibling(n); + if (job->n == nullptr) { + job->setResult(true); + MUSTTAIL return complete(job, context); + } + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } + } + if (commonLen == n->partialKeyLen) { + // partial key matches + job->begin = job->begin.subspan(commonLen, job->begin.size() - commonLen); + } else if (n->partialKeyLen > int(job->begin.size())) [[unlikely]] { + // n is the first physical node greater than remaining, and there's no + // eq node. All physical nodes that start with prefix are reachable from + // n. + if (maxVersion(n) > job->readVersion) { + job->setResult(false); + MUSTTAIL return complete(job, context); + } + job->continuation = down_left_spine; + MUSTTAIL return down_left_spine(job, context); + } + } + + ++context->tls->prefix_read_iterations_accum; + + if (job->begin.size() == 0) [[unlikely]] { + job->setResult(maxVersion(job->n) <= job->readVersion); + MUSTTAIL return complete(job, context); + } + + auto taggedChild = getChild(n, job->begin[0]); + Node *child = taggedChild; + if (child == nullptr) [[unlikely]] { + auto c = getChildGeq(n, job->begin[0]); + if (c != nullptr) { + job->n = c; + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } else { + job->n = nextSibling(job->n); + if (job->n == nullptr) { + job->setResult(true); + MUSTTAIL return complete(job, context); + } + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } + } + job->continuation = iterTable[taggedChild.getType()]; + job->n = child; + __builtin_prefetch(child); + MUSTTAIL return keepGoing(job, context); +} + +void down_left_spine(CheckJob *job, CheckContext *context) { + if (job->n->entryPresent) { + job->setResult(job->n->entry.rangeVersion <= job->readVersion); + MUSTTAIL return complete(job, context); + } + job->n = getFirstChildExists(job->n); + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); +} + +} // namespace check_prefix_read_state_machine + namespace check_range_read_state_machine { FLATTEN PRESERVE_NONE void begin(CheckJob *, CheckContext *); @@ -3444,15 +3412,13 @@ FLATTEN PRESERVE_NONE void begin(CheckJob *job, CheckContext *context) { if (lcp == int(job->begin.size() - 1) && job->end.size() == job->begin.size() && int(job->begin.back()) + 1 == int(job->end.back())) { - *job->result = - checkPrefixRead(job->n, job->begin, job->readVersion, context->tls) - ? ConflictSet::Commit - : ConflictSet::Conflict; - return complete(job, context); + job->continuation = check_prefix_read_state_machine::begin; + // Call directly since we have nothing to prefetch + MUSTTAIL return job->continuation(job, context); } - *job->result = checkRangeRead(job->n, job->begin, job->end, job->readVersion, - context->tls) + *job->result = checkRangeRead(lcp, job->n, job->begin, job->end, + job->readVersion, context->tls) ? ConflictSet::Commit : ConflictSet::Conflict; return complete(job, context);