From 681a961289bbcf21b8f52f31c9239c2fd595ef3b Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Fri, 1 Nov 2024 13:45:41 -0700 Subject: [PATCH] Dispatch on type pairs for end iter --- ConflictSet.cpp | 136 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 106 insertions(+), 30 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index f5e0330..8cb40c2 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -4287,11 +4287,8 @@ template PRESERVE_NONE void prefixIter(Job *, Context *); template PRESERVE_NONE void beginIter(Job *, Context *); -template PRESERVE_NONE void endIter(Job *, Context *); - -static Continuation endIterTable[] = {endIter, endIter, - endIter, endIter, - endIter}; +template +PRESERVE_NONE void endIter(Job *, Context *); template struct PrefixIterTable { static constexpr Continuation table[] = { @@ -4319,6 +4316,19 @@ static constexpr Continuation const *beginIterTable[] = { BeginIterTable::table, }; +template struct EndIterTable { + static constexpr Continuation table[] = { + endIter, endIter, + endIter, endIter, + endIter}; +}; + +static constexpr Continuation const *endIterTable[] = { + EndIterTable::table, EndIterTable::table, + EndIterTable::table, EndIterTable::table, + EndIterTable::table, +}; + template void prefixIter(Job *job, Context *context) { assert(NodeTFrom::kType == job->n->getType()); @@ -4358,9 +4368,9 @@ void prefixIter(Job *job, Context *context) { job->end = job->end.subspan(job->commonPrefixLen, job->end.size() - job->commonPrefixLen); if (job->begin.size() == 0) [[unlikely]] { - MUSTTAIL return endIter(job, context); + goto gotoEndIter; } else if (!job->getChildAndIndex(child, job->begin.front())) [[unlikely]] { - MUSTTAIL return endIter(job, context); + goto gotoEndIter; } else { job->continuation = BeginIterTable::table[job->child.getType()]; __builtin_prefetch(job->child); @@ -4377,7 +4387,7 @@ void prefixIter(Job *job, Context *context) { __builtin_prefetch(job->child); MUSTTAIL return keepGoing(job, context); -noNodeOnSearchPath: +noNodeOnSearchPath: { int prefixLen = job->commonPrefixLen - job->remaining.size(); assert(prefixLen >= 0); assert(job->n != nullptr); @@ -4390,6 +4400,22 @@ noNodeOnSearchPath: MUSTTAIL return complete(job, context); } +gotoEndIter: + if (!job->getChildAndIndex(child, job->end.front())) [[unlikely]] { + *job->result = { + job->n, + job->begin, + job->n, + job->end, + }; + MUSTTAIL return complete(job, context); + } else { + job->continuation = EndIterTable::table[job->child.getType()]; + __builtin_prefetch(job->child); + MUSTTAIL return keepGoing(job, context); + } +} + template void beginIter(Job *job, Context *context) { assert(NodeTFrom::kType == job->n->getType()); @@ -4403,7 +4429,7 @@ void beginIter(Job *job, Context *context) { int partialKeyIndex = longestCommonPrefix(child->partialKey(), key.data(), commonLen); if (partialKeyIndex < child->partialKeyLen) { - MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + goto gotoEndIter; } } @@ -4423,39 +4449,82 @@ void beginIter(Job *job, Context *context) { } if (job->begin.size() == 0) [[unlikely]] { - MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + goto gotoEndIter; } if (!job->getChildAndIndex(child, job->begin.front())) [[unlikely]] { - MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + goto gotoEndIter; } ++context->iterations; job->continuation = BeginIterTable::table[job->child.getType()]; __builtin_prefetch(job->child); MUSTTAIL return keepGoing(job, context); + +gotoEndIter: + if (!job->getChildAndIndex(job->endNode, job->end.front())) [[unlikely]] { + *job->result = { + job->n, + job->begin, + job->endNode, + job->end, + }; + MUSTTAIL return complete(job, context); + } else { + MUSTTAIL return endIterTable[job->endNode->getType()][job->child.getType()]( + job, context); + } } -template void endIter(Job *job, Context *context) { - assert(NodeT::kType == job->endNode->getType()); - NodeT *endNode = static_cast(job->endNode); +template +void endIter(Job *job, Context *context) { + assert(NodeTFrom::kType == job->endNode->getType()); + NodeTFrom *endNode = static_cast(job->endNode); + assert(NodeTTo::kType == job->child->getType()); + NodeTTo *child = static_cast(job->child); - TaggedNodePointer *child = - getChildUpdatingMaxVersion(endNode, job->end, context->writeVersion); - if (child == nullptr) [[unlikely]] { - *job->result = {job->n, job->begin, job->endNode, job->end}; - assert(job->endNode != nullptr); - MUSTTAIL return complete(job, context); + auto key = job->end.subspan(1, job->end.size() - 1); + if (child->partialKeyLen > 0) { + int commonLen = std::min(child->partialKeyLen, key.size()); + int partialKeyIndex = + longestCommonPrefix(child->partialKey(), key.data(), commonLen); + if (partialKeyIndex < child->partialKeyLen) { + *job->result = {job->n, job->begin, job->endNode, job->end}; + assert(job->endNode != nullptr); + MUSTTAIL return complete(job, context); + } } - job->endNode = *child; + + // child is on the search path. Commit to advancing and updating max version + job->endNode = child; + job->end = + key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen); + if constexpr (std::is_same_v || + std::is_same_v) { + endNode->childMaxVersion[job->childIndex] = context->writeVersion; + } else if constexpr (std::is_same_v || + std::is_same_v) { + endNode->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] = std::max( + endNode->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift], + context->writeVersion); + endNode->childMaxVersion[job->childIndex] = context->writeVersion; + } + if (job->end.size() == 0) [[unlikely]] { *job->result = {job->n, job->begin, job->endNode, job->end}; assert(job->endNode != nullptr); MUSTTAIL return complete(job, context); } + + if (!job->getChildAndIndex(child, job->end.front())) [[unlikely]] { + *job->result = {job->n, job->begin, job->endNode, job->end}; + assert(job->endNode != nullptr); + MUSTTAIL return complete(job, context); + } + ++context->iterations; - job->continuation = endIterTable[child->getType()]; - __builtin_prefetch(job->endNode); + job->continuation = EndIterTable::table[job->child.getType()]; + __builtin_prefetch(job->child); MUSTTAIL return keepGoing(job, context); } @@ -4494,16 +4563,23 @@ void Job::init(Context *context, int index) { } else { continuation = prefixIterTable[n->getType()][child.getType()]; } - } else if (begin.size() > 0) { + } else if (begin.size() > 0 && getChildAndIndex(n, begin.front())) { endNode = n; - if (!getChildAndIndex(n, begin.front())) [[unlikely]] { - continuation = endIterTable[n->getType()]; - } else { - continuation = beginIterTable[n->getType()][child.getType()]; - } + continuation = beginIterTable[n->getType()][child.getType()]; } else { + assert(end.size() > 0); endNode = n; - continuation = endIterTable[n->getType()]; + if (!getChildAndIndex(n, end.front())) [[unlikely]] { + *result = { + n, + begin, + n, + end, + }; + continuation = complete; + } else { + continuation = endIterTable[n->getType()][child.getType()]; + } } return;