SIMD for Node16 in lastLeq and firstGeq

This commit is contained in:
2024-01-25 12:40:11 -08:00
parent 35cf3f3132
commit b15bec6b38

View File

@@ -686,6 +686,9 @@ Node *&getChildExists(Node *self, uint8_t index) {
}
int getChildGeq(Node *self, int child) {
if (child > 255) {
return -1;
}
if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
for (int i = 0; i < self->numChildren; ++i) {
@@ -698,6 +701,34 @@ int getChildGeq(Node *self, int child) {
}
} else if (self->type == Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
#ifdef HAS_AVX
// TODO
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self16->index, sizeof(self16->index));
// 0xff for each leq
auto results = vcleq_u8(vdupq_n_u8(child), indices);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
int simd =
bitfield == 0 ? -1 : self16->index[__builtin_ctzll(bitfield) / 4];
assert(simd == [&]() -> int {
for (int i = 0; i < self->numChildren; ++i) {
if (self16->index[i] >= child) {
return self16->index[i];
}
}
return -1;
}());
return simd;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (i > 0) {
assert(self16->index[i - 1] < self16->index[i]);
@@ -706,6 +737,7 @@ int getChildGeq(Node *self, int child) {
return self16->index[i];
}
}
#endif
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
@@ -732,39 +764,19 @@ int getChildGeq(Node *self, int child) {
#endif
} else {
auto *self256 = static_cast<Node256 *>(self);
// For some reason gcc can't auto vectorize this, and the plain loop is
// faster.
#if defined(__clang__)
int i = child;
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
for (; (i & (kUnrollCount - 1)) != 0; ++i) {
if (self256->children[i]) {
return i;
}
}
for (; i < 256; i += kUnrollCount) {
uint8_t nonNull[kUnrollCount];
for (int j = 0; j < kUnrollCount; ++j) {
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
}
uint64_t word;
memcpy(&word, nonNull, kUnrollCount);
if (word) {
return i + __builtin_ctzll(word) / 8;
}
}
#else
for (int i = child; i < 256; ++i) {
if (self256->children[i]) {
return i;
}
}
#endif
}
return -1;
}
int getChildLeq(Node *self, int child) {
if (child < 0) {
return -1;
}
if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
for (int i = self->numChildren - 1; i >= 0; --i) {
@@ -777,14 +789,41 @@ int getChildLeq(Node *self, int child) {
}
} else if (self->type == Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
for (int i = self->numChildren - 1; i >= 0; --i) {
if (i > 0) {
assert(self16->index[i - 1] < self16->index[i]);
#ifdef HAS_AVX
// TODO
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self16->index, sizeof(self16->index));
// 0xff for each leq
auto results = vcleq_u8(indices, vdupq_n_u8(child));
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
int simd =
bitfield == 0 ? -1 : self16->index[15 - __builtin_clzll(bitfield) / 4];
assert(simd == [&]() -> int {
for (int i = self->numChildren - 1; i >= 0; --i) {
if (self16->index[i] <= child) {
return self16->index[i];
}
}
return -1;
}());
return simd;
#else
for (int i = self->numChildren - 1; i >= 0; --i) {
if (self16->index[i] <= child) {
return self16->index[i];
}
}
return -1;
#endif
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
// TODO the plain loop is faster?
@@ -820,37 +859,11 @@ int getChildLeq(Node *self, int child) {
#endif
} else {
auto *self256 = static_cast<Node256 *>(self);
// TODO: The plain loop is faster?
#if defined(__clang__)
int i = child;
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
for (; (i & (kUnrollCount - 1)) != 0; --i) {
if (self256->children[i]) {
return i;
}
}
if (self256->children[i]) {
return i;
}
i -= kUnrollCount;
for (; i >= 0; i -= kUnrollCount) {
uint8_t nonNull[kUnrollCount];
for (int j = 0; j < kUnrollCount; ++j) {
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
}
uint64_t word;
memcpy(&word, nonNull, kUnrollCount);
if (word) {
return i + 7 - __builtin_clzll(word) / 8;
}
}
#else
for (int i = child; i >= 0; --i) {
if (self256->children[i]) {
return i;
}
}
#endif
}
return -1;
}