From 309e6ab8163262aec2760f1060efb4111204e235 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Wed, 30 Oct 2024 22:59:00 -0700 Subject: [PATCH] Dispatch on pair of types --- ConflictSet.cpp | 129 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 113 insertions(+), 16 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 7debe81..6b04956 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -4098,6 +4098,8 @@ typedef PRESERVE_NONE void (*Continuation)(struct Job *, struct Context *); struct Job { TrivialSpan remaining; Node *n; + TaggedNodePointer child; + int childIndex; int index; TrivialSpan begin; // Range write only TrivialSpan end; // Range write only @@ -4109,6 +4111,58 @@ struct Job { Job *prev; Job *next; void init(Context *, int index); + + bool getChildAndIndex(Node0 *, uint8_t) { return false; } + bool getChildAndIndex(Node3 *self, uint8_t index) { + childIndex = getNodeIndex(self, index); + if (childIndex >= 0) { + child = self->children[childIndex]; + return true; + } + return false; + } + bool getChildAndIndex(Node16 *self, uint8_t index) { + childIndex = getNodeIndex(self, index); + if (childIndex >= 0) { + child = self->children[childIndex]; + return true; + } + return false; + } + bool getChildAndIndex(Node48 *self, uint8_t index) { + childIndex = self->index[index]; + if (childIndex >= 0) { + child = self->children[childIndex]; + return true; + } + return false; + } + bool getChildAndIndex(Node256 *self, uint8_t i) { + child = self->children[i]; + if (child != nullptr) { + childIndex = i; + child = self->children[childIndex]; + return true; + } + return false; + } + + bool getChildAndIndex(Node *self, uint8_t index) { + switch (self->getType()) { + case Type_Node0: + return getChildAndIndex(static_cast(self), index); + case Type_Node3: + return getChildAndIndex(static_cast(self), index); + case Type_Node16: + return getChildAndIndex(static_cast(self), index); + case Type_Node48: + return getChildAndIndex(static_cast(self), index); + case Type_Node256: + return getChildAndIndex(static_cast(self), index); + default: // GCOVR_EXCL_LINE + __builtin_unreachable(); // GCOVR_EXCL_LINE + } + } }; // Result of an insertion. The search path of insertionPoint + remaining == the @@ -4167,30 +4221,68 @@ PRESERVE_NONE void complete(Job *job, Context *context) { } } -template PRESERVE_NONE void pointIter(Job *, Context *); +template +PRESERVE_NONE void pointIter(Job *, Context *); -static Continuation pointIterTable[] = {pointIter, pointIter, - pointIter, pointIter, - pointIter}; +template struct PointIterTable { + static constexpr Continuation table[] = { + pointIter, pointIter, + pointIter, pointIter, + pointIter}; +}; -template void pointIter(Job *job, Context *context) { - assert(NodeT::kType == job->n->getType()); - NodeT *n = static_cast(job->n); +static constexpr Continuation const *pointIterTable[] = { + PointIterTable::table, PointIterTable::table, + PointIterTable::table, PointIterTable::table, + PointIterTable::table, +}; - TaggedNodePointer *child = - getChildUpdatingMaxVersion(n, job->remaining, context->writeVersion); - if (child == nullptr) [[unlikely]] { - context->results[job->index] = {job->n, job->remaining}; - MUSTTAIL return complete(job, context); +template +void pointIter(Job *job, Context *context) { + assert(NodeTFrom::kType == job->n->getType()); + NodeTFrom *n = static_cast(job->n); + assert(NodeTTo::kType == job->child->getType()); + NodeTTo *child = static_cast(job->child); + + auto key = job->remaining.subspan(1, job->remaining.size() - 1); + if (child->partialKeyLen > 0) { + int commonLen = std::min(child->partialKeyLen, key.size()); + int partialKeyIndex = + longestCommonPrefix(child->partialKey(), key.data(), commonLen); + if (partialKeyIndex < child->partialKeyLen) { + context->results[job->index] = {job->n, job->remaining}; + MUSTTAIL return complete(job, context); + } } - job->n = *child; + + // child is on the search path. Commit to advancing and updating max version + job->n = child; + job->remaining = + key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen); + if constexpr (std::is_same_v || + std::is_same_v) { + n->childMaxVersion[job->childIndex] = context->writeVersion; + } else if constexpr (std::is_same_v || + std::is_same_v) { + n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] = + std::max(n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift], + context->writeVersion); + n->childMaxVersion[job->childIndex] = context->writeVersion; + } + if (job->remaining.size() == 0) [[unlikely]] { context->results[job->index] = {job->n, job->remaining}; MUSTTAIL return complete(job, context); } + + if (!job->getChildAndIndex(child, job->remaining.front())) [[unlikely]] { + context->results[job->index] = {job->n, job->remaining}; + MUSTTAIL return complete(job, context); + } + ++context->iterations; - job->continuation = pointIterTable[child->getType()]; - __builtin_prefetch(job->n); + job->continuation = PointIterTable::table[job->child.getType()]; + __builtin_prefetch(job->child); MUSTTAIL return keepGoing(job, context); } @@ -4331,7 +4423,12 @@ pointWrite: context->results[index] = {n, remaining}; continuation = complete; } else { - continuation = pointIterTable[n->getType()]; + if (!getChildAndIndex(n, remaining.front())) [[unlikely]] { + context->results[index] = {n, remaining}; + continuation = complete; + } else { + continuation = pointIterTable[n->getType()][child.getType()]; + } } }