From 2c1c26bc88be6354735e83d56ba91994072126d9 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Wed, 30 Oct 2024 11:01:23 -0700 Subject: [PATCH] Enable interleaved range writes --- ConflictSet.cpp | 237 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 206 insertions(+), 31 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 203cfc3..1e4fb54 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -3113,13 +3113,16 @@ struct AddedWriteRange { Node *endNode; }; -AddedWriteRange -addWriteRange(TaggedNodePointer &beginRoot, std::span begin, - TaggedNodePointer &endRoot, std::span end, - InternalVersionT writeVersion, WriteContext *writeContext) { +AddedWriteRange addWriteRange(Node *beginRoot, std::span begin, + Node *endRoot, std::span end, + InternalVersionT writeVersion, + WriteContext *writeContext, + ConflictSet::Impl *impl) { + ++writeContext->accum.range_writes; - Node *beginNode = *insert(&beginRoot, begin, writeVersion, writeContext); + Node *beginNode = + *insert(&getInTree(beginRoot, impl), begin, writeVersion, writeContext); addKey(beginNode); if (!beginNode->entryPresent) { ++writeContext->accum.entries_inserted; @@ -3131,7 +3134,12 @@ addWriteRange(TaggedNodePointer &beginRoot, std::span begin, } beginNode->entry.pointVersion = writeVersion; - Node *endNode = *insert(&endRoot, end, writeVersion, writeContext); + while (endRoot->releaseDeferred) { + endRoot = endRoot->forwardTo; + } + Node *endNode = + *insert(&getInTree(endRoot, impl), end, writeVersion, writeContext); + addKey(endNode); if (!endNode->entryPresent) { ++writeContext->accum.entries_inserted; @@ -3182,7 +3190,7 @@ void addWriteRange(TaggedNodePointer &root, std::span begin, auto [beginNode, endNode] = addWriteRange( *useAsRoot, begin.subspan(lcp, begin.size() - lcp), *useAsRoot, - end.subspan(lcp, end.size() - lcp), writeVersion, writeContext); + end.subspan(lcp, end.size() - lcp), writeVersion, writeContext, impl); eraseInRange(beginNode, endNode, writeContext, impl); } @@ -3350,10 +3358,6 @@ static Continuation iterTable[] = {iter, iter, iter, void begin(Job *job, Context *context) { ++context->readContext.point_read_accum; -#if DEBUG_VERBOSE && !defined(NDEBUG) - fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); -#endif - if (job->begin.size() == 0) [[unlikely]] { // We don't erase the root assert(job->n->entryPresent); @@ -3479,9 +3483,6 @@ static Continuation iterTable[] = {iter, iter, iter, void begin(Job *job, Context *context) { ++context->readContext.prefix_read_accum; -#if DEBUG_VERBOSE && !defined(NDEBUG) - fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str()); -#endif // There's no way to encode a prefix read of "" assert(job->begin.size() > 0); @@ -4098,6 +4099,10 @@ struct Job { std::span remaining; Node *n; int index; + std::span begin; // Range write only + std::span end; // Range write only + Node *endNode; // Range write only + int commonPrefixLen; // Range write only // State for context switching machinery - not application specific Continuation continuation; @@ -4112,6 +4117,9 @@ struct Job { struct Result { Node *insertionPoint; std::span remaining; + + Node *endInsertionPoint; // Range write only + std::span endRemaining; // Range write only }; // State relevant to every insertion @@ -4146,43 +4154,172 @@ PRESERVE_NONE void complete(Job *job, Context *context) { } } -template PRESERVE_NONE void iter(Job *, Context *); +template PRESERVE_NONE void pointIter(Job *, Context *); -static Continuation iterTable[] = {iter, iter, iter, - iter, iter}; +static Continuation pointIterTable[] = {pointIter, pointIter, + pointIter, pointIter, + pointIter}; -template void iter(Job *job, Context *context) { +template void pointIter(Job *job, Context *context) { assert(NodeT::kType == job->n->getType()); NodeT *n = static_cast(job->n); TaggedNodePointer *child = getChildUpdatingMaxVersion(n, job->remaining, context->writeVersion); if (child == nullptr) [[unlikely]] { - context->results[job->index] = {job->n, job->remaining}; + context->results[job->index] = {job->n, job->remaining, nullptr, {}}; MUSTTAIL return complete(job, context); } job->n = *child; if (job->remaining.size() == 0) [[unlikely]] { - context->results[job->index] = {job->n, job->remaining}; + context->results[job->index] = {job->n, job->remaining, nullptr, {}}; MUSTTAIL return complete(job, context); } ++context->iterations; - job->continuation = iterTable[child->getType()]; + job->continuation = pointIterTable[child->getType()]; __builtin_prefetch(job->n); MUSTTAIL return keepGoing(job, context); } +template PRESERVE_NONE void commonPrefixIter(Job *, Context *); +template PRESERVE_NONE void beginIter(Job *, Context *); +template PRESERVE_NONE void endIter(Job *, Context *); + +static Continuation commonPrefixIterTable[] = { + commonPrefixIter, commonPrefixIter, commonPrefixIter, + commonPrefixIter, commonPrefixIter}; + +static Continuation beginIterTable[] = {beginIter, beginIter, + beginIter, beginIter, + beginIter}; + +static Continuation endIterTable[] = {endIter, endIter, + endIter, endIter, + endIter}; + +template void commonPrefixIter(Job *job, Context *context) { + assert(NodeT::kType == job->n->getType()); + NodeT *n = static_cast(job->n); + + TaggedNodePointer *child = + getChildUpdatingMaxVersion(n, job->remaining, context->writeVersion); + if (child == nullptr) [[unlikely]] { + int prefixLen = job->commonPrefixLen - job->remaining.size(); + assert(prefixLen >= 0); + assert(job->n != nullptr); + context->results[job->index] = { + job->n, + job->begin.subspan(prefixLen, job->begin.size() - prefixLen), + job->n, + job->end.subspan(prefixLen, job->end.size() - prefixLen), + }; + MUSTTAIL return complete(job, context); + } + job->n = *child; + ++context->iterations; + if (job->remaining.size() == 0) [[unlikely]] { + job->endNode = job->n; + job->begin = job->begin.subspan(job->commonPrefixLen, + job->begin.size() - job->commonPrefixLen); + job->end = job->end.subspan(job->commonPrefixLen, + job->end.size() - job->commonPrefixLen); + if (job->begin.size() == 0) [[unlikely]] { + job->continuation = endIterTable[child->getType()]; + } else { + job->continuation = beginIterTable[child->getType()]; + } + } else { + job->continuation = commonPrefixIterTable[child->getType()]; + } + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); +} + +template void beginIter(Job *job, Context *context) { + assert(NodeT::kType == job->n->getType()); + NodeT *n = static_cast(job->n); + + TaggedNodePointer *child = + getChildUpdatingMaxVersion(n, job->begin, context->writeVersion); + if (child == nullptr) [[unlikely]] { + MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + } + job->n = *child; + if (job->begin.size() == 0) [[unlikely]] { + MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + } + ++context->iterations; + job->continuation = beginIterTable[child->getType()]; + __builtin_prefetch(job->n); + MUSTTAIL return keepGoing(job, context); +} + +template void endIter(Job *job, Context *context) { + assert(NodeT::kType == job->endNode->getType()); + NodeT *endNode = static_cast(job->endNode); + + TaggedNodePointer *child = + getChildUpdatingMaxVersion(endNode, job->end, context->writeVersion); + if (child == nullptr) [[unlikely]] { + context->results[job->index] = {job->n, job->begin, job->endNode, job->end}; + assert(job->endNode != nullptr); + MUSTTAIL return complete(job, context); + } + job->endNode = *child; + if (job->remaining.size() == 0) [[unlikely]] { + context->results[job->index] = {job->n, job->begin, job->endNode, job->end}; + assert(job->endNode != nullptr); + MUSTTAIL return complete(job, context); + } + ++context->iterations; + job->continuation = endIterTable[child->getType()]; + __builtin_prefetch(job->endNode); + MUSTTAIL return keepGoing(job, context); +} + void Job::init(Context *context, int index) { this->index = index; - remaining = std::span(context->writes[index].begin.p, - context->writes[index].begin.len); n = context->root; - if (remaining.size() == 0) [[unlikely]] { - context->results[index] = {n, remaining}; - continuation = interleaved_insert::complete; + if (context->writes[index].end.len == 0) { + goto pointWrite; + } + + begin = std::span(context->writes[index].begin.p, + context->writes[index].begin.len); + end = std::span(context->writes[index].end.p, + context->writes[index].end.len); + + commonPrefixLen = longestCommonPrefix(begin.data(), end.data(), + std::min(begin.size(), end.size())); + if (commonPrefixLen == int(begin.size()) && end.size() == begin.size() + 1 && + end.back() == 0) { + goto pointWrite; + } + + remaining = + std::span(context->writes[index].begin.p, commonPrefixLen); + + if (commonPrefixLen > 0) { + // common prefix iter will set endNode + continuation = commonPrefixIterTable[n->getType()]; + } else if (begin.size() > 0) { + endNode = n; + continuation = beginIterTable[n->getType()]; } else { - continuation = iterTable[n->getType()]; + endNode = n; + continuation = endIterTable[n->getType()]; + } + + return; +pointWrite: + remaining = std::span(context->writes[index].begin.p, + context->writes[index].begin.len); + if (remaining.size() == 0) [[unlikely]] { + context->results[index] = {n, remaining, nullptr, {}}; + continuation = complete; + } else { + continuation = pointIterTable[n->getType()]; } } @@ -4763,12 +4900,51 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { // phase 2. for (int i = 0; i < count; ++i) { + +#if DEBUG_VERBOSE && !defined(NDEBUG) + fprintf(stderr, "search path: %s, begin: %s\n", + getSearchPathPrintable(context.results[i].insertionPoint).c_str(), + printable(writes[i].begin).c_str()); + fprintf( + stderr, "search path: %s, end: %s\n", + getSearchPathPrintable(context.results[i].endInsertionPoint).c_str(), + printable(writes[i].end).c_str()); +#endif + while (context.results[i].insertionPoint->releaseDeferred) { context.results[i].insertionPoint = context.results[i].insertionPoint->forwardTo; } - addPointWrite(getInTree(context.results[i].insertionPoint, this), - context.results[i].remaining, writeVersion, &writeContext); + if (context.results[i].endInsertionPoint == nullptr) { + addPointWrite(getInTree(context.results[i].insertionPoint, this), + context.results[i].remaining, writeVersion, + &writeContext); + } else { + auto [beginNode, endNode] = addWriteRange( + context.results[i].insertionPoint, context.results[i].remaining, + context.results[i].endInsertionPoint, + context.results[i].endRemaining, writeVersion, &writeContext, this); + context.results[i].insertionPoint = beginNode; + context.results[i].endInsertionPoint = endNode; + } + } + + // Phase 3: Erase nodes within written ranges. Going left to right ensures + // that nothing later is on the search path of anything earlier, so we don't + // encounter invalidated nodes. + for (int i = 0; i < count; ++i) { + if (context.results[i].endInsertionPoint != nullptr) { + while (context.results[i].insertionPoint->releaseDeferred) { + context.results[i].insertionPoint = + context.results[i].insertionPoint->forwardTo; + } + while (context.results[i].endInsertionPoint->releaseDeferred) { + context.results[i].endInsertionPoint = + context.results[i].endInsertionPoint->forwardTo; + } + eraseInRange(context.results[i].insertionPoint, + context.results[i].endInsertionPoint, &writeContext, this); + } } if (count > int(sizeof(stackResults) / sizeof(stackResults[0]))) @@ -4792,8 +4968,7 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { #endif #if __has_attribute(preserve_none) && __has_attribute(musttail) - // TODO make this work for sorted range writes - constexpr bool kEnableInterleaved = false; + constexpr bool kEnableInterleaved = true; #else constexpr bool kEnableInterleaved = false; #endif