diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 3761264..f5e0330 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -4285,13 +4285,10 @@ void pointIter(Job *job, Context *context) { template PRESERVE_NONE void prefixIter(Job *, Context *); -template PRESERVE_NONE void beginIter(Job *, Context *); +template +PRESERVE_NONE void beginIter(Job *, Context *); template PRESERVE_NONE void endIter(Job *, Context *); -static Continuation beginIterTable[] = {beginIter, beginIter, - beginIter, beginIter, - beginIter}; - static Continuation endIterTable[] = {endIter, endIter, endIter, endIter, endIter}; @@ -4309,6 +4306,19 @@ static constexpr Continuation const *prefixIterTable[] = { PrefixIterTable::table, }; +template struct BeginIterTable { + static constexpr Continuation table[] = { + beginIter, beginIter, + beginIter, beginIter, + beginIter}; +}; + +static constexpr Continuation const *beginIterTable[] = { + BeginIterTable::table, BeginIterTable::table, + BeginIterTable::table, BeginIterTable::table, + BeginIterTable::table, +}; + template void prefixIter(Job *job, Context *context) { assert(NodeTFrom::kType == job->n->getType()); @@ -4348,11 +4358,14 @@ void prefixIter(Job *job, Context *context) { job->end = job->end.subspan(job->commonPrefixLen, job->end.size() - job->commonPrefixLen); if (job->begin.size() == 0) [[unlikely]] { - job->continuation = endIterTable[child->getType()]; + MUSTTAIL return endIter(job, context); + } else if (!job->getChildAndIndex(child, job->begin.front())) [[unlikely]] { + MUSTTAIL return endIter(job, context); } else { - job->continuation = beginIterTable[child->getType()]; + job->continuation = BeginIterTable::table[job->child.getType()]; + __builtin_prefetch(job->child); + MUSTTAIL return keepGoing(job, context); } - MUSTTAIL return keepGoing(job, context); } if (!job->getChildAndIndex(child, job->remaining.front())) [[unlikely]] { @@ -4377,22 +4390,49 @@ noNodeOnSearchPath: MUSTTAIL return complete(job, context); } -template void beginIter(Job *job, Context *context) { - assert(NodeT::kType == job->n->getType()); - NodeT *n = static_cast(job->n); +template +void beginIter(Job *job, Context *context) { + assert(NodeTFrom::kType == job->n->getType()); + NodeTFrom *n = static_cast(job->n); + assert(NodeTTo::kType == job->child->getType()); + NodeTTo *child = static_cast(job->child); - TaggedNodePointer *child = - getChildUpdatingMaxVersion(n, job->begin, context->writeVersion); - if (child == nullptr) [[unlikely]] { - MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + auto key = job->begin.subspan(1, job->begin.size() - 1); + if (child->partialKeyLen > 0) { + int commonLen = std::min(child->partialKeyLen, key.size()); + int partialKeyIndex = + longestCommonPrefix(child->partialKey(), key.data(), commonLen); + if (partialKeyIndex < child->partialKeyLen) { + MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + } } - job->n = *child; + + // child is on the search path. Commit to advancing and updating max version + job->n = child; + job->begin = + key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen); + if constexpr (std::is_same_v || + std::is_same_v) { + n->childMaxVersion[job->childIndex] = context->writeVersion; + } else if constexpr (std::is_same_v || + std::is_same_v) { + n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] = + std::max(n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift], + context->writeVersion); + n->childMaxVersion[job->childIndex] = context->writeVersion; + } + if (job->begin.size() == 0) [[unlikely]] { MUSTTAIL return endIterTable[job->endNode->getType()](job, context); } + + if (!job->getChildAndIndex(child, job->begin.front())) [[unlikely]] { + MUSTTAIL return endIterTable[job->endNode->getType()](job, context); + } + ++context->iterations; - job->continuation = beginIterTable[child->getType()]; - __builtin_prefetch(job->n); + job->continuation = BeginIterTable::table[job->child.getType()]; + __builtin_prefetch(job->child); MUSTTAIL return keepGoing(job, context); } @@ -4456,7 +4496,11 @@ void Job::init(Context *context, int index) { } } else if (begin.size() > 0) { endNode = n; - continuation = beginIterTable[n->getType()]; + if (!getChildAndIndex(n, begin.front())) [[unlikely]] { + continuation = endIterTable[n->getType()]; + } else { + continuation = beginIterTable[n->getType()][child.getType()]; + } } else { endNode = n; continuation = endIterTable[n->getType()];