SIMD for Node16 in lastLeq and firstGeq
This commit is contained in:
117
ConflictSet.cpp
117
ConflictSet.cpp
@@ -686,6 +686,9 @@ Node *&getChildExists(Node *self, uint8_t index) {
|
||||
}
|
||||
|
||||
int getChildGeq(Node *self, int child) {
|
||||
if (child > 255) {
|
||||
return -1;
|
||||
}
|
||||
if (self->type == Type::Node4) {
|
||||
auto *self4 = static_cast<Node4 *>(self);
|
||||
for (int i = 0; i < self->numChildren; ++i) {
|
||||
@@ -698,6 +701,34 @@ int getChildGeq(Node *self, int child) {
|
||||
}
|
||||
} else if (self->type == Type::Node16) {
|
||||
auto *self16 = static_cast<Node16 *>(self);
|
||||
#ifdef HAS_AVX
|
||||
// TODO
|
||||
#elif defined(HAS_ARM_NEON)
|
||||
uint8x16_t indices;
|
||||
memcpy(&indices, self16->index, sizeof(self16->index));
|
||||
// 0xff for each leq
|
||||
auto results = vcleq_u8(vdupq_n_u8(child), indices);
|
||||
uint64_t mask = self->numChildren == 16
|
||||
? uint64_t(-1)
|
||||
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
||||
// 0xf for each 0xff (within mask)
|
||||
uint64_t bitfield =
|
||||
vget_lane_u64(
|
||||
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
|
||||
0) &
|
||||
mask;
|
||||
int simd =
|
||||
bitfield == 0 ? -1 : self16->index[__builtin_ctzll(bitfield) / 4];
|
||||
assert(simd == [&]() -> int {
|
||||
for (int i = 0; i < self->numChildren; ++i) {
|
||||
if (self16->index[i] >= child) {
|
||||
return self16->index[i];
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}());
|
||||
return simd;
|
||||
#else
|
||||
for (int i = 0; i < self->numChildren; ++i) {
|
||||
if (i > 0) {
|
||||
assert(self16->index[i - 1] < self16->index[i]);
|
||||
@@ -706,6 +737,7 @@ int getChildGeq(Node *self, int child) {
|
||||
return self16->index[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else if (self->type == Type::Node48) {
|
||||
auto *self48 = static_cast<Node48 *>(self);
|
||||
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
|
||||
@@ -732,39 +764,19 @@ int getChildGeq(Node *self, int child) {
|
||||
#endif
|
||||
} else {
|
||||
auto *self256 = static_cast<Node256 *>(self);
|
||||
// For some reason gcc can't auto vectorize this, and the plain loop is
|
||||
// faster.
|
||||
#if defined(__clang__)
|
||||
int i = child;
|
||||
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
|
||||
for (; (i & (kUnrollCount - 1)) != 0; ++i) {
|
||||
if (self256->children[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
for (; i < 256; i += kUnrollCount) {
|
||||
uint8_t nonNull[kUnrollCount];
|
||||
for (int j = 0; j < kUnrollCount; ++j) {
|
||||
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
|
||||
}
|
||||
uint64_t word;
|
||||
memcpy(&word, nonNull, kUnrollCount);
|
||||
if (word) {
|
||||
return i + __builtin_ctzll(word) / 8;
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int i = child; i < 256; ++i) {
|
||||
if (self256->children[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int getChildLeq(Node *self, int child) {
|
||||
if (child < 0) {
|
||||
return -1;
|
||||
}
|
||||
if (self->type == Type::Node4) {
|
||||
auto *self4 = static_cast<Node4 *>(self);
|
||||
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||
@@ -777,14 +789,41 @@ int getChildLeq(Node *self, int child) {
|
||||
}
|
||||
} else if (self->type == Type::Node16) {
|
||||
auto *self16 = static_cast<Node16 *>(self);
|
||||
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||
if (i > 0) {
|
||||
assert(self16->index[i - 1] < self16->index[i]);
|
||||
#ifdef HAS_AVX
|
||||
// TODO
|
||||
#elif defined(HAS_ARM_NEON)
|
||||
uint8x16_t indices;
|
||||
memcpy(&indices, self16->index, sizeof(self16->index));
|
||||
// 0xff for each leq
|
||||
auto results = vcleq_u8(indices, vdupq_n_u8(child));
|
||||
uint64_t mask = self->numChildren == 16
|
||||
? uint64_t(-1)
|
||||
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
||||
// 0xf for each 0xff (within mask)
|
||||
uint64_t bitfield =
|
||||
vget_lane_u64(
|
||||
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
|
||||
0) &
|
||||
mask;
|
||||
int simd =
|
||||
bitfield == 0 ? -1 : self16->index[15 - __builtin_clzll(bitfield) / 4];
|
||||
assert(simd == [&]() -> int {
|
||||
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||
if (self16->index[i] <= child) {
|
||||
return self16->index[i];
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}());
|
||||
return simd;
|
||||
#else
|
||||
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||
if (self16->index[i] <= child) {
|
||||
return self16->index[i];
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
#endif
|
||||
} else if (self->type == Type::Node48) {
|
||||
auto *self48 = static_cast<Node48 *>(self);
|
||||
// TODO the plain loop is faster?
|
||||
@@ -820,37 +859,11 @@ int getChildLeq(Node *self, int child) {
|
||||
#endif
|
||||
} else {
|
||||
auto *self256 = static_cast<Node256 *>(self);
|
||||
// TODO: The plain loop is faster?
|
||||
#if defined(__clang__)
|
||||
int i = child;
|
||||
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
|
||||
for (; (i & (kUnrollCount - 1)) != 0; --i) {
|
||||
if (self256->children[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
if (self256->children[i]) {
|
||||
return i;
|
||||
}
|
||||
i -= kUnrollCount;
|
||||
for (; i >= 0; i -= kUnrollCount) {
|
||||
uint8_t nonNull[kUnrollCount];
|
||||
for (int j = 0; j < kUnrollCount; ++j) {
|
||||
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
|
||||
}
|
||||
uint64_t word;
|
||||
memcpy(&word, nonNull, kUnrollCount);
|
||||
if (word) {
|
||||
return i + 7 - __builtin_clzll(word) / 8;
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int i = child; i >= 0; --i) {
|
||||
if (self256->children[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
Reference in New Issue
Block a user