From 2706b2f65ec7311cf80eed859a496e37e2fad22b Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Mon, 28 Oct 2024 16:02:56 -0700 Subject: [PATCH] Implement "phase 1" of interleaved point writes --- ConflictSet.cpp | 194 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 184 insertions(+), 10 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 98a0cdd..b445612 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -4057,6 +4057,125 @@ void Job::init(const ConflictSet::ReadRange *read, ConflictSet::Result *result, } } // namespace check +namespace interleaved_insert { + +typedef PRESERVE_NONE void (*Continuation)(struct Job *, struct Context *); + +// State relevant to an individual insertion +struct Job { + std::span remaining; + Node *n; + int index; + + // State for context switching machinery - not application specific + Continuation continuation; + Job *prev; + Job *next; + void init(Context *, int index); +}; + +// Result of an insertion. The search path of insertionPoint + remaining == the +// original key, and there is existing node in the tree further along the search +// path of the original key +struct Result { + Node *insertionPoint; + std::span remaining; +}; + +// State relevant to every insertion +struct Context { + int count; + int64_t started; + const ConflictSet::WriteRange *writes; + Node *root; + InternalVersionT writeVersion; + Result *results; +}; + +PRESERVE_NONE void keepGoing(Job *job, Context *context) { + fprintf(stderr, "search path: %s, Remaining: %s\n", + getSearchPathPrintable(job->n).c_str(), + printable(job->remaining).c_str()); + job = job->next; + MUSTTAIL return job->continuation(job, context); +} + +PRESERVE_NONE void complete(Job *job, Context *context) { + fprintf(stderr, "search path: %s, Remaining: %s\n", + getSearchPathPrintable(job->n).c_str(), + printable(job->remaining).c_str()); + if (context->started == context->count) { + if (job->prev == job) { + return; + } + job->prev->next = job->next; + job->next->prev = job->prev; + job = job->prev; + } else { + int temp = context->started++; + job->init(context, temp); + } + MUSTTAIL return keepGoing(job, context); +} + +template PRESERVE_NONE void iter(Job *, Context *); + +static Continuation iterTable[] = {iter, iter, iter, + iter, iter}; + +PRESERVE_NONE void begin(Job *job, Context *context) { + if (job->remaining.size() == 0) [[unlikely]] { + context->results[job->index] = {job->n, job->remaining}; + MUSTTAIL return complete(job, context); + } + + TaggedNodePointer *child = + getChildUpdatingMaxVersion(job->n, job->remaining, context->writeVersion); + + if (child == nullptr) [[unlikely]] { + context->results[job->index] = {job->n, job->remaining}; + MUSTTAIL return complete(job, context); + } + job->n = *child; + if (job->remaining.size() == 0) [[unlikely]] { + context->results[job->index] = {job->n, job->remaining}; + MUSTTAIL return complete(job, context); + } + job->continuation = iterTable[child->getType()]; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); +} + +template void iter(Job *job, Context *context) { + assert(NodeT::kType == job->n->getType()); + NodeT *n = static_cast(job->n); + + TaggedNodePointer *child = + getChildUpdatingMaxVersion(n, job->remaining, context->writeVersion); + if (child == nullptr) [[unlikely]] { + context->results[job->index] = {job->n, job->remaining}; + MUSTTAIL return complete(job, context); + } + job->n = *child; + if (job->remaining.size() == 0) [[unlikely]] { + context->results[job->index] = {job->n, job->remaining}; + MUSTTAIL return complete(job, context); + } + job->continuation = iterTable[child->getType()]; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); +} + +void Job::init(Context *context, int index) { + this->index = index; + this->continuation = interleaved_insert::begin; + this->remaining = std::span(context->writes[index].begin.p, + context->writes[index].begin.len); + this->n = context->root; +} + +} // namespace interleaved_insert + // Sequential implementations namespace { // Logically this is the same as performing firstGeq and then checking against @@ -4583,6 +4702,50 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { check_bytes_total.add(check_byte_accum); } + void interleavedPointWrites(const WriteRange *writes, int count, + InternalVersionT writeVersion) { + // Phase 1: Search for insertion points concurrently, without modifying the + // structure of the tree. + + if (count == 0) { + return; + } + +#if __has_attribute(preserve_none) && __has_attribute(musttail) + constexpr int kConcurrent = 16; + interleaved_insert::Job inProgress[kConcurrent]; + interleaved_insert::Context context; + context.writeVersion = writeVersion; + context.count = count; + context.root = root; + context.writes = writes; + context.results = (interleaved_insert::Result *)safe_malloc( + sizeof(interleaved_insert::Result) * count); + int64_t started = std::min(kConcurrent, count); + context.started = started; + for (int i = 0; i < started; i++) { + inProgress[i].init(&context, i); + } + for (int i = 0; i < started - 1; i++) { + inProgress[i].next = inProgress + i + 1; + } + for (int i = 1; i < started; i++) { + inProgress[i].prev = inProgress + i - 1; + } + inProgress[0].prev = inProgress + started - 1; + inProgress[started - 1].next = inProgress; + + // Kick off the sequence of tail calls that finally returns once all jobs + // are done + inProgress->continuation(inProgress, &context); + +#endif + + // Phase 2: Perform insertions. Nodes may be upsized during this phase, but + // old nodes get forwarding pointers installed and are released after + // phase 2. + } + void addWrites(const WriteRange *writes, int count, int64_t writeVersion) { #if !USE_64_BIT // There could be other conflict sets in the same thread. We need @@ -4624,17 +4787,28 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { } } + bool allPointWrites = true; for (int i = 0; i < count; ++i) { - const auto &w = writes[i]; - writeContext.accum.write_bytes += w.begin.len + w.end.len; - auto begin = std::span(w.begin.p, w.begin.len); - auto end = std::span(w.end.p, w.end.len); - if (w.end.len > 0) { - addWriteRange(root, begin, end, InternalVersionT(writeVersion), - &writeContext, this); - } else { - addPointWrite(root, begin, InternalVersionT(writeVersion), - &writeContext); + if (writes[i].end.len > 0) { + allPointWrites = false; + break; + } + } + if (0 && allPointWrites) { + interleavedPointWrites(writes, count, InternalVersionT(writeVersion)); + } else { + for (int i = 0; i < count; ++i) { + const auto &w = writes[i]; + writeContext.accum.write_bytes += w.begin.len + w.end.len; + auto begin = std::span(w.begin.p, w.begin.len); + auto end = std::span(w.end.p, w.end.len); + if (w.end.len > 0) { + addWriteRange(root, begin, end, InternalVersionT(writeVersion), + &writeContext, this); + } else { + addPointWrite(root, begin, InternalVersionT(writeVersion), + &writeContext); + } } }