Compare commits
2 Commits
0abf6a1ecf
...
84c4d0fcba
| Author | SHA1 | Date | |
|---|---|---|---|
| 84c4d0fcba | |||
| 6241533dfb |
@@ -776,6 +776,18 @@ int getNodeIndex(Node3 *self, uint8_t index) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getNodeIndexExists(Node3 *self, uint8_t index) {
|
||||||
|
Node3 *n = (Node3 *)self;
|
||||||
|
assume(n->numChildren >= 1);
|
||||||
|
assume(n->numChildren <= 3);
|
||||||
|
for (int i = 0; i < n->numChildren; ++i) {
|
||||||
|
if (n->index[i] == index) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
|
}
|
||||||
|
|
||||||
int getNodeIndex(Node16 *self, uint8_t index) {
|
int getNodeIndex(Node16 *self, uint8_t index) {
|
||||||
|
|
||||||
#ifdef HAS_AVX
|
#ifdef HAS_AVX
|
||||||
@@ -836,13 +848,52 @@ int getNodeIndex(Node16 *self, uint8_t index) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getNodeIndexExists(Node16 *self, uint8_t index) {
|
||||||
|
|
||||||
|
#ifdef HAS_AVX
|
||||||
|
__m128i key_vec = _mm_set1_epi8(index);
|
||||||
|
__m128i indices;
|
||||||
|
memcpy(&indices, self->index, Node16::kMaxNodes);
|
||||||
|
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
|
||||||
|
uint32_t mask = (1 << self->numChildren) - 1;
|
||||||
|
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
|
||||||
|
assume(bitfield != 0);
|
||||||
|
return std::countr_zero(bitfield);
|
||||||
|
#elif defined(HAS_ARM_NEON)
|
||||||
|
// Based on
|
||||||
|
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
|
||||||
|
|
||||||
|
uint8x16_t indices;
|
||||||
|
memcpy(&indices, self->index, Node16::kMaxNodes);
|
||||||
|
// 0xff for each match
|
||||||
|
uint16x8_t results =
|
||||||
|
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
|
||||||
|
assume(self->numChildren <= Node16::kMaxNodes);
|
||||||
|
uint64_t mask = self->numChildren == 16
|
||||||
|
? uint64_t(-1)
|
||||||
|
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
||||||
|
// 0xf for each match in valid range
|
||||||
|
uint64_t bitfield =
|
||||||
|
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
|
||||||
|
assume(bitfield != 0);
|
||||||
|
return std::countr_zero(bitfield) / 4;
|
||||||
|
#else
|
||||||
|
for (int i = 0; i < self->numChildren; ++i) {
|
||||||
|
if (self->index[i] == index) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// Precondition - an entry for index must exist in the node
|
// Precondition - an entry for index must exist in the node
|
||||||
Node *&getChildExists(Node3 *self, uint8_t index) {
|
Node *&getChildExists(Node3 *self, uint8_t index) {
|
||||||
return self->children[getNodeIndex(self, index)];
|
return self->children[getNodeIndexExists(self, index)];
|
||||||
}
|
}
|
||||||
// Precondition - an entry for index must exist in the node
|
// Precondition - an entry for index must exist in the node
|
||||||
Node *&getChildExists(Node16 *self, uint8_t index) {
|
Node *&getChildExists(Node16 *self, uint8_t index) {
|
||||||
return self->children[getNodeIndex(self, index)];
|
return self->children[getNodeIndexExists(self, index)];
|
||||||
}
|
}
|
||||||
// Precondition - an entry for index must exist in the node
|
// Precondition - an entry for index must exist in the node
|
||||||
Node *&getChildExists(Node48 *self, uint8_t index) {
|
Node *&getChildExists(Node48 *self, uint8_t index) {
|
||||||
@@ -887,12 +938,12 @@ InternalVersionT maxVersion(Node *n) {
|
|||||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
case Type_Node3: {
|
case Type_Node3: {
|
||||||
auto *n3 = static_cast<Node3 *>(n);
|
auto *n3 = static_cast<Node3 *>(n);
|
||||||
int i = getNodeIndex(n3, index);
|
int i = getNodeIndexExists(n3, index);
|
||||||
return n3->childMaxVersion[i];
|
return n3->childMaxVersion[i];
|
||||||
}
|
}
|
||||||
case Type_Node16: {
|
case Type_Node16: {
|
||||||
auto *n16 = static_cast<Node16 *>(n);
|
auto *n16 = static_cast<Node16 *>(n);
|
||||||
int i = getNodeIndex(n16, index);
|
int i = getNodeIndexExists(n16, index);
|
||||||
return n16->childMaxVersion[i];
|
return n16->childMaxVersion[i];
|
||||||
}
|
}
|
||||||
case Type_Node48: {
|
case Type_Node48: {
|
||||||
@@ -920,12 +971,12 @@ InternalVersionT exchangeMaxVersion(Node *n, InternalVersionT newMax) {
|
|||||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
case Type_Node3: {
|
case Type_Node3: {
|
||||||
auto *n3 = static_cast<Node3 *>(n);
|
auto *n3 = static_cast<Node3 *>(n);
|
||||||
int i = getNodeIndex(n3, index);
|
int i = getNodeIndexExists(n3, index);
|
||||||
return std::exchange(n3->childMaxVersion[i], newMax);
|
return std::exchange(n3->childMaxVersion[i], newMax);
|
||||||
}
|
}
|
||||||
case Type_Node16: {
|
case Type_Node16: {
|
||||||
auto *n16 = static_cast<Node16 *>(n);
|
auto *n16 = static_cast<Node16 *>(n);
|
||||||
int i = getNodeIndex(n16, index);
|
int i = getNodeIndexExists(n16, index);
|
||||||
return std::exchange(n16->childMaxVersion[i], newMax);
|
return std::exchange(n16->childMaxVersion[i], newMax);
|
||||||
}
|
}
|
||||||
case Type_Node48: {
|
case Type_Node48: {
|
||||||
@@ -954,13 +1005,13 @@ void setMaxVersion(Node *n, InternalVersionT newMax) {
|
|||||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
case Type_Node3: {
|
case Type_Node3: {
|
||||||
auto *n3 = static_cast<Node3 *>(n);
|
auto *n3 = static_cast<Node3 *>(n);
|
||||||
int i = getNodeIndex(n3, index);
|
int i = getNodeIndexExists(n3, index);
|
||||||
n3->childMaxVersion[i] = newMax;
|
n3->childMaxVersion[i] = newMax;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case Type_Node16: {
|
case Type_Node16: {
|
||||||
auto *n16 = static_cast<Node16 *>(n);
|
auto *n16 = static_cast<Node16 *>(n);
|
||||||
int i = getNodeIndex(n16, index);
|
int i = getNodeIndexExists(n16, index);
|
||||||
n16->childMaxVersion[i] = newMax;
|
n16->childMaxVersion[i] = newMax;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user