Add getChildNodeGeq, use in nextLogical

This commit is contained in:
2024-03-07 12:12:02 -08:00
parent 0f360fa806
commit 53bc36f628

View File

@@ -484,6 +484,59 @@ int getChildGeq(Node *self, int child) {
return -1;
}
Node *getChildNodeGeq(Node *self, int child) {
if (child > 255) {
return nullptr;
}
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(child);
__m128i indices;
memcpy(&indices, self16->index, sizeof(self16->index));
__m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices));
int mask = (1 << self16->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
return bitfield == 0 ? nullptr
: self16->children[std::countr_zero(bitfield)].child;
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self16->index, sizeof(self16->index));
// 0xff for each leq
auto results = vcleq_u8(vdupq_n_u8(child), indices);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
return bitfield == 0
? nullptr
: self16->children[std::countr_zero(bitfield) / 4].child;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self16->index[i] >= child) {
return self16->children[i].child;
}
}
return nullptr;
#endif
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
int c = self48->bitSet.firstSetGeq(child);
return c >= 0 ? self48->children[self48->index[c]].child : nullptr;
} else {
assert(self->type == Type::Node256);
auto *self256 = static_cast<Node256 *>(self);
int c = self256->bitSet.firstSetGeq(child);
return c >= 0 ? self256->children[c].child : nullptr;
}
return nullptr;
}
void setChildrenParents(Node4 *n) {
for (int i = 0; i < n->numChildren; ++i) {
n->children[i].child->parent = n;
@@ -698,9 +751,9 @@ void eraseChild(Node *self, uint8_t index, NodeAllocators *allocators) {
Node *nextPhysical(Node *node) {
int index = -1;
for (;;) {
auto nextChild = getChildGeq(node, index + 1);
if (nextChild >= 0) {
return getChildExists(node, nextChild);
Node *nextChild = getChildNodeGeq(node, index + 1);
if (nextChild != nullptr) {
return nextChild;
}
index = node->parentsIndex;
node = node->parent;