diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 9bf5dd0..95dad91 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -3045,21 +3045,15 @@ Node *firstGeqPhysical(Node *n, const std::span key) { } } +typedef void (*continuation)(struct CheckAll *, int64_t prevJob, int64_t job, + int64_t started, int64_t count); + +// State relevant to a particular query struct CheckJob { void setResult(bool ok) { *result = ok ? ConflictSet::Commit : ConflictSet::Conflict; } - // The type of a function that takes a CheckJob* and returns its own type - struct continuation { - typedef continuation (*functionPtrType)(CheckJob *); - functionPtrType func; - continuation operator()(CheckJob *job) { return func(job); } - /*implicit*/ continuation(functionPtrType func) : func(func) {} - continuation() = default; - operator bool() { return func != nullptr; } - }; - void init(const ConflictSet::ReadRange *read, ConflictSet::Result *result, Node *root, int64_t oldestVersionFullPrecision, ReadContext *tls); @@ -3068,116 +3062,184 @@ struct CheckJob { ChildAndMaxVersion childAndVersion; std::span begin; InternalVersionT readVersion; - ReadContext *tls; ConflictSet::Result *result; }; +// State relevant to all queries +struct CheckAll { + constexpr static int kConcurrent = 32; + const ConflictSet::ReadRange *queries; + ConflictSet::Result *results; + CheckJob inProgress[kConcurrent]; + int nextJob[kConcurrent]; + Node *root; + int64_t oldestVersionFullPrecision; + ReadContext *tls; +}; + +#ifndef __has_attribute +#define __has_attribute(x) 0 +#endif + +#if __has_attribute(musttail) +#define MUSTTAIL __attribute__((musttail)) +#else +#define MUSTTAIL +#endif + +void keepGoing(CheckAll *context, int64_t prevJob, int64_t job, int64_t started, + int64_t count) { + prevJob = job; + job = context->nextJob[job]; + MUSTTAIL return context->inProgress[job].next(context, prevJob, job, started, + count); +} + +void complete(CheckAll *context, int64_t prevJob, int64_t job, int64_t started, + int64_t count) { + if (started == count) { + if (prevJob == job) { + return; + } + context->nextJob[prevJob] = context->nextJob[job]; + job = prevJob; + } else { + int temp = started++; + context->inProgress[job].init( + context->queries + temp, context->results + temp, context->root, + context->oldestVersionFullPrecision, context->tls); + } + prevJob = job; + job = context->nextJob[job]; + MUSTTAIL return context->inProgress[job].next(context, prevJob, job, started, + count); +} + namespace check_point_read_state_machine { -CheckJob::continuation down_left_spine(CheckJob *job); -CheckJob::continuation iter(CheckJob *job); +void down_left_spine(struct CheckAll *, int64_t prevJob, int64_t job, + int64_t started, int64_t count); +void iter(struct CheckAll *, int64_t prevJob, int64_t job, int64_t started, + int64_t count); +void begin(struct CheckAll *, int64_t prevJob, int64_t job, int64_t started, + int64_t count); -CheckJob::continuation begin(CheckJob *job) { - ++job->tls->point_read_accum; +void begin(struct CheckAll *context, int64_t prevJob, int64_t job, + int64_t started, int64_t count) { + ++context->tls->point_read_accum; #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); #endif + auto *j = context->inProgress + job; - if (job->begin.size() == 0) { - if (job->n->entryPresent) { - job->setResult(job->n->entry.pointVersion <= job->readVersion); - return nullptr; // Done + if (j->begin.size() == 0) { + if (j->n->entryPresent) { + j->setResult(j->n->entry.pointVersion <= j->readVersion); + MUSTTAIL return complete(context, prevJob, job, started, count); } - job->n = getFirstChildExists(job->n); - __builtin_prefetch(job->n); - return down_left_spine; + j->n = getFirstChildExists(j->n); + j->next = down_left_spine; + __builtin_prefetch(j->n); + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } - job->childAndVersion = getChildAndMaxVersion(job->n, job->begin[0]); - __builtin_prefetch(job->childAndVersion.child); - return iter; + j->childAndVersion = getChildAndMaxVersion(j->n, j->begin[0]); + j->next = iter; + __builtin_prefetch(j->childAndVersion.child); + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } -CheckJob::continuation iter(CheckJob *job) { - if (job->childAndVersion.child == nullptr) { - auto c = getChildGeq(job->n, job->begin[0]); +void iter(struct CheckAll *context, int64_t prevJob, int64_t job, + int64_t started, int64_t count) { + auto *j = context->inProgress + job; + if (j->childAndVersion.child == nullptr) { + auto c = getChildGeq(j->n, j->begin[0]); if (c != nullptr) { - job->n = c; - __builtin_prefetch(job->n); - return down_left_spine; + j->n = c; + j->next = down_left_spine; + __builtin_prefetch(j->n); + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } else { - job->n = nextSibling(job->n); - if (job->n == nullptr) { - job->setResult(true); - return nullptr; // Done + j->n = nextSibling(j->n); + if (j->n == nullptr) { + j->setResult(true); + MUSTTAIL return complete(context, prevJob, job, started, count); } - __builtin_prefetch(job->n); - return down_left_spine; + j->next = down_left_spine; + __builtin_prefetch(j->n); + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } } - job->n = job->childAndVersion.child; - job->begin = job->begin.subspan(1, job->begin.size() - 1); + j->n = j->childAndVersion.child; + j->begin = j->begin.subspan(1, j->begin.size() - 1); - if (job->n->partialKeyLen > 0) { - int commonLen = std::min(job->n->partialKeyLen, job->begin.size()); - int i = - longestCommonPrefix(job->n->partialKey(), job->begin.data(), commonLen); + if (j->n->partialKeyLen > 0) { + int commonLen = std::min(j->n->partialKeyLen, j->begin.size()); + int i = longestCommonPrefix(j->n->partialKey(), j->begin.data(), commonLen); if (i < commonLen) { - auto c = job->n->partialKey()[i] <=> job->begin[i]; + auto c = j->n->partialKey()[i] <=> j->begin[i]; if (c > 0) { - return down_left_spine(job); + j->next = down_left_spine; + MUSTTAIL return j->next(context, prevJob, job, started, count); } else { - job->n = nextSibling(job->n); - if (job->n == nullptr) { - job->setResult(true); - return nullptr; // Done + j->n = nextSibling(j->n); + if (j->n == nullptr) { + j->setResult(true); + MUSTTAIL return complete(context, prevJob, job, started, count); } - __builtin_prefetch(job->n); - return down_left_spine; + j->next = down_left_spine; + __builtin_prefetch(j->n); + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } } - if (commonLen == job->n->partialKeyLen) { + if (commonLen == j->n->partialKeyLen) { // partial key matches - job->begin = job->begin.subspan(commonLen, job->begin.size() - commonLen); - } else if (job->n->partialKeyLen > int(job->begin.size())) { + j->begin = j->begin.subspan(commonLen, j->begin.size() - commonLen); + } else if (j->n->partialKeyLen > int(j->begin.size())) { // n is the first physical node greater than remaining, and there's no // eq node - return down_left_spine(job); + j->next = down_left_spine; + MUSTTAIL return j->next(context, prevJob, job, started, count); } } - if (job->childAndVersion.maxVersion <= job->readVersion) { - ++job->tls->point_read_short_circuit_accum; - job->setResult(true); - return nullptr; // Done + if (j->childAndVersion.maxVersion <= j->readVersion) { + ++context->tls->point_read_short_circuit_accum; + j->setResult(true); + MUSTTAIL return complete(context, prevJob, job, started, count); } - ++job->tls->point_read_iterations_accum; + ++context->tls->point_read_iterations_accum; - if (job->begin.size() == 0) { - if (job->n->entryPresent) { - job->setResult(job->n->entry.pointVersion <= job->readVersion); - return nullptr; // Done + if (j->begin.size() == 0) { + if (j->n->entryPresent) { + j->setResult(j->n->entry.pointVersion <= j->readVersion); + MUSTTAIL return complete(context, prevJob, job, started, count); } - job->n = getFirstChildExists(job->n); - __builtin_prefetch(job->n); - return down_left_spine; + j->n = getFirstChildExists(j->n); + j->next = down_left_spine; + __builtin_prefetch(j->n); + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } - job->childAndVersion = getChildAndMaxVersion(job->n, job->begin[0]); - __builtin_prefetch(job->childAndVersion.child); - return iter; + j->childAndVersion = getChildAndMaxVersion(j->n, j->begin[0]); + __builtin_prefetch(j->childAndVersion.child); + // j->next is already iter + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } -CheckJob::continuation down_left_spine(CheckJob *job) { - if (job->n->entryPresent) { - job->setResult(job->n->entry.rangeVersion <= job->readVersion); - return nullptr; // Done +void down_left_spine(struct CheckAll *context, int64_t prevJob, int64_t job, + int64_t started, int64_t count) { + auto *j = context->inProgress + job; + if (j->n->entryPresent) { + j->setResult(j->n->entry.rangeVersion <= j->readVersion); + MUSTTAIL return complete(context, prevJob, job, started, count); } - job->n = getFirstChildExists(job->n); - __builtin_prefetch(job->n); - return down_left_spine; + j->n = getFirstChildExists(j->n); + __builtin_prefetch(j->n); + // j->next is already down_left_spine + MUSTTAIL return keepGoing(context, prevJob, job, started, count); } } // namespace check_point_read_state_machine @@ -3189,25 +3251,24 @@ void CheckJob::init(const ConflictSet::ReadRange *read, auto end = std::span(read->end.p, read->end.len); if (read->readVersion < oldestVersionFullPrecision) { *result = ConflictSet::TooOld; - next = +[](CheckJob *) -> continuation { return nullptr; }; + next = complete; } else if (end.size() == 0) { this->begin = begin; this->n = root; this->readVersion = InternalVersionT(read->readVersion); this->result = result; - this->tls = tls; this->next = check_point_read_state_machine::begin; // *result = // checkPointRead(root, begin, InternalVersionT(read->readVersion), tls) // ? ConflictSet::Commit // : ConflictSet::Conflict; - // next = +[](CheckJob *) -> continuation { return nullptr; }; + // next = complete; } else { *result = checkRangeRead(root, begin, end, InternalVersionT(read->readVersion), tls) ? ConflictSet::Commit : ConflictSet::Conflict; - next = +[](CheckJob *) -> continuation { return nullptr; }; + next = complete; } } @@ -3225,38 +3286,23 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { tls.impl = this; int64_t check_byte_accum = 0; - constexpr int kConcurrent = 32; - CheckJob inProgress[kConcurrent]; - int nextJob[kConcurrent]; + CheckAll context; + context.oldestVersionFullPrecision = oldestVersionFullPrecision; + context.queries = reads; + context.results = result; + context.root = root; + context.tls = &tls; - int started = std::min(kConcurrent, count); + int64_t started = std::min(context.kConcurrent, count); for (int i = 0; i < started; i++) { - inProgress[i].init(reads + i, result + i, root, - oldestVersionFullPrecision, &tls); - nextJob[i] = i + 1; + context.inProgress[i].init(reads + i, result + i, root, + oldestVersionFullPrecision, &tls); + context.nextJob[i] = i + 1; } - nextJob[started - 1] = 0; - + context.nextJob[started - 1] = 0; int prevJob = started - 1; int job = 0; - for (;;) { - auto next = inProgress[job].next(inProgress + job); - inProgress[job].next = next; - if (!next) { - if (started == count) { - if (prevJob == job) - break; - nextJob[prevJob] = nextJob[job]; - job = prevJob; - } else { - int temp = started++; - inProgress[job].init(reads + temp, result + temp, root, - oldestVersionFullPrecision, &tls); - } - } - prevJob = job; - job = nextJob[job]; - } + context.inProgress[job].next(&context, prevJob, job, started, count); for (int i = 0; i < count; ++i) { assert(reads[i].readVersion >= 0);