SIMD tinkering
This commit is contained in:
@@ -502,7 +502,7 @@ struct Node256 : Node {
|
|||||||
Node256() { this->type = Type::Node256; }
|
Node256() { this->type = Type::Node256; }
|
||||||
};
|
};
|
||||||
|
|
||||||
static int getNodeIndex(Node4 *self, uint8_t index) {
|
int getNodeIndex(Node4 *self, uint8_t index) {
|
||||||
for (int i = 0; i < self->numChildren; ++i) {
|
for (int i = 0; i < self->numChildren; ++i) {
|
||||||
if (self->index[i] == index) {
|
if (self->index[i] == index) {
|
||||||
return i;
|
return i;
|
||||||
@@ -511,7 +511,7 @@ static int getNodeIndex(Node4 *self, uint8_t index) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int getNodeIndex(Node16 *self, uint8_t index) {
|
int getNodeIndex(Node16 *self, uint8_t index) {
|
||||||
#ifdef HAS_AVX
|
#ifdef HAS_AVX
|
||||||
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
|
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
|
||||||
|
|
||||||
@@ -580,6 +580,17 @@ int firstNonNeg1(const int8_t x[16]) {
|
|||||||
return -1;
|
return -1;
|
||||||
return __builtin_clz(bitfield);
|
return __builtin_clz(bitfield);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int lastNonNeg1(const int8_t x[16]) {
|
||||||
|
__m128i key_vec = _mm_set1_epi8(-1);
|
||||||
|
__m128i indices;
|
||||||
|
memcpy(&indices, x, 16);
|
||||||
|
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
|
||||||
|
uint32_t bitfield = _mm_movemask_epi8(results) ^ 0xffff;
|
||||||
|
if (bitfield == 0)
|
||||||
|
return -1;
|
||||||
|
return 15 - __builtin_ctz(bitfield);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef HAS_ARM_NEON
|
#ifdef HAS_ARM_NEON
|
||||||
@@ -593,6 +604,17 @@ int firstNonNeg1(const int8_t x[16]) {
|
|||||||
return -1;
|
return -1;
|
||||||
return __builtin_ctzll(bitfield) / 4;
|
return __builtin_ctzll(bitfield) / 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int lastNonNeg1(const int8_t x[16]) {
|
||||||
|
uint8x16_t indices;
|
||||||
|
memcpy(&indices, x, 16);
|
||||||
|
uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(-1), indices));
|
||||||
|
uint64_t bitfield =
|
||||||
|
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
|
||||||
|
if (bitfield == 0)
|
||||||
|
return -1;
|
||||||
|
return 15 - __builtin_clzll(bitfield) / 4;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Node *getChild(Node *self, uint8_t index) {
|
Node *getChild(Node *self, uint8_t index) {
|
||||||
@@ -624,7 +646,7 @@ Node *getChild(Node *self, uint8_t index) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Precondition - an entry for index must exist in the node
|
// Precondition - an entry for index must exist in the node
|
||||||
static Node *&getChildExists(Node *self, uint8_t index) {
|
Node *&getChildExists(Node *self, uint8_t index) {
|
||||||
if (self->type == Type::Node4) {
|
if (self->type == Type::Node4) {
|
||||||
auto *self4 = static_cast<Node4 *>(self);
|
auto *self4 = static_cast<Node4 *>(self);
|
||||||
return self4->children[getNodeIndex(self4, index)];
|
return self4->children[getNodeIndex(self4, index)];
|
||||||
@@ -724,7 +746,6 @@ int getChildGeq(Node *self, int child) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int getChildLeq(Node *self, int child) {
|
int getChildLeq(Node *self, int child) {
|
||||||
// TODO simd
|
|
||||||
if (self->type == Type::Node4) {
|
if (self->type == Type::Node4) {
|
||||||
auto *self4 = static_cast<Node4 *>(self);
|
auto *self4 = static_cast<Node4 *>(self);
|
||||||
for (int i = self->numChildren - 1; i >= 0; --i) {
|
for (int i = self->numChildren - 1; i >= 0; --i) {
|
||||||
@@ -747,25 +768,75 @@ int getChildLeq(Node *self, int child) {
|
|||||||
}
|
}
|
||||||
} else if (self->type == Type::Node48) {
|
} else if (self->type == Type::Node48) {
|
||||||
auto *self48 = static_cast<Node48 *>(self);
|
auto *self48 = static_cast<Node48 *>(self);
|
||||||
|
// TODO the plain loop is faster?
|
||||||
|
#if 0 && (defined(HAS_AVX) || defined(HAS_ARM_NEON))
|
||||||
|
int i = child;
|
||||||
|
if (i < 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
for (; (i & 0xf) != 0; --i) {
|
||||||
|
if (self48->index[i] >= 0) {
|
||||||
|
assert(self48->children[self48->index[i]] != nullptr);
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (self48->index[i] >= 0) {
|
||||||
|
assert(self48->children[self48->index[i]] != nullptr);
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
i -= 16;
|
||||||
|
for (; i >= 0; i -= 16) {
|
||||||
|
auto result = lastNonNeg1(self48->index + i);
|
||||||
|
if (result != -1) {
|
||||||
|
return i + result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
for (int i = child; i >= 0; --i) {
|
for (int i = child; i >= 0; --i) {
|
||||||
if (self48->index[i] >= 0) {
|
if (self48->index[i] >= 0) {
|
||||||
assert(self48->children[self48->index[i]] != nullptr);
|
assert(self48->children[self48->index[i]] != nullptr);
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
auto *self256 = static_cast<Node256 *>(self);
|
auto *self256 = static_cast<Node256 *>(self);
|
||||||
// TODO simd?
|
// TODO: The plain loop is faster?
|
||||||
|
#if 0 && defined(__clang__)
|
||||||
|
int i = child;
|
||||||
|
constexpr int kUnrollCount = 8; // Must be a power of two and <= 8
|
||||||
|
for (; (i & (kUnrollCount - 1)) != 0; --i) {
|
||||||
|
if (self256->children[i]) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (self256->children[i]) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
i -= kUnrollCount;
|
||||||
|
for (; i >= 0; i -= kUnrollCount) {
|
||||||
|
uint8_t nonNull[kUnrollCount];
|
||||||
|
for (int j = 0; j < kUnrollCount; ++j) {
|
||||||
|
nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0;
|
||||||
|
}
|
||||||
|
uint64_t word;
|
||||||
|
memcpy(&word, nonNull, kUnrollCount);
|
||||||
|
if (word) {
|
||||||
|
return i + 8 - __builtin_clzll(word) / 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
for (int i = child; i >= 0; --i) {
|
for (int i = child; i >= 0; --i) {
|
||||||
if (self256->children[i]) {
|
if (self256->children[i]) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void setChildrenParents(Node *node) {
|
void setChildrenParents(Node *node) {
|
||||||
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
|
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
|
||||||
getChildExists(node, i)->parent = node;
|
getChildExists(node, i)->parent = node;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user