Dispatch on pair of types

This commit is contained in:
2024-10-30 22:59:00 -07:00
parent 12b82c1be5
commit 309e6ab816

View File

@@ -4098,6 +4098,8 @@ typedef PRESERVE_NONE void (*Continuation)(struct Job *, struct Context *);
struct Job {
TrivialSpan remaining;
Node *n;
TaggedNodePointer child;
int childIndex;
int index;
TrivialSpan begin; // Range write only
TrivialSpan end; // Range write only
@@ -4109,6 +4111,58 @@ struct Job {
Job *prev;
Job *next;
void init(Context *, int index);
bool getChildAndIndex(Node0 *, uint8_t) { return false; }
bool getChildAndIndex(Node3 *self, uint8_t index) {
childIndex = getNodeIndex(self, index);
if (childIndex >= 0) {
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node16 *self, uint8_t index) {
childIndex = getNodeIndex(self, index);
if (childIndex >= 0) {
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node48 *self, uint8_t index) {
childIndex = self->index[index];
if (childIndex >= 0) {
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node256 *self, uint8_t i) {
child = self->children[i];
if (child != nullptr) {
childIndex = i;
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChildAndIndex(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChildAndIndex(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChildAndIndex(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChildAndIndex(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChildAndIndex(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
};
// Result of an insertion. The search path of insertionPoint + remaining == the
@@ -4167,30 +4221,68 @@ PRESERVE_NONE void complete(Job *job, Context *context) {
}
}
template <class NodeT> PRESERVE_NONE void pointIter(Job *, Context *);
template <class NodeTFrom, class NodeTTo>
PRESERVE_NONE void pointIter(Job *, Context *);
static Continuation pointIterTable[] = {pointIter<Node0>, pointIter<Node3>,
pointIter<Node16>, pointIter<Node48>,
pointIter<Node256>};
template <class NodeTFrom> struct PointIterTable {
static constexpr Continuation table[] = {
pointIter<NodeTFrom, Node0>, pointIter<NodeTFrom, Node3>,
pointIter<NodeTFrom, Node16>, pointIter<NodeTFrom, Node48>,
pointIter<NodeTFrom, Node256>};
};
template <class NodeT> void pointIter(Job *job, Context *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
static constexpr Continuation const *pointIterTable[] = {
PointIterTable<Node0>::table, PointIterTable<Node3>::table,
PointIterTable<Node16>::table, PointIterTable<Node48>::table,
PointIterTable<Node256>::table,
};
TaggedNodePointer *child =
getChildUpdatingMaxVersion(n, job->remaining, context->writeVersion);
if (child == nullptr) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
template <class NodeTFrom, class NodeTTo>
void pointIter(Job *job, Context *context) {
assert(NodeTFrom::kType == job->n->getType());
NodeTFrom *n = static_cast<NodeTFrom *>(job->n);
assert(NodeTTo::kType == job->child->getType());
NodeTTo *child = static_cast<NodeTTo *>(job->child);
auto key = job->remaining.subspan(1, job->remaining.size() - 1);
if (child->partialKeyLen > 0) {
int commonLen = std::min<int>(child->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(child->partialKey(), key.data(), commonLen);
if (partialKeyIndex < child->partialKeyLen) {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
}
job->n = *child;
// child is on the search path. Commit to advancing and updating max version
job->n = child;
job->remaining =
key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen);
if constexpr (std::is_same_v<NodeTFrom, Node3> ||
std::is_same_v<NodeTFrom, Node16>) {
n->childMaxVersion[job->childIndex] = context->writeVersion;
} else if constexpr (std::is_same_v<NodeTFrom, Node48> ||
std::is_same_v<NodeTFrom, Node256>) {
n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] =
std::max(n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift],
context->writeVersion);
n->childMaxVersion[job->childIndex] = context->writeVersion;
}
if (job->remaining.size() == 0) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
if (!job->getChildAndIndex(child, job->remaining.front())) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
++context->iterations;
job->continuation = pointIterTable[child->getType()];
__builtin_prefetch(job->n);
job->continuation = PointIterTable<NodeTTo>::table[job->child.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
@@ -4331,7 +4423,12 @@ pointWrite:
context->results[index] = {n, remaining};
continuation = complete;
} else {
continuation = pointIterTable[n->getType()];
if (!getChildAndIndex(n, remaining.front())) [[unlikely]] {
context->results[index] = {n, remaining};
continuation = complete;
} else {
continuation = pointIterTable[n->getType()][child.getType()];
}
}
}