diff --git a/ConflictSet.cpp b/ConflictSet.cpp index a88e9f0..3761264 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -4283,14 +4283,11 @@ void pointIter(Job *job, Context *context) { MUSTTAIL return keepGoing(job, context); } -template PRESERVE_NONE void commonPrefixIter(Job *, Context *); +template +PRESERVE_NONE void prefixIter(Job *, Context *); template PRESERVE_NONE void beginIter(Job *, Context *); template PRESERVE_NONE void endIter(Job *, Context *); -static Continuation commonPrefixIterTable[] = { - commonPrefixIter, commonPrefixIter, commonPrefixIter, - commonPrefixIter, commonPrefixIter}; - static Continuation beginIterTable[] = {beginIter, beginIter, beginIter, beginIter, beginIter}; @@ -4299,26 +4296,51 @@ static Continuation endIterTable[] = {endIter, endIter, endIter, endIter, endIter}; -template void commonPrefixIter(Job *job, Context *context) { - assert(NodeT::kType == job->n->getType()); - NodeT *n = static_cast(job->n); +template struct PrefixIterTable { + static constexpr Continuation table[] = { + prefixIter, prefixIter, + prefixIter, prefixIter, + prefixIter}; +}; - TaggedNodePointer *child = - getChildUpdatingMaxVersion(n, job->remaining, context->writeVersion); - if (child == nullptr) [[unlikely]] { - int prefixLen = job->commonPrefixLen - job->remaining.size(); - assert(prefixLen >= 0); - assert(job->n != nullptr); - *job->result = { - job->n, - job->begin.subspan(prefixLen, job->begin.size() - prefixLen), - job->n, - job->end.subspan(prefixLen, job->end.size() - prefixLen), - }; - MUSTTAIL return complete(job, context); +static constexpr Continuation const *prefixIterTable[] = { + PrefixIterTable::table, PrefixIterTable::table, + PrefixIterTable::table, PrefixIterTable::table, + PrefixIterTable::table, +}; + +template +void prefixIter(Job *job, Context *context) { + assert(NodeTFrom::kType == job->n->getType()); + NodeTFrom *n = static_cast(job->n); + assert(NodeTTo::kType == job->child->getType()); + NodeTTo *child = static_cast(job->child); + + auto key = job->remaining.subspan(1, job->remaining.size() - 1); + if (child->partialKeyLen > 0) { + int commonLen = std::min(child->partialKeyLen, key.size()); + int partialKeyIndex = + longestCommonPrefix(child->partialKey(), key.data(), commonLen); + if (partialKeyIndex < child->partialKeyLen) { + goto noNodeOnSearchPath; + } } - job->n = *child; - ++context->iterations; + + // child is on the search path. Commit to advancing and updating max version + job->n = child; + job->remaining = + key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen); + if constexpr (std::is_same_v || + std::is_same_v) { + n->childMaxVersion[job->childIndex] = context->writeVersion; + } else if constexpr (std::is_same_v || + std::is_same_v) { + n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] = + std::max(n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift], + context->writeVersion); + n->childMaxVersion[job->childIndex] = context->writeVersion; + } + if (job->remaining.size() == 0) [[unlikely]] { job->endNode = job->n; job->begin = job->begin.subspan(job->commonPrefixLen, @@ -4330,11 +4352,29 @@ template void commonPrefixIter(Job *job, Context *context) { } else { job->continuation = beginIterTable[child->getType()]; } - } else { - job->continuation = commonPrefixIterTable[child->getType()]; + MUSTTAIL return keepGoing(job, context); } - __builtin_prefetch(job->n); + + if (!job->getChildAndIndex(child, job->remaining.front())) [[unlikely]] { + goto noNodeOnSearchPath; + } + + ++context->iterations; + job->continuation = PrefixIterTable::table[job->child.getType()]; + __builtin_prefetch(job->child); MUSTTAIL return keepGoing(job, context); + +noNodeOnSearchPath: + int prefixLen = job->commonPrefixLen - job->remaining.size(); + assert(prefixLen >= 0); + assert(job->n != nullptr); + *job->result = { + job->n, + job->begin.subspan(prefixLen, job->begin.size() - prefixLen), + job->n, + job->end.subspan(prefixLen, job->end.size() - prefixLen), + }; + MUSTTAIL return complete(job, context); } template void beginIter(Job *job, Context *context) { @@ -4403,7 +4443,17 @@ void Job::init(Context *context, int index) { if (commonPrefixLen > 0) { // common prefix iter will set endNode - continuation = commonPrefixIterTable[n->getType()]; + if (!getChildAndIndex(n, remaining.front())) [[unlikely]] { + *result = { + n, + begin, + n, + end, + }; + continuation = complete; + } else { + continuation = prefixIterTable[n->getType()][child.getType()]; + } } else if (begin.size() > 0) { endNode = n; continuation = beginIterTable[n->getType()]; @@ -4417,11 +4467,11 @@ pointWrite: remaining = TrivialSpan(context->writes[index].begin.p, context->writes[index].begin.len); if (remaining.size() == 0) [[unlikely]] { - context->results[index] = {n, remaining}; + *result = {n, remaining}; continuation = complete; } else { if (!getChildAndIndex(n, remaining.front())) [[unlikely]] { - context->results[index] = {n, remaining}; + *result = {n, remaining}; continuation = complete; } else { continuation = pointIterTable[n->getType()][child.getType()];