Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d9e4a7d1b6 | |||
| 52201fa4c7 | |||
| 0814822d82 | |||
| 41df2398e8 | |||
| 84c4d0fcba | |||
| 6241533dfb | |||
| 0abf6a1ecf |
+154
-23
@@ -766,6 +766,8 @@ private:
|
|||||||
|
|
||||||
int getNodeIndex(Node3 *self, uint8_t index) {
|
int getNodeIndex(Node3 *self, uint8_t index) {
|
||||||
Node3 *n = (Node3 *)self;
|
Node3 *n = (Node3 *)self;
|
||||||
|
assume(n->numChildren >= 1);
|
||||||
|
assume(n->numChildren <= 3);
|
||||||
for (int i = 0; i < n->numChildren; ++i) {
|
for (int i = 0; i < n->numChildren; ++i) {
|
||||||
if (n->index[i] == index) {
|
if (n->index[i] == index) {
|
||||||
return i;
|
return i;
|
||||||
@@ -774,6 +776,18 @@ int getNodeIndex(Node3 *self, uint8_t index) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getNodeIndexExists(Node3 *self, uint8_t index) {
|
||||||
|
Node3 *n = (Node3 *)self;
|
||||||
|
assume(n->numChildren >= 1);
|
||||||
|
assume(n->numChildren <= 3);
|
||||||
|
for (int i = 0; i < n->numChildren; ++i) {
|
||||||
|
if (n->index[i] == index) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
|
}
|
||||||
|
|
||||||
int getNodeIndex(Node16 *self, uint8_t index) {
|
int getNodeIndex(Node16 *self, uint8_t index) {
|
||||||
|
|
||||||
#ifdef HAS_AVX
|
#ifdef HAS_AVX
|
||||||
@@ -834,13 +848,52 @@ int getNodeIndex(Node16 *self, uint8_t index) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getNodeIndexExists(Node16 *self, uint8_t index) {
|
||||||
|
|
||||||
|
#ifdef HAS_AVX
|
||||||
|
__m128i key_vec = _mm_set1_epi8(index);
|
||||||
|
__m128i indices;
|
||||||
|
memcpy(&indices, self->index, Node16::kMaxNodes);
|
||||||
|
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
|
||||||
|
uint32_t mask = (1 << self->numChildren) - 1;
|
||||||
|
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
|
||||||
|
assume(bitfield != 0);
|
||||||
|
return std::countr_zero(bitfield);
|
||||||
|
#elif defined(HAS_ARM_NEON)
|
||||||
|
// Based on
|
||||||
|
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
|
||||||
|
|
||||||
|
uint8x16_t indices;
|
||||||
|
memcpy(&indices, self->index, Node16::kMaxNodes);
|
||||||
|
// 0xff for each match
|
||||||
|
uint16x8_t results =
|
||||||
|
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
|
||||||
|
assume(self->numChildren <= Node16::kMaxNodes);
|
||||||
|
uint64_t mask = self->numChildren == 16
|
||||||
|
? uint64_t(-1)
|
||||||
|
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
||||||
|
// 0xf for each match in valid range
|
||||||
|
uint64_t bitfield =
|
||||||
|
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
|
||||||
|
assume(bitfield != 0);
|
||||||
|
return std::countr_zero(bitfield) / 4;
|
||||||
|
#else
|
||||||
|
for (int i = 0; i < self->numChildren; ++i) {
|
||||||
|
if (self->index[i] == index) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// Precondition - an entry for index must exist in the node
|
// Precondition - an entry for index must exist in the node
|
||||||
Node *&getChildExists(Node3 *self, uint8_t index) {
|
Node *&getChildExists(Node3 *self, uint8_t index) {
|
||||||
return self->children[getNodeIndex(self, index)];
|
return self->children[getNodeIndexExists(self, index)];
|
||||||
}
|
}
|
||||||
// Precondition - an entry for index must exist in the node
|
// Precondition - an entry for index must exist in the node
|
||||||
Node *&getChildExists(Node16 *self, uint8_t index) {
|
Node *&getChildExists(Node16 *self, uint8_t index) {
|
||||||
return self->children[getNodeIndex(self, index)];
|
return self->children[getNodeIndexExists(self, index)];
|
||||||
}
|
}
|
||||||
// Precondition - an entry for index must exist in the node
|
// Precondition - an entry for index must exist in the node
|
||||||
Node *&getChildExists(Node48 *self, uint8_t index) {
|
Node *&getChildExists(Node48 *self, uint8_t index) {
|
||||||
@@ -885,12 +938,12 @@ InternalVersionT maxVersion(Node *n) {
|
|||||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
case Type_Node3: {
|
case Type_Node3: {
|
||||||
auto *n3 = static_cast<Node3 *>(n);
|
auto *n3 = static_cast<Node3 *>(n);
|
||||||
int i = getNodeIndex(n3, index);
|
int i = getNodeIndexExists(n3, index);
|
||||||
return n3->childMaxVersion[i];
|
return n3->childMaxVersion[i];
|
||||||
}
|
}
|
||||||
case Type_Node16: {
|
case Type_Node16: {
|
||||||
auto *n16 = static_cast<Node16 *>(n);
|
auto *n16 = static_cast<Node16 *>(n);
|
||||||
int i = getNodeIndex(n16, index);
|
int i = getNodeIndexExists(n16, index);
|
||||||
return n16->childMaxVersion[i];
|
return n16->childMaxVersion[i];
|
||||||
}
|
}
|
||||||
case Type_Node48: {
|
case Type_Node48: {
|
||||||
@@ -918,12 +971,12 @@ InternalVersionT exchangeMaxVersion(Node *n, InternalVersionT newMax) {
|
|||||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
case Type_Node3: {
|
case Type_Node3: {
|
||||||
auto *n3 = static_cast<Node3 *>(n);
|
auto *n3 = static_cast<Node3 *>(n);
|
||||||
int i = getNodeIndex(n3, index);
|
int i = getNodeIndexExists(n3, index);
|
||||||
return std::exchange(n3->childMaxVersion[i], newMax);
|
return std::exchange(n3->childMaxVersion[i], newMax);
|
||||||
}
|
}
|
||||||
case Type_Node16: {
|
case Type_Node16: {
|
||||||
auto *n16 = static_cast<Node16 *>(n);
|
auto *n16 = static_cast<Node16 *>(n);
|
||||||
int i = getNodeIndex(n16, index);
|
int i = getNodeIndexExists(n16, index);
|
||||||
return std::exchange(n16->childMaxVersion[i], newMax);
|
return std::exchange(n16->childMaxVersion[i], newMax);
|
||||||
}
|
}
|
||||||
case Type_Node48: {
|
case Type_Node48: {
|
||||||
@@ -952,13 +1005,13 @@ void setMaxVersion(Node *n, InternalVersionT newMax) {
|
|||||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
case Type_Node3: {
|
case Type_Node3: {
|
||||||
auto *n3 = static_cast<Node3 *>(n);
|
auto *n3 = static_cast<Node3 *>(n);
|
||||||
int i = getNodeIndex(n3, index);
|
int i = getNodeIndexExists(n3, index);
|
||||||
n3->childMaxVersion[i] = newMax;
|
n3->childMaxVersion[i] = newMax;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case Type_Node16: {
|
case Type_Node16: {
|
||||||
auto *n16 = static_cast<Node16 *>(n);
|
auto *n16 = static_cast<Node16 *>(n);
|
||||||
int i = getNodeIndex(n16, index);
|
int i = getNodeIndexExists(n16, index);
|
||||||
n16->childMaxVersion[i] = newMax;
|
n16->childMaxVersion[i] = newMax;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -1070,6 +1123,8 @@ ChildAndMaxVersion getChildAndMaxVersion(Node *self, uint8_t index) {
|
|||||||
Node *getChildGeq(Node0 *, int) { return nullptr; }
|
Node *getChildGeq(Node0 *, int) { return nullptr; }
|
||||||
|
|
||||||
Node *getChildGeq(Node3 *n, int child) {
|
Node *getChildGeq(Node3 *n, int child) {
|
||||||
|
assume(n->numChildren >= 1);
|
||||||
|
assume(n->numChildren <= 3);
|
||||||
for (int i = 0; i < n->numChildren; ++i) {
|
for (int i = 0; i < n->numChildren; ++i) {
|
||||||
if (n->index[i] >= child) {
|
if (n->index[i] >= child) {
|
||||||
return n->children[i];
|
return n->children[i];
|
||||||
@@ -1549,7 +1604,6 @@ __attribute__((target("avx512f"))) void rezero16(InternalVersionT *vs,
|
|||||||
_mm512_sub_epi32(_mm512_loadu_epi32(vs), zvec), _mm512_setzero_epi32());
|
_mm512_sub_epi32(_mm512_loadu_epi32(vs), zvec), _mm512_setzero_epi32());
|
||||||
_mm512_mask_storeu_epi32(vs, m, zvec);
|
_mm512_mask_storeu_epi32(vs, m, zvec);
|
||||||
}
|
}
|
||||||
|
|
||||||
__attribute__((target("default")))
|
__attribute__((target("default")))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -2416,6 +2470,7 @@ checkMaxBetweenExclusive(Node *n, int begin, int end,
|
|||||||
}
|
}
|
||||||
__attribute__((target("default")))
|
__attribute__((target("default")))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool checkMaxBetweenExclusive(Node *n, int begin, int end,
|
bool checkMaxBetweenExclusive(Node *n, int begin, int end,
|
||||||
InternalVersionT readVersion, ReadContext *tls) {
|
InternalVersionT readVersion, ReadContext *tls) {
|
||||||
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion, tls);
|
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion, tls);
|
||||||
@@ -2855,6 +2910,72 @@ void addPointWrite(Node *&root, std::span<const uint8_t> key,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
|
||||||
|
__attribute__((target("avx512f"))) InternalVersionT
|
||||||
|
horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT z, int len) {
|
||||||
|
assume(len <= 16);
|
||||||
|
#if USE_64_BIT
|
||||||
|
// Hope it gets vectorized
|
||||||
|
InternalVersionT max = vs[0];
|
||||||
|
for (int i = 1; i < len; ++i) {
|
||||||
|
max = std::max(vs[i], max);
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
#else
|
||||||
|
uint32_t zero;
|
||||||
|
memcpy(&zero, &z, sizeof(zero));
|
||||||
|
auto zeroVec = _mm512_set1_epi32(zero);
|
||||||
|
auto max = InternalVersionT(
|
||||||
|
zero +
|
||||||
|
_mm512_reduce_max_epi32(_mm512_sub_epi32(
|
||||||
|
_mm512_mask_loadu_epi32(zeroVec, _mm512_int2mask((1 << len) - 1), vs),
|
||||||
|
zeroVec)));
|
||||||
|
return max;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
__attribute__((target("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
InternalVersionT
|
||||||
|
horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT, int len) {
|
||||||
|
assume(len <= 16);
|
||||||
|
InternalVersionT max = vs[0];
|
||||||
|
for (int i = 1; i < len; ++i) {
|
||||||
|
max = std::max(vs[i], max);
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
|
||||||
|
__attribute__((target("avx512f"))) InternalVersionT
|
||||||
|
horizontalMax16(InternalVersionT *vs, InternalVersionT z) {
|
||||||
|
#if USE_64_BIT
|
||||||
|
// Hope it gets vectorized
|
||||||
|
InternalVersionT max = vs[0];
|
||||||
|
for (int i = 1; i < 16; ++i) {
|
||||||
|
max = std::max(vs[i], max);
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
#else
|
||||||
|
uint32_t zero;
|
||||||
|
memcpy(&zero, &z, sizeof(zero));
|
||||||
|
auto zeroVec = _mm512_set1_epi32(zero);
|
||||||
|
return InternalVersionT(zero + _mm512_reduce_max_epi32(_mm512_sub_epi32(
|
||||||
|
_mm512_loadu_epi32(vs), zeroVec)));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
__attribute__((target("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
InternalVersionT
|
||||||
|
horizontalMax16(InternalVersionT *vs, InternalVersionT) {
|
||||||
|
InternalVersionT max = vs[0];
|
||||||
|
for (int i = 1; i < 16; ++i) {
|
||||||
|
max = std::max(vs[i], max);
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
// Precondition: `node->entryPresent`, and node is not the root
|
// Precondition: `node->entryPresent`, and node is not the root
|
||||||
void fixupMaxVersion(Node *node, WriteContext *tls) {
|
void fixupMaxVersion(Node *node, WriteContext *tls) {
|
||||||
assert(node->parent);
|
assert(node->parent);
|
||||||
@@ -2866,15 +2987,13 @@ void fixupMaxVersion(Node *node, WriteContext *tls) {
|
|||||||
break;
|
break;
|
||||||
case Type_Node3: {
|
case Type_Node3: {
|
||||||
auto *self3 = static_cast<Node3 *>(node);
|
auto *self3 = static_cast<Node3 *>(node);
|
||||||
for (int i = 0; i < self3->numChildren; ++i) {
|
max = std::max(max, horizontalMaxUpTo16(self3->childMaxVersion, tls->zero,
|
||||||
max = std::max(self3->childMaxVersion[i], max);
|
self3->numChildren));
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
case Type_Node16: {
|
case Type_Node16: {
|
||||||
auto *self16 = static_cast<Node16 *>(node);
|
auto *self16 = static_cast<Node16 *>(node);
|
||||||
for (int i = 0; i < self16->numChildren; ++i) {
|
max = std::max(max, horizontalMaxUpTo16(self16->childMaxVersion, tls->zero,
|
||||||
max = std::max(self16->childMaxVersion[i], max);
|
self16->numChildren));
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
case Type_Node48: {
|
case Type_Node48: {
|
||||||
auto *self48 = static_cast<Node48 *>(node);
|
auto *self48 = static_cast<Node48 *>(node);
|
||||||
@@ -2884,9 +3003,7 @@ void fixupMaxVersion(Node *node, WriteContext *tls) {
|
|||||||
} break;
|
} break;
|
||||||
case Type_Node256: {
|
case Type_Node256: {
|
||||||
auto *self256 = static_cast<Node256 *>(node);
|
auto *self256 = static_cast<Node256 *>(node);
|
||||||
for (auto v : self256->maxOfMax) {
|
max = std::max(max, horizontalMax16(self256->childMaxVersion, tls->zero));
|
||||||
max = std::max(v, max);
|
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
default: // GCOVR_EXCL_LINE
|
default: // GCOVR_EXCL_LINE
|
||||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||||
@@ -3978,6 +4095,24 @@ template <int kN> void benchScan2() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void benchHorizontal16() {
|
||||||
|
ankerl::nanobench::Bench bench;
|
||||||
|
InternalVersionT vs[16];
|
||||||
|
for (int i = 0; i < 16; ++i) {
|
||||||
|
vs[i] = InternalVersionT(rand() % 1000 + 1000);
|
||||||
|
}
|
||||||
|
#if !USE_64_BIT
|
||||||
|
InternalVersionT::zero = InternalVersionT(rand() % 1000);
|
||||||
|
#endif
|
||||||
|
bench.run("horizontal16", [&]() {
|
||||||
|
bench.doNotOptimizeAway(horizontalMax16(vs, InternalVersionT::zero));
|
||||||
|
});
|
||||||
|
int x = rand() % 15 + 1;
|
||||||
|
bench.run("horizontalUpTo16", [&]() {
|
||||||
|
bench.doNotOptimizeAway(horizontalMaxUpTo16(vs, InternalVersionT::zero, x));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void benchLCP(int len) {
|
void benchLCP(int len) {
|
||||||
ankerl::nanobench::Bench bench;
|
ankerl::nanobench::Bench bench;
|
||||||
std::vector<uint8_t> lhs(len);
|
std::vector<uint8_t> lhs(len);
|
||||||
@@ -4010,11 +4145,7 @@ void printTree() {
|
|||||||
debugPrintDot(stdout, cs.root, &cs);
|
debugPrintDot(stdout, cs.root, &cs);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(void) {
|
int main(void) { benchHorizontal16(); }
|
||||||
for (int i = 0; i < 256; ++i) {
|
|
||||||
benchLCP(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_FUZZ
|
#ifdef ENABLE_FUZZ
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user