avx512 implementations for fixupMaxVersion
Some checks failed
Tests / Clang total: 3296, failed: 1, passed: 3295
Tests / 64 bit versions total: 3296, passed: 3296
Tests / Debug total: 3294, failed: 1, passed: 3293
Tests / SIMD fallback total: 3296, passed: 3296
Tests / Release [gcc] total: 3296, failed: 1, passed: 3295
Tests / Release [gcc,aarch64] total: 2458, passed: 2458
Tests / Coverage total: 2476, failed: 1, passed: 2475
weaselab/conflict-set/pipeline/head There was a failure building this commit
Some checks failed
Tests / Clang total: 3296, failed: 1, passed: 3295
Tests / 64 bit versions total: 3296, passed: 3296
Tests / Debug total: 3294, failed: 1, passed: 3293
Tests / SIMD fallback total: 3296, passed: 3296
Tests / Release [gcc] total: 3296, failed: 1, passed: 3295
Tests / Release [gcc,aarch64] total: 2458, passed: 2458
Tests / Coverage total: 2476, failed: 1, passed: 2475
weaselab/conflict-set/pipeline/head There was a failure building this commit
This commit is contained in:
105
ConflictSet.cpp
105
ConflictSet.cpp
@@ -1604,7 +1604,6 @@ __attribute__((target("avx512f"))) void rezero16(InternalVersionT *vs,
|
||||
_mm512_sub_epi32(_mm512_loadu_epi32(vs), zvec), _mm512_setzero_epi32());
|
||||
_mm512_mask_storeu_epi32(vs, m, zvec);
|
||||
}
|
||||
|
||||
__attribute__((target("default")))
|
||||
#endif
|
||||
|
||||
@@ -2471,6 +2470,7 @@ checkMaxBetweenExclusive(Node *n, int begin, int end,
|
||||
}
|
||||
__attribute__((target("default")))
|
||||
#endif
|
||||
|
||||
bool checkMaxBetweenExclusive(Node *n, int begin, int end,
|
||||
InternalVersionT readVersion, ReadContext *tls) {
|
||||
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion, tls);
|
||||
@@ -2910,6 +2910,71 @@ void addPointWrite(Node *&root, std::span<const uint8_t> key,
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
|
||||
__attribute__((target("avx512f"))) InternalVersionT
|
||||
horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT z, int len) {
|
||||
assume(len <= 16);
|
||||
#if USE_64_BIT
|
||||
// Hope it gets vectorized
|
||||
InternalVersionT max = vs[0];
|
||||
for (int i = 1; i < len; ++i) {
|
||||
max = std::max(vs[i], max);
|
||||
}
|
||||
return max;
|
||||
#else
|
||||
uint32_t zero;
|
||||
memcpy(&zero, &z, sizeof(zero));
|
||||
auto zeroVec = _mm512_set1_epi32(zero);
|
||||
return InternalVersionT(
|
||||
zero +
|
||||
_mm512_reduce_max_epu32(_mm512_sub_epi32(
|
||||
_mm512_mask_loadu_epi32(zeroVec, _mm512_int2mask((1 << len) - 1), vs),
|
||||
zeroVec)));
|
||||
#endif
|
||||
}
|
||||
__attribute__((target("default")))
|
||||
#endif
|
||||
|
||||
InternalVersionT
|
||||
horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT, int len) {
|
||||
assume(len <= 16);
|
||||
InternalVersionT max = vs[0];
|
||||
for (int i = 1; i < len; ++i) {
|
||||
max = std::max(vs[i], max);
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
|
||||
__attribute__((target("avx512f"))) InternalVersionT
|
||||
horizontalMax16(InternalVersionT *vs, InternalVersionT z) {
|
||||
#if USE_64_BIT
|
||||
// Hope it gets vectorized
|
||||
InternalVersionT max = vs[0];
|
||||
for (int i = 1; i < 16; ++i) {
|
||||
max = std::max(vs[i], max);
|
||||
}
|
||||
return max;
|
||||
#else
|
||||
uint32_t zero;
|
||||
memcpy(&zero, &z, sizeof(zero));
|
||||
auto zeroVec = _mm512_set1_epi32(zero);
|
||||
return InternalVersionT(zero + _mm512_reduce_max_epu32(_mm512_sub_epi32(
|
||||
_mm512_loadu_epi32(vs), zeroVec)));
|
||||
#endif
|
||||
}
|
||||
__attribute__((target("default")))
|
||||
#endif
|
||||
|
||||
InternalVersionT
|
||||
horizontalMax16(InternalVersionT *vs, InternalVersionT) {
|
||||
InternalVersionT max = vs[0];
|
||||
for (int i = 1; i < 16; ++i) {
|
||||
max = std::max(vs[i], max);
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
// Precondition: `node->entryPresent`, and node is not the root
|
||||
void fixupMaxVersion(Node *node, WriteContext *tls) {
|
||||
assert(node->parent);
|
||||
@@ -2921,15 +2986,13 @@ void fixupMaxVersion(Node *node, WriteContext *tls) {
|
||||
break;
|
||||
case Type_Node3: {
|
||||
auto *self3 = static_cast<Node3 *>(node);
|
||||
for (int i = 0; i < self3->numChildren; ++i) {
|
||||
max = std::max(self3->childMaxVersion[i], max);
|
||||
}
|
||||
max = std::max(max, horizontalMaxUpTo16(self3->childMaxVersion, tls->zero,
|
||||
self3->numChildren));
|
||||
} break;
|
||||
case Type_Node16: {
|
||||
auto *self16 = static_cast<Node16 *>(node);
|
||||
for (int i = 0; i < self16->numChildren; ++i) {
|
||||
max = std::max(self16->childMaxVersion[i], max);
|
||||
}
|
||||
max = std::max(max, horizontalMaxUpTo16(self16->childMaxVersion, tls->zero,
|
||||
self16->numChildren));
|
||||
} break;
|
||||
case Type_Node48: {
|
||||
auto *self48 = static_cast<Node48 *>(node);
|
||||
@@ -2939,9 +3002,7 @@ void fixupMaxVersion(Node *node, WriteContext *tls) {
|
||||
} break;
|
||||
case Type_Node256: {
|
||||
auto *self256 = static_cast<Node256 *>(node);
|
||||
for (auto v : self256->maxOfMax) {
|
||||
max = std::max(v, max);
|
||||
}
|
||||
max = std::max(max, horizontalMax16(self256->childMaxVersion, tls->zero));
|
||||
} break;
|
||||
default: // GCOVR_EXCL_LINE
|
||||
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
||||
@@ -4033,6 +4094,24 @@ template <int kN> void benchScan2() {
|
||||
});
|
||||
}
|
||||
|
||||
void benchHorizontal16() {
|
||||
ankerl::nanobench::Bench bench;
|
||||
InternalVersionT vs[16];
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
vs[i] = InternalVersionT(rand() % 1000 + 1000);
|
||||
}
|
||||
#if !USE_64_BIT
|
||||
InternalVersionT::zero = InternalVersionT(rand() % 1000);
|
||||
#endif
|
||||
bench.run("horizontal16", [&]() {
|
||||
bench.doNotOptimizeAway(horizontalMax16(vs, InternalVersionT::zero));
|
||||
});
|
||||
int x = rand() % 15 + 1;
|
||||
bench.run("horizontalUpTo16", [&]() {
|
||||
bench.doNotOptimizeAway(horizontalMaxUpTo16(vs, InternalVersionT::zero, x));
|
||||
});
|
||||
}
|
||||
|
||||
void benchLCP(int len) {
|
||||
ankerl::nanobench::Bench bench;
|
||||
std::vector<uint8_t> lhs(len);
|
||||
@@ -4065,11 +4144,7 @@ void printTree() {
|
||||
debugPrintDot(stdout, cs.root, &cs);
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
for (int i = 0; i < 256; ++i) {
|
||||
benchLCP(i);
|
||||
}
|
||||
}
|
||||
int main(void) { benchHorizontal16(); }
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FUZZ
|
||||
|
Reference in New Issue
Block a user