SIMD tinkering

This commit is contained in:
2024-01-23 15:56:57 -08:00
parent 122cddb54d
commit c775fccf6f

View File

@@ -502,7 +502,7 @@ struct Node256 : Node {
Node256() { this->type = Type::Node256; }
};
static int getNodeIndex(Node4 *self, uint8_t index) {
int getNodeIndex(Node4 *self, uint8_t index) {
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
@@ -511,7 +511,7 @@ static int getNodeIndex(Node4 *self, uint8_t index) {
return -1;
}
static int getNodeIndex(Node16 *self, uint8_t index) {
int getNodeIndex(Node16 *self, uint8_t index) {
#ifdef HAS_AVX
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
@@ -580,6 +580,17 @@ int firstNonNeg1(const int8_t x[16]) {
return -1;
return __builtin_clz(bitfield);
}
int lastNonNeg1(const int8_t x[16]) {
__m128i key_vec = _mm_set1_epi8(-1);
__m128i indices;
memcpy(&indices, x, 16);
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
uint32_t bitfield = _mm_movemask_epi8(results) ^ 0xffff;
if (bitfield == 0)
return -1;
return 15 - __builtin_ctz(bitfield);
}
#endif
#ifdef HAS_ARM_NEON
@@ -593,6 +604,17 @@ int firstNonNeg1(const int8_t x[16]) {
return -1;
return __builtin_ctzll(bitfield) / 4;
}
int lastNonNeg1(const int8_t x[16]) {
uint8x16_t indices;
memcpy(&indices, x, 16);
uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(-1), indices));
uint64_t bitfield =
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
if (bitfield == 0)
return -1;
return 15 - __builtin_clzll(bitfield) / 4;
}
#endif
Node *getChild(Node *self, uint8_t index) {
@@ -624,7 +646,7 @@ Node *getChild(Node *self, uint8_t index) {
}
// Precondition - an entry for index must exist in the node
static Node *&getChildExists(Node *self, uint8_t index) {
Node *&getChildExists(Node *self, uint8_t index) {
if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
return self4->children[getNodeIndex(self4, index)];
@@ -724,7 +746,6 @@ int getChildGeq(Node *self, int child) {
}
int getChildLeq(Node *self, int child) {
// TODO simd
if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
for (int i = self->numChildren - 1; i >= 0; --i) {
@@ -747,25 +768,75 @@ int getChildLeq(Node *self, int child) {
}
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
// TODO the plain loop is faster?
#if 0 && (defined(HAS_AVX) || defined(HAS_ARM_NEON))
int i = child;
if (i < 0) {
return -1;
}
for (; (i & 0xf) != 0; --i) {
if (self48->index[i] >= 0) {
assert(self48->children[self48->index[i]] != nullptr);
return i;
}
}
if (self48->index[i] >= 0) {
assert(self48->children[self48->index[i]] != nullptr);
return i;
}
i -= 16;
for (; i >= 0; i -= 16) {
auto result = lastNonNeg1(self48->index + i);
if (result != -1) {
return i + result;
}
}
#else
for (int i = child; i >= 0; --i) {
if (self48->index[i] >= 0) {
assert(self48->children[self48->index[i]] != nullptr);
return i;
}
}
#endif
} else {
auto *self256 = static_cast<Node256 *>(self);
// TODO simd?
// TODO: The plain loop is faster?
#if 0 && defined(__clang__)
int i = child;
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
for (; (i & (kUnrollCount - 1)) != 0; --i) {
if (self256->children[i]) {
return i;
}
}
if (self256->children[i]) {
return i;
}
i -= kUnrollCount;
for (; i >= 0; i -= kUnrollCount) {
uint8_t nonNull[kUnrollCount];
for (int j = 0; j < kUnrollCount; ++j) {
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
}
uint64_t word;
memcpy(&word, nonNull, kUnrollCount);
if (word) {
return i + 8 - __builtin_clzll(word) / 8;
}
}
#else
for (int i = child; i >= 0; --i) {
if (self256->children[i]) {
return i;
}
}
#endif
}
return -1;
}
static void setChildrenParents(Node *node) {
void setChildrenParents(Node *node) {
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
getChildExists(node, i)->parent = node;
}