diff --git a/ConflictSet.cpp b/ConflictSet.cpp index b6d4469..b272db6 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -3010,26 +3010,144 @@ Node *firstGeqPhysical(Node *n, const std::span key) { } struct CheckJob { - // Returned void* is a function pointer to the next continuation. We have to - // use void* because otherwise the type would be recursive. - typedef void *(*continuation)(CheckJob *); + Node *n; + std::span begin; + InternalVersionT readVersion; + ReadContext *tls; + ConflictSet::Result *result; + + void setResult(bool ok) { + *result = ok ? ConflictSet::Commit : ConflictSet::Conflict; + } + + typedef void (*typeErasedContinuation)(void *); + + // The type of a function that takes a CheckJob* and returns its own type + struct continuation { + typedef continuation (*functionPtrType)(CheckJob *); + functionPtrType func; + continuation operator()(CheckJob *job) { return func(job); } + /*implicit*/ continuation(functionPtrType func) : func(func) {} + continuation() = default; + operator bool() { return func != nullptr; } + }; + continuation next; void init(const ConflictSet::ReadRange *read, ConflictSet::Result *result, - Node *root, int64_t oldestVersionFullPrecision, ReadContext *tls) { - auto begin = std::span(read->begin.p, read->begin.len); - auto end = std::span(read->end.p, read->end.len); - *result = read->readVersion < oldestVersionFullPrecision - ? ConflictSet::TooOld - : (end.size() > 0 - ? checkRangeRead(root, begin, end, - InternalVersionT(read->readVersion), tls) - : checkPointRead(root, begin, - InternalVersionT(read->readVersion), tls)) + Node *root, int64_t oldestVersionFullPrecision, ReadContext *tls); +}; + +namespace check_point_read_state_machine { + +CheckJob::continuation down_left_spine(CheckJob *job); + +// Logically this is the same as performing firstGeq and then checking against +// point or range version according to cmp, but this version short circuits as +// soon as it can prove that there's no conflict. +CheckJob::continuation begin(CheckJob *job) { + ++job->tls->point_read_accum; +#if DEBUG_VERBOSE && !defined(NDEBUG) + fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); +#endif + for (;; ++job->tls->point_read_iterations_accum) { + if (job->begin.size() == 0) { + if (job->n->entryPresent) { + job->setResult(job->n->entry.pointVersion <= job->readVersion); + return nullptr; // Done + } + job->n = getFirstChildExists(job->n); + return down_left_spine; + } + + auto [child, maxV] = getChildAndMaxVersion(job->n, job->begin[0]); + if (child == nullptr) { + auto c = getChildGeq(job->n, job->begin[0]); + if (c != nullptr) { + job->n = c; + return down_left_spine; + } else { + job->n = nextSibling(job->n); + if (job->n == nullptr) { + job->setResult(true); + return nullptr; // Done + } + return down_left_spine; + } + } + + job->n = child; + job->begin = job->begin.subspan(1, job->begin.size() - 1); + + if (job->n->partialKeyLen > 0) { + int commonLen = std::min(job->n->partialKeyLen, job->begin.size()); + int i = longestCommonPrefix(job->n->partialKey(), job->begin.data(), + commonLen); + if (i < commonLen) { + auto c = job->n->partialKey()[i] <=> job->begin[i]; + if (c > 0) { + return down_left_spine; + } else { + job->n = nextSibling(job->n); + if (job->n == nullptr) { + job->setResult(true); + return nullptr; // Done + } + return down_left_spine; + } + } + if (commonLen == job->n->partialKeyLen) { + // partial key matches + job->begin = + job->begin.subspan(commonLen, job->begin.size() - commonLen); + } else if (job->n->partialKeyLen > int(job->begin.size())) { + // n is the first physical node greater than remaining, and there's no + // eq node + return down_left_spine; + } + } + + if (maxV <= job->readVersion) { + ++job->tls->point_read_short_circuit_accum; + job->setResult(true); + return nullptr; // Done + } + } +} + +CheckJob::continuation down_left_spine(CheckJob *job) { + if (job->n->entryPresent) { + job->setResult(job->n->entry.rangeVersion <= job->readVersion); + return nullptr; // Done + } + job->n = getFirstChildExists(job->n); + return down_left_spine; +} + +} // namespace check_point_read_state_machine + +void CheckJob::init(const ConflictSet::ReadRange *read, + ConflictSet::Result *result, Node *root, + int64_t oldestVersionFullPrecision, ReadContext *tls) { + auto begin = std::span(read->begin.p, read->begin.len); + auto end = std::span(read->end.p, read->end.len); + if (read->readVersion < oldestVersionFullPrecision) { + *result = ConflictSet::TooOld; + next = +[](CheckJob *) -> continuation { return nullptr; }; + } else if (end.size() == 0) { + this->begin = begin; + this->n = root; + this->readVersion = InternalVersionT(read->readVersion); + this->result = result; + this->tls = tls; + this->next = (CheckJob::continuation)check_point_read_state_machine::begin; + } else { + *result = checkRangeRead(root, begin, end, + InternalVersionT(read->readVersion), tls) ? ConflictSet::Commit : ConflictSet::Conflict; - next = +[](CheckJob *) -> void * { return nullptr; }; + next = +[](CheckJob *) -> continuation { return nullptr; }; } -}; +} struct __attribute__((visibility("hidden"))) ConflictSet::Impl { @@ -3060,10 +3178,9 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { int prevJob = started - 1; int job = 0; for (;;) { - auto next = - (CheckJob::continuation)inProgress[job].next(inProgress + job); + auto next = inProgress[job].next(inProgress + job); inProgress[job].next = next; - if (next == nullptr) { + if (!next) { if (started == count) { if (prevJob == job) break;