diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 5a03fc4..c0ae817 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -3208,34 +3208,297 @@ Node *firstGeqPhysical(Node *n, const std::span key) { } } +#ifndef __has_attribute +#define __has_attribute(x) 0 +#endif + +#if __has_attribute(musttail) +#define MUSTTAIL __attribute__((musttail)) +#else +#define MUSTTAIL +#endif + +#if __has_attribute(preserve_none) +#define PRESERVE_NONE __attribute__((preserve_none)) +#else +#define PRESERVE_NONE +#endif + +#if __has_attribute(flatten) +#define FLATTEN __attribute__((flatten)) +#else +#define FLATTEN +#endif + +typedef PRESERVE_NONE void (*Continuation)(struct CheckJob *, + struct CheckContext *); + +// State relevant to an individual query +struct CheckJob { + void setResult(bool ok) { + *result = ok ? ConflictSet::Commit : ConflictSet::Conflict; + } + + void init(const ConflictSet::ReadRange *read, ConflictSet::Result *result, + Node *root, int64_t oldestVersionFullPrecision, ReadContext *tls); + + Node *n; + std::span begin; + InternalVersionT readVersion; + ConflictSet::Result *result; + Continuation continuation; + CheckJob *prev; + CheckJob *next; +}; + +// State relevant to every query +struct CheckContext { + int count; + int64_t oldestVersionFullPrecision; + Node *root; + const ConflictSet::ReadRange *queries; + ConflictSet::Result *results; + int64_t started; + ReadContext *tls; +}; + +FLATTEN PRESERVE_NONE void keepGoing(CheckJob *job, CheckContext *context) { + job = job->next; + MUSTTAIL return job->continuation(job, context); +} + +FLATTEN PRESERVE_NONE void complete(CheckJob *job, CheckContext *context) { + if (context->started == context->count) { + if (job->prev == job) { + return; + } + job->prev->next = job->next; + job->next->prev = job->prev; + job = job->prev; + } else { + int temp = context->started++; + job->init(context->queries + temp, context->results + temp, context->root, + context->oldestVersionFullPrecision, context->tls); + } + MUSTTAIL return keepGoing(job, context); +} + +namespace check_point_read_state_machine { + +FLATTEN PRESERVE_NONE void begin(CheckJob *, CheckContext *); + +template +FLATTEN PRESERVE_NONE void iter(CheckJob *, CheckContext *); + +FLATTEN PRESERVE_NONE void down_left_spine(CheckJob *, CheckContext *); + +static Continuation iterTable[] = {iter, iter, iter, + iter, iter}; + +void begin(CheckJob *job, CheckContext *context) { + ++context->tls->point_read_accum; +#if DEBUG_VERBOSE && !defined(NDEBUG) + fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); +#endif + + if (job->begin.size() == 0) [[unlikely]] { + if (job->n->entryPresent) { + job->setResult(job->n->entry.pointVersion <= job->readVersion); + MUSTTAIL return complete(job, context); + } + job->n = getFirstChildExists(job->n); + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } + + auto taggedChild = getChild(job->n, job->begin[0]); + Node *child = taggedChild; + if (child == nullptr) [[unlikely]] { + auto c = getChildGeq(job->n, job->begin[0]); + if (c != nullptr) { + job->n = c; + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } else { + job->n = nextSibling(job->n); + if (job->n == nullptr) { + job->setResult(true); + MUSTTAIL return complete(job, context); + } + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } + } + job->continuation = iterTable[taggedChild.getType()]; + job->n = child; + __builtin_prefetch(child); + MUSTTAIL return keepGoing(job, context); +} + +template void iter(CheckJob *job, CheckContext *context) { + + assert(NodeT::kType == job->n->getType()); + NodeT *n = static_cast(job->n); + job->begin = job->begin.subspan(1, job->begin.size() - 1); + + if (n->partialKeyLen > 0) { + int commonLen = std::min(n->partialKeyLen, job->begin.size()); + int i = longestCommonPrefix(n->partialKey(), job->begin.data(), commonLen); + if (i < commonLen) [[unlikely]] { + auto c = n->partialKey()[i] <=> job->begin[i]; + if (c > 0) { + job->continuation = down_left_spine; + MUSTTAIL return down_left_spine(job, context); + } else { + job->n = nextSibling(n); + if (job->n == nullptr) { + job->setResult(true); + MUSTTAIL return complete(job, context); + } + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } + } + if (commonLen == n->partialKeyLen) { + // partial key matches + job->begin = job->begin.subspan(commonLen, job->begin.size() - commonLen); + } else if (n->partialKeyLen > int(job->begin.size())) [[unlikely]] { + // n is the first physical node greater than remaining, and there's no + // eq node + job->continuation = down_left_spine; + MUSTTAIL return down_left_spine(job, context); + } + } + + ++context->tls->point_read_iterations_accum; + + if (job->begin.size() == 0) [[unlikely]] { + if (n->entryPresent) { + job->setResult(n->entry.pointVersion <= job->readVersion); + MUSTTAIL return complete(job, context); + } + job->n = getFirstChildExists(n); + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } + + auto taggedChild = getChild(n, job->begin[0]); + Node *child = taggedChild; + if (child == nullptr) [[unlikely]] { + auto c = getChildGeq(n, job->begin[0]); + if (c != nullptr) { + job->n = c; + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } else { + job->n = nextSibling(job->n); + if (job->n == nullptr) { + job->setResult(true); + MUSTTAIL return complete(job, context); + } + job->continuation = down_left_spine; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); + } + } + job->continuation = iterTable[taggedChild.getType()]; + job->n = child; + __builtin_prefetch(child); + MUSTTAIL return keepGoing(job, context); +} + +void down_left_spine(CheckJob *job, CheckContext *context) { + if (job->n->entryPresent) { + job->setResult(job->n->entry.rangeVersion <= job->readVersion); + MUSTTAIL return complete(job, context); + } + job->n = getFirstChildExists(job->n); + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); +} + +} // namespace check_point_read_state_machine + +void CheckJob::init(const ConflictSet::ReadRange *read, + ConflictSet::Result *result, Node *root, + int64_t oldestVersionFullPrecision, ReadContext *tls) { + auto begin = std::span(read->begin.p, read->begin.len); + auto end = std::span(read->end.p, read->end.len); + if (read->readVersion < oldestVersionFullPrecision) [[unlikely]] { + *result = ConflictSet::TooOld; + continuation = complete; + } else if (end.size() == 0) { + this->begin = begin; + this->n = root; + this->readVersion = InternalVersionT(read->readVersion); + this->result = result; + continuation = check_point_read_state_machine::begin; + } else { + *result = checkRangeRead(root, begin, end, + InternalVersionT(read->readVersion), tls) + ? ConflictSet::Commit + : ConflictSet::Conflict; + continuation = complete; + } +} + struct __attribute__((visibility("hidden"))) ConflictSet::Impl { void check(const ReadRange *reads, Result *result, int count) { + assert(oldestVersionFullPrecision >= + newestVersionFullPrecision - kNominalVersionWindow); + + if (count == 0) { + return; + } + ReadContext tls; tls.impl = this; int64_t check_byte_accum = 0; + constexpr int kConcurrent = 16; + CheckJob inProgress[kConcurrent]; + CheckContext context; + context.count = count; + context.oldestVersionFullPrecision = oldestVersionFullPrecision; + context.root = root; + context.queries = reads; + context.results = result; + context.tls = &tls; + int64_t started = std::min(kConcurrent, count); + context.started = started; + for (int i = 0; i < started; i++) { + inProgress[i].init(reads + i, result + i, root, + oldestVersionFullPrecision, &tls); + } + for (int i = 0; i < started - 1; i++) { + inProgress[i].next = inProgress + i + 1; + } + for (int i = 1; i < started; i++) { + inProgress[i].prev = inProgress + i - 1; + } + inProgress[0].prev = inProgress + started - 1; + inProgress[started - 1].next = inProgress; + + // Kick off the sequence of tail calls that finally returns once all jobs + // are done + inProgress->continuation(inProgress, &context); + for (int i = 0; i < count; ++i) { assert(reads[i].readVersion >= 0); assert(reads[i].readVersion <= newestVersionFullPrecision); const auto &r = reads[i]; check_byte_accum += r.begin.len + r.end.len; - auto begin = std::span(r.begin.p, r.begin.len); - auto end = std::span(r.end.p, r.end.len); - assert(oldestVersionFullPrecision >= - newestVersionFullPrecision - kNominalVersionWindow); - result[i] = - reads[i].readVersion < oldestVersionFullPrecision ? TooOld - : (end.size() > 0 - ? checkRangeRead(root, begin, end, - InternalVersionT(reads[i].readVersion), &tls) - : checkPointRead(root, begin, - InternalVersionT(reads[i].readVersion), &tls)) - ? Commit - : Conflict; tls.commits_accum += result[i] == Commit; tls.conflicts_accum += result[i] == Conflict; tls.too_olds_accum += result[i] == TooOld; } + point_read_total.add(tls.point_read_accum); prefix_read_total.add(tls.prefix_read_accum); range_read_total.add(tls.range_read_accum);