From 6de63dd3fe5e6d458644a55436b55e82de79ec39 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Mon, 23 Sep 2024 14:53:16 -0700 Subject: [PATCH] Use preserve_none and put continuation array in CheckAll --- ConflictSet.cpp | 113 +++++++++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 49 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 95dad91..03c97da 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -3045,8 +3045,27 @@ Node *firstGeqPhysical(Node *n, const std::span key) { } } -typedef void (*continuation)(struct CheckAll *, int64_t prevJob, int64_t job, - int64_t started, int64_t count); +#ifndef __has_attribute +#define __has_attribute(x) 0 +#endif + +#if __has_attribute(musttail) +#define MUSTTAIL __attribute__((musttail)) +#else +#define MUSTTAIL +#endif + +#if __has_attribute(preserve_none) +#define CONTINUATION_CALLING_CONVENTION __attribute__((preserve_none)) +#else +#define CONTINUATION_CALLING_CONVENTION +#endif + +typedef CONTINUATION_CALLING_CONVENTION void (*continuation)(struct CheckAll *, + int64_t prevJob, + int64_t job, + int64_t started, + int64_t count); // State relevant to a particular query struct CheckJob { @@ -3054,10 +3073,11 @@ struct CheckJob { *result = ok ? ConflictSet::Commit : ConflictSet::Conflict; } - void init(const ConflictSet::ReadRange *read, ConflictSet::Result *result, - Node *root, int64_t oldestVersionFullPrecision, ReadContext *tls); + [[nodiscard]] continuation init(const ConflictSet::ReadRange *read, + ConflictSet::Result *result, Node *root, + int64_t oldestVersionFullPrecision, + ReadContext *tls); - continuation next; Node *n; ChildAndMaxVersion childAndVersion; std::span begin; @@ -3071,32 +3091,24 @@ struct CheckAll { const ConflictSet::ReadRange *queries; ConflictSet::Result *results; CheckJob inProgress[kConcurrent]; + continuation next[kConcurrent]; int nextJob[kConcurrent]; Node *root; int64_t oldestVersionFullPrecision; ReadContext *tls; }; -#ifndef __has_attribute -#define __has_attribute(x) 0 -#endif - -#if __has_attribute(musttail) -#define MUSTTAIL __attribute__((musttail)) -#else -#define MUSTTAIL -#endif - -void keepGoing(CheckAll *context, int64_t prevJob, int64_t job, int64_t started, - int64_t count) { +CONTINUATION_CALLING_CONVENTION void keepGoing(CheckAll *context, + int64_t prevJob, int64_t job, + int64_t started, int64_t count) { prevJob = job; job = context->nextJob[job]; - MUSTTAIL return context->inProgress[job].next(context, prevJob, job, started, - count); + MUSTTAIL return context->next[job](context, prevJob, job, started, count); } -void complete(CheckAll *context, int64_t prevJob, int64_t job, int64_t started, - int64_t count) { +CONTINUATION_CALLING_CONVENTION void complete(CheckAll *context, + int64_t prevJob, int64_t job, + int64_t started, int64_t count) { if (started == count) { if (prevJob == job) { return; @@ -3105,24 +3117,26 @@ void complete(CheckAll *context, int64_t prevJob, int64_t job, int64_t started, job = prevJob; } else { int temp = started++; - context->inProgress[job].init( + context->next[job] = context->inProgress[job].init( context->queries + temp, context->results + temp, context->root, context->oldestVersionFullPrecision, context->tls); } prevJob = job; job = context->nextJob[job]; - MUSTTAIL return context->inProgress[job].next(context, prevJob, job, started, - count); + MUSTTAIL return context->next[job](context, prevJob, job, started, count); } namespace check_point_read_state_machine { -void down_left_spine(struct CheckAll *, int64_t prevJob, int64_t job, - int64_t started, int64_t count); -void iter(struct CheckAll *, int64_t prevJob, int64_t job, int64_t started, - int64_t count); -void begin(struct CheckAll *, int64_t prevJob, int64_t job, int64_t started, - int64_t count); +CONTINUATION_CALLING_CONVENTION void +down_left_spine(struct CheckAll *, int64_t prevJob, int64_t job, + int64_t started, int64_t count); +CONTINUATION_CALLING_CONVENTION void iter(struct CheckAll *, int64_t prevJob, + int64_t job, int64_t started, + int64_t count); +CONTINUATION_CALLING_CONVENTION void begin(struct CheckAll *, int64_t prevJob, + int64_t job, int64_t started, + int64_t count); void begin(struct CheckAll *context, int64_t prevJob, int64_t job, int64_t started, int64_t count) { @@ -3138,13 +3152,13 @@ void begin(struct CheckAll *context, int64_t prevJob, int64_t job, MUSTTAIL return complete(context, prevJob, job, started, count); } j->n = getFirstChildExists(j->n); - j->next = down_left_spine; + context->next[job] = down_left_spine; __builtin_prefetch(j->n); MUSTTAIL return keepGoing(context, prevJob, job, started, count); } j->childAndVersion = getChildAndMaxVersion(j->n, j->begin[0]); - j->next = iter; + context->next[job] = iter; __builtin_prefetch(j->childAndVersion.child); MUSTTAIL return keepGoing(context, prevJob, job, started, count); } @@ -3156,7 +3170,7 @@ void iter(struct CheckAll *context, int64_t prevJob, int64_t job, auto c = getChildGeq(j->n, j->begin[0]); if (c != nullptr) { j->n = c; - j->next = down_left_spine; + context->next[job] = down_left_spine; __builtin_prefetch(j->n); MUSTTAIL return keepGoing(context, prevJob, job, started, count); } else { @@ -3165,7 +3179,7 @@ void iter(struct CheckAll *context, int64_t prevJob, int64_t job, j->setResult(true); MUSTTAIL return complete(context, prevJob, job, started, count); } - j->next = down_left_spine; + context->next[job] = down_left_spine; __builtin_prefetch(j->n); MUSTTAIL return keepGoing(context, prevJob, job, started, count); } @@ -3180,15 +3194,15 @@ void iter(struct CheckAll *context, int64_t prevJob, int64_t job, if (i < commonLen) { auto c = j->n->partialKey()[i] <=> j->begin[i]; if (c > 0) { - j->next = down_left_spine; - MUSTTAIL return j->next(context, prevJob, job, started, count); + context->next[job] = down_left_spine; + MUSTTAIL return down_left_spine(context, prevJob, job, started, count); } else { j->n = nextSibling(j->n); if (j->n == nullptr) { j->setResult(true); MUSTTAIL return complete(context, prevJob, job, started, count); } - j->next = down_left_spine; + context->next[job] = down_left_spine; __builtin_prefetch(j->n); MUSTTAIL return keepGoing(context, prevJob, job, started, count); } @@ -3199,8 +3213,8 @@ void iter(struct CheckAll *context, int64_t prevJob, int64_t job, } else if (j->n->partialKeyLen > int(j->begin.size())) { // n is the first physical node greater than remaining, and there's no // eq node - j->next = down_left_spine; - MUSTTAIL return j->next(context, prevJob, job, started, count); + context->next[job] = down_left_spine; + MUSTTAIL return down_left_spine(context, prevJob, job, started, count); } } @@ -3218,7 +3232,7 @@ void iter(struct CheckAll *context, int64_t prevJob, int64_t job, MUSTTAIL return complete(context, prevJob, job, started, count); } j->n = getFirstChildExists(j->n); - j->next = down_left_spine; + context->next[job] = down_left_spine; __builtin_prefetch(j->n); MUSTTAIL return keepGoing(context, prevJob, job, started, count); } @@ -3244,20 +3258,21 @@ void down_left_spine(struct CheckAll *context, int64_t prevJob, int64_t job, } // namespace check_point_read_state_machine -void CheckJob::init(const ConflictSet::ReadRange *read, - ConflictSet::Result *result, Node *root, - int64_t oldestVersionFullPrecision, ReadContext *tls) { +continuation CheckJob::init(const ConflictSet::ReadRange *read, + ConflictSet::Result *result, Node *root, + int64_t oldestVersionFullPrecision, + ReadContext *tls) { auto begin = std::span(read->begin.p, read->begin.len); auto end = std::span(read->end.p, read->end.len); if (read->readVersion < oldestVersionFullPrecision) { *result = ConflictSet::TooOld; - next = complete; + return complete; } else if (end.size() == 0) { this->begin = begin; this->n = root; this->readVersion = InternalVersionT(read->readVersion); this->result = result; - this->next = check_point_read_state_machine::begin; + return check_point_read_state_machine::begin; // *result = // checkPointRead(root, begin, InternalVersionT(read->readVersion), tls) // ? ConflictSet::Commit @@ -3268,7 +3283,7 @@ void CheckJob::init(const ConflictSet::ReadRange *read, InternalVersionT(read->readVersion), tls) ? ConflictSet::Commit : ConflictSet::Conflict; - next = complete; + return complete; } } @@ -3295,14 +3310,14 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { int64_t started = std::min(context.kConcurrent, count); for (int i = 0; i < started; i++) { - context.inProgress[i].init(reads + i, result + i, root, - oldestVersionFullPrecision, &tls); + context.next[i] = context.inProgress[i].init( + reads + i, result + i, root, oldestVersionFullPrecision, &tls); context.nextJob[i] = i + 1; } context.nextJob[started - 1] = 0; int prevJob = started - 1; int job = 0; - context.inProgress[job].next(&context, prevJob, job, started, count); + context.next[job](&context, prevJob, job, started, count); for (int i = 0; i < count; ++i) { assert(reads[i].readVersion >= 0);