SIMD for Node16 in lastLeq and firstGeq
This commit is contained in:
117
ConflictSet.cpp
117
ConflictSet.cpp
@@ -686,6 +686,9 @@ Node *&getChildExists(Node *self, uint8_t index) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int getChildGeq(Node *self, int child) {
|
int getChildGeq(Node *self, int child) {
|
||||||
|
if (child > 255) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (self->type == Type::Node4) {
|
if (self->type == Type::Node4) {
|
||||||
auto *self4 = static_cast<Node4 *>(self);
|
auto *self4 = static_cast<Node4 *>(self);
|
||||||
for (int i = 0; i < self->numChildren; ++i) {
|
for (int i = 0; i < self->numChildren; ++i) {
|
||||||
@@ -698,6 +701,34 @@ int getChildGeq(Node *self, int child) {
|
|||||||
}
|
}
|
||||||
} else if (self->type == Type::Node16) {
|
} else if (self->type == Type::Node16) {
|
||||||
auto *self16 = static_cast<Node16 *>(self);
|
auto *self16 = static_cast<Node16 *>(self);
|
||||||
|
#ifdef HAS_AVX
|
||||||
|
// TODO
|
||||||
|
#elif defined(HAS_ARM_NEON)
|
||||||
|
uint8x16_t indices;
|
||||||
|
memcpy(&indices, self16->index, sizeof(self16->index));
|
||||||
|
// 0xff for each leq
|
||||||
|
auto results = vcleq_u8(vdupq_n_u8(child), indices);
|
||||||
|
uint64_t mask = self->numChildren == 16
|
||||||
|
? uint64_t(-1)
|
||||||
|
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
||||||
|
// 0xf for each 0xff (within mask)
|
||||||
|
uint64_t bitfield =
|
||||||
|
vget_lane_u64(
|
||||||
|
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
|
||||||
|
0) &
|
||||||
|
mask;
|
||||||
|
int simd =
|
||||||
|
bitfield == 0 ? -1 : self16->index[__builtin_ctzll(bitfield) / 4];
|
||||||
|
assert(simd == [&]() -> int {
|
||||||
|
for (int i = 0; i < self->numChildren; ++i) {
|
||||||
|
if (self16->index[i] >= child) {
|
||||||
|
return self16->index[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}());
|
||||||
|
return simd;
|
||||||
|
#else
|
||||||
for (int i = 0; i < self->numChildren; ++i) {
|
for (int i = 0; i < self->numChildren; ++i) {
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
assert(self16->index[i - 1] < self16->index[i]);
|
assert(self16->index[i - 1] < self16->index[i]);
|
||||||
@@ -706,6 +737,7 @@ int getChildGeq(Node *self, int child) {
|
|||||||
return self16->index[i];
|
return self16->index[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
} else if (self->type == Type::Node48) {
|
} else if (self->type == Type::Node48) {
|
||||||
auto *self48 = static_cast<Node48 *>(self);
|
auto *self48 = static_cast<Node48 *>(self);
|
||||||
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
|
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
|
||||||
@@ -732,39 +764,19 @@ int getChildGeq(Node *self, int child) {
|
|||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
auto *self256 = static_cast<Node256 *>(self);
|
auto *self256 = static_cast<Node256 *>(self);
|
||||||
// For some reason gcc can't auto vectorize this, and the plain loop is
|
|
||||||
// faster.
|
|
||||||
#if defined(__clang__)
|
|
||||||
int i = child;
|
|
||||||
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
|
|
||||||
for (; (i & (kUnrollCount - 1)) != 0; ++i) {
|
|
||||||
if (self256->children[i]) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (; i < 256; i += kUnrollCount) {
|
|
||||||
uint8_t nonNull[kUnrollCount];
|
|
||||||
for (int j = 0; j < kUnrollCount; ++j) {
|
|
||||||
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
|
|
||||||
}
|
|
||||||
uint64_t word;
|
|
||||||
memcpy(&word, nonNull, kUnrollCount);
|
|
||||||
if (word) {
|
|
||||||
return i + __builtin_ctzll(word) / 8;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
for (int i = child; i < 256; ++i) {
|
for (int i = child; i < 256; ++i) {
|
||||||
if (self256->children[i]) {
|
if (self256->children[i]) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int getChildLeq(Node *self, int child) {
|
int getChildLeq(Node *self, int child) {
|
||||||
|
if (child < 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (self->type == Type::Node4) {
|
if (self->type == Type::Node4) {
|
||||||
auto *self4 = static_cast<Node4 *>(self);
|
auto *self4 = static_cast<Node4 *>(self);
|
||||||
for (int i = self->numChildren - 1; i >= 0; --i) {
|
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||||
@@ -777,14 +789,41 @@ int getChildLeq(Node *self, int child) {
|
|||||||
}
|
}
|
||||||
} else if (self->type == Type::Node16) {
|
} else if (self->type == Type::Node16) {
|
||||||
auto *self16 = static_cast<Node16 *>(self);
|
auto *self16 = static_cast<Node16 *>(self);
|
||||||
for (int i = self->numChildren - 1; i >= 0; --i) {
|
#ifdef HAS_AVX
|
||||||
if (i > 0) {
|
// TODO
|
||||||
assert(self16->index[i - 1] < self16->index[i]);
|
#elif defined(HAS_ARM_NEON)
|
||||||
|
uint8x16_t indices;
|
||||||
|
memcpy(&indices, self16->index, sizeof(self16->index));
|
||||||
|
// 0xff for each leq
|
||||||
|
auto results = vcleq_u8(indices, vdupq_n_u8(child));
|
||||||
|
uint64_t mask = self->numChildren == 16
|
||||||
|
? uint64_t(-1)
|
||||||
|
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
||||||
|
// 0xf for each 0xff (within mask)
|
||||||
|
uint64_t bitfield =
|
||||||
|
vget_lane_u64(
|
||||||
|
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
|
||||||
|
0) &
|
||||||
|
mask;
|
||||||
|
int simd =
|
||||||
|
bitfield == 0 ? -1 : self16->index[15 - __builtin_clzll(bitfield) / 4];
|
||||||
|
assert(simd == [&]() -> int {
|
||||||
|
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||||
|
if (self16->index[i] <= child) {
|
||||||
|
return self16->index[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return -1;
|
||||||
|
}());
|
||||||
|
return simd;
|
||||||
|
#else
|
||||||
|
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||||
if (self16->index[i] <= child) {
|
if (self16->index[i] <= child) {
|
||||||
return self16->index[i];
|
return self16->index[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return -1;
|
||||||
|
#endif
|
||||||
} else if (self->type == Type::Node48) {
|
} else if (self->type == Type::Node48) {
|
||||||
auto *self48 = static_cast<Node48 *>(self);
|
auto *self48 = static_cast<Node48 *>(self);
|
||||||
// TODO the plain loop is faster?
|
// TODO the plain loop is faster?
|
||||||
@@ -820,37 +859,11 @@ int getChildLeq(Node *self, int child) {
|
|||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
auto *self256 = static_cast<Node256 *>(self);
|
auto *self256 = static_cast<Node256 *>(self);
|
||||||
// TODO: The plain loop is faster?
|
|
||||||
#if defined(__clang__)
|
|
||||||
int i = child;
|
|
||||||
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
|
|
||||||
for (; (i & (kUnrollCount - 1)) != 0; --i) {
|
|
||||||
if (self256->children[i]) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (self256->children[i]) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
i -= kUnrollCount;
|
|
||||||
for (; i >= 0; i -= kUnrollCount) {
|
|
||||||
uint8_t nonNull[kUnrollCount];
|
|
||||||
for (int j = 0; j < kUnrollCount; ++j) {
|
|
||||||
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
|
|
||||||
}
|
|
||||||
uint64_t word;
|
|
||||||
memcpy(&word, nonNull, kUnrollCount);
|
|
||||||
if (word) {
|
|
||||||
return i + 7 - __builtin_clzll(word) / 8;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
for (int i = child; i >= 0; --i) {
|
for (int i = child; i >= 0; --i) {
|
||||||
if (self256->children[i]) {
|
if (self256->children[i]) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user