From 0814822d821092cb0e2aa67f4afa418ce8d773e0 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Fri, 13 Sep 2024 22:01:56 -0700 Subject: [PATCH] avx512 implementations for fixupMaxVersion --- ConflictSet.cpp | 105 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 90 insertions(+), 15 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 2e2d6e9..3997c61 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -1604,7 +1604,6 @@ __attribute__((target("avx512f"))) void rezero16(InternalVersionT *vs, _mm512_sub_epi32(_mm512_loadu_epi32(vs), zvec), _mm512_setzero_epi32()); _mm512_mask_storeu_epi32(vs, m, zvec); } - __attribute__((target("default"))) #endif @@ -2471,6 +2470,7 @@ checkMaxBetweenExclusive(Node *n, int begin, int end, } __attribute__((target("default"))) #endif + bool checkMaxBetweenExclusive(Node *n, int begin, int end, InternalVersionT readVersion, ReadContext *tls) { return checkMaxBetweenExclusiveImpl(n, begin, end, readVersion, tls); @@ -2910,6 +2910,71 @@ void addPointWrite(Node *&root, std::span key, } } +#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__) +__attribute__((target("avx512f"))) InternalVersionT +horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT z, int len) { + assume(len <= 16); +#if USE_64_BIT + // Hope it gets vectorized + InternalVersionT max = vs[0]; + for (int i = 1; i < len; ++i) { + max = std::max(vs[i], max); + } + return max; +#else + uint32_t zero; + memcpy(&zero, &z, sizeof(zero)); + auto zeroVec = _mm512_set1_epi32(zero); + return InternalVersionT( + zero + + _mm512_reduce_max_epu32(_mm512_sub_epi32( + _mm512_mask_loadu_epi32(zeroVec, _mm512_int2mask((1 << len) - 1), vs), + zeroVec))); +#endif +} +__attribute__((target("default"))) +#endif + +InternalVersionT +horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT, int len) { + assume(len <= 16); + InternalVersionT max = vs[0]; + for (int i = 1; i < len; ++i) { + max = std::max(vs[i], max); + } + return max; +} + +#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__) +__attribute__((target("avx512f"))) InternalVersionT +horizontalMax16(InternalVersionT *vs, InternalVersionT z) { +#if USE_64_BIT + // Hope it gets vectorized + InternalVersionT max = vs[0]; + for (int i = 1; i < 16; ++i) { + max = std::max(vs[i], max); + } + return max; +#else + uint32_t zero; + memcpy(&zero, &z, sizeof(zero)); + auto zeroVec = _mm512_set1_epi32(zero); + return InternalVersionT(zero + _mm512_reduce_max_epu32(_mm512_sub_epi32( + _mm512_loadu_epi32(vs), zeroVec))); +#endif +} +__attribute__((target("default"))) +#endif + +InternalVersionT +horizontalMax16(InternalVersionT *vs, InternalVersionT) { + InternalVersionT max = vs[0]; + for (int i = 1; i < 16; ++i) { + max = std::max(vs[i], max); + } + return max; +} + // Precondition: `node->entryPresent`, and node is not the root void fixupMaxVersion(Node *node, WriteContext *tls) { assert(node->parent); @@ -2921,15 +2986,13 @@ void fixupMaxVersion(Node *node, WriteContext *tls) { break; case Type_Node3: { auto *self3 = static_cast(node); - for (int i = 0; i < self3->numChildren; ++i) { - max = std::max(self3->childMaxVersion[i], max); - } + max = std::max(max, horizontalMaxUpTo16(self3->childMaxVersion, tls->zero, + self3->numChildren)); } break; case Type_Node16: { auto *self16 = static_cast(node); - for (int i = 0; i < self16->numChildren; ++i) { - max = std::max(self16->childMaxVersion[i], max); - } + max = std::max(max, horizontalMaxUpTo16(self16->childMaxVersion, tls->zero, + self16->numChildren)); } break; case Type_Node48: { auto *self48 = static_cast(node); @@ -2939,9 +3002,7 @@ void fixupMaxVersion(Node *node, WriteContext *tls) { } break; case Type_Node256: { auto *self256 = static_cast(node); - for (auto v : self256->maxOfMax) { - max = std::max(v, max); - } + max = std::max(max, horizontalMax16(self256->childMaxVersion, tls->zero)); } break; default: // GCOVR_EXCL_LINE __builtin_unreachable(); // GCOVR_EXCL_LINE @@ -4033,6 +4094,24 @@ template void benchScan2() { }); } +void benchHorizontal16() { + ankerl::nanobench::Bench bench; + InternalVersionT vs[16]; + for (int i = 0; i < 16; ++i) { + vs[i] = InternalVersionT(rand() % 1000 + 1000); + } +#if !USE_64_BIT + InternalVersionT::zero = InternalVersionT(rand() % 1000); +#endif + bench.run("horizontal16", [&]() { + bench.doNotOptimizeAway(horizontalMax16(vs, InternalVersionT::zero)); + }); + int x = rand() % 15 + 1; + bench.run("horizontalUpTo16", [&]() { + bench.doNotOptimizeAway(horizontalMaxUpTo16(vs, InternalVersionT::zero, x)); + }); +} + void benchLCP(int len) { ankerl::nanobench::Bench bench; std::vector lhs(len); @@ -4065,11 +4144,7 @@ void printTree() { debugPrintDot(stdout, cs.root, &cs); } -int main(void) { - for (int i = 0; i < 256; ++i) { - benchLCP(i); - } -} +int main(void) { benchHorizontal16(); } #endif #ifdef ENABLE_FUZZ