From 8e3eacb54f7196daa994c55d60b328ab7eb5f4be Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Sun, 30 Jun 2024 13:30:44 -0700 Subject: [PATCH] Apply function multi versioning higher in call stack to save branches --- ConflictSet.cpp | 112 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 36 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 6f09a34..f85a412 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -1709,11 +1709,7 @@ downLeftSpine: } #ifdef HAS_AVX -#ifndef __SANITIZE_THREAD__ -__attribute__((target("default"))) -#endif -uint32_t -compare16_32bit(const InternalVersionT *vs, InternalVersionT rv) { +uint32_t compare16_32bit(const InternalVersionT *vs, InternalVersionT rv) { uint32_t compared = 0; __m128i w[4]; memcpy(w, vs, sizeof(w)); @@ -1729,9 +1725,8 @@ compare16_32bit(const InternalVersionT *vs, InternalVersionT rv) { return compared; } -#ifndef __SANITIZE_THREAD__ __attribute__((target("avx512f"))) uint32_t -compare16_32bit(const InternalVersionT *vs, InternalVersionT rv) { +compare16_32bit_avx512(const InternalVersionT *vs, InternalVersionT rv) { __m512i w; memcpy(&w, vs, sizeof(w)); uint32_t r; @@ -1740,10 +1735,10 @@ compare16_32bit(const InternalVersionT *vs, InternalVersionT rv) { _mm512_setzero_epi32()); } #endif -#endif // Returns true if v[i] <= readVersion for all i such that begin <= is[i] < end // Preconditions: begin <= end, end - begin < 256 +template bool scan16(const InternalVersionT *vs, const uint8_t *is, int begin, int end, InternalVersionT readVersion) { @@ -1800,7 +1795,11 @@ bool scan16(const InternalVersionT *vs, const uint8_t *is, int begin, int end, uint32_t compared = 0; #if INTERNAL_VERSION_32_BIT - compared = compare16_32bit(vs, readVersion); + if constexpr (kAVX512) { + compared = compare16_32bit_avx512(vs, readVersion); + } else { + compared = compare16_32bit(vs, readVersion); + } #else for (int i = 0; i < 16; ++i) { compared |= (vs[i] > readVersion) << i; @@ -1830,6 +1829,7 @@ bool scan16(const InternalVersionT *vs, const uint8_t *is, int begin, int end, // Returns true if v[i] <= readVersion for all i such that begin <= i < end // // always_inline So that we can optimize when begin or end is a constant. +template inline __attribute((always_inline)) bool scan16(const InternalVersionT *vs, int begin, int end, InternalVersionT readVersion) { @@ -1862,7 +1862,12 @@ inline __attribute((always_inline)) bool scan16(const InternalVersionT *vs, conflict >>= begin << 2; return !conflict; #elif INTERNAL_VERSION_32_BIT && defined(HAS_AVX) - uint32_t conflict = compare16_32bit(vs, readVersion); + uint32_t conflict; + if constexpr (kAVX512) { + conflict = compare16_32bit_avx512(vs, readVersion); + } else { + conflict = compare16_32bit(vs, readVersion); + } conflict &= (1 << end) - 1; conflict >>= begin; return !conflict; @@ -1880,6 +1885,7 @@ inline __attribute((always_inline)) bool scan16(const InternalVersionT *vs, // Return whether or not the max version among all keys starting with the search // path of n + [child], where child in (begin, end) is <= readVersion. Does not // account for the range version of firstGt(searchpath(n) + [end - 1]) +template bool checkMaxBetweenExclusive(Node *n, int begin, int end, InternalVersionT readVersion) { assume(-1 <= begin); @@ -1934,7 +1940,8 @@ bool checkMaxBetweenExclusive(Node *n, int begin, int end, case Type_Node16: { auto *self = static_cast(n); - return scan16(self->childMaxVersion, self->index, begin, end, readVersion); + return scan16(self->childMaxVersion, self->index, begin, end, + readVersion); } case Type_Node48: { auto *self = static_cast(n); @@ -1942,9 +1949,10 @@ bool checkMaxBetweenExclusive(Node *n, int begin, int end, static_assert(Node48::kMaxOfMaxPageSize == 16); for (int i = 0; i < Node48::kMaxOfMaxTotalPages; ++i) { if (self->maxOfMax[i] > readVersion) { - if (!scan16(self->childMaxVersion + (i << Node48::kMaxOfMaxShift), - self->reverseIndex + (i << Node48::kMaxOfMaxShift), begin, - end, readVersion)) { + if (!scan16(self->childMaxVersion + + (i << Node48::kMaxOfMaxShift), + self->reverseIndex + (i << Node48::kMaxOfMaxShift), + begin, end, readVersion)) { return false; } } @@ -1966,31 +1974,33 @@ bool checkMaxBetweenExclusive(Node *n, int begin, int end, } const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1); const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift); - return scan16(self->childMaxVersion + - (firstPage << Node256::kMaxOfMaxShift), - intraPageBegin, intraPageEnd, readVersion); + return scan16(self->childMaxVersion + + (firstPage << Node256::kMaxOfMaxShift), + intraPageBegin, intraPageEnd, readVersion); } // Check the first page if (self->maxOfMax[firstPage] > readVersion) { const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1); - if (!scan16(self->childMaxVersion + - (firstPage << Node256::kMaxOfMaxShift), - intraPageBegin, 16, readVersion)) { + if (!scan16(self->childMaxVersion + + (firstPage << Node256::kMaxOfMaxShift), + intraPageBegin, 16, readVersion)) { return false; } } // Check the last page if (self->maxOfMax[lastPage] > readVersion) { const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift); - if (!scan16(self->childMaxVersion + (lastPage << Node256::kMaxOfMaxShift), - 0, intraPageEnd, readVersion)) { + if (!scan16(self->childMaxVersion + + (lastPage << Node256::kMaxOfMaxShift), + 0, intraPageEnd, readVersion)) { return false; } } // Check inner pages const int innerPageBegin = (begin >> Node256::kMaxOfMaxShift) + 1; const int innerPageEnd = (end - 1) >> Node256::kMaxOfMaxShift; - return scan16(self->maxOfMax, innerPageBegin, innerPageEnd, readVersion); + return scan16(self->maxOfMax, innerPageBegin, innerPageEnd, + readVersion); } default: // GCOVR_EXCL_LINE __builtin_unreachable(); // GCOVR_EXCL_LINE @@ -2019,6 +2029,7 @@ Vector getSearchPath(Arena &arena, Node *n) { // // Precondition: transitively, no child of n has a search path that's a longer // prefix of key than n +template bool checkRangeStartsWith(Node *n, std::span key, int begin, int end, InternalVersionT readVersion, ConflictSet::Impl *impl) { @@ -2027,7 +2038,7 @@ bool checkRangeStartsWith(Node *n, std::span key, int begin, #endif auto remaining = key; if (remaining.size() == 0) { - return checkMaxBetweenExclusive(n, begin, end, readVersion); + return checkMaxBetweenExclusive(n, begin, end, readVersion); } auto *child = getChild(n, remaining[0]); @@ -2088,9 +2099,10 @@ downLeftSpine: } } +namespace { // Return true if the max version among all keys that start with key[:prefixLen] // that are >= key is <= readVersion -struct CheckRangeLeftSide { +template struct CheckRangeLeftSide { CheckRangeLeftSide(Node *n, std::span key, int prefixLen, InternalVersionT readVersion, ConflictSet::Impl *impl) : n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion), @@ -2122,7 +2134,8 @@ struct CheckRangeLeftSide { } if (searchPathLen >= prefixLen) { - if (!checkMaxBetweenExclusive(n, remaining[0], 256, readVersion)) { + if (!checkMaxBetweenExclusive(n, remaining[0], 256, + readVersion)) { ok = false; return true; } @@ -2209,7 +2222,7 @@ struct CheckRangeLeftSide { // Return true if the max version among all keys that start with key[:prefixLen] // that are < key is <= readVersion -struct CheckRangeRightSide { +template struct CheckRangeRightSide { CheckRangeRightSide(Node *n, std::span key, int prefixLen, InternalVersionT readVersion, ConflictSet::Impl *impl) : n(n), key(key), remaining(key), prefixLen(prefixLen), @@ -2251,7 +2264,8 @@ struct CheckRangeRightSide { return true; } - if (!checkMaxBetweenExclusive(n, -1, remaining[0], readVersion)) { + if (!checkMaxBetweenExclusive(n, -1, remaining[0], + readVersion)) { ok = false; return true; } @@ -2341,10 +2355,12 @@ struct CheckRangeRightSide { } } }; +} // namespace -bool checkRangeRead(Node *n, std::span begin, - std::span end, InternalVersionT readVersion, - ConflictSet::Impl *impl) { +template +bool checkRangeReadImpl(Node *n, std::span begin, + std::span end, + InternalVersionT readVersion, ConflictSet::Impl *impl) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && @@ -2378,19 +2394,22 @@ bool checkRangeRead(Node *n, std::span begin, lcp -= consumed; if (lcp == int(begin.size())) { - CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion, impl}; + CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion, + impl}; while (!checkRangeRightSide.step()) ; return checkRangeRightSide.ok; } - if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp], - readVersion, impl)) { + if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], + end[lcp], readVersion, impl)) { return false; } - CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion, impl}; - CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion, impl}; + CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion, + impl}; + CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion, + impl}; for (;;) { bool leftDone = checkRangeLeftSide.step(); @@ -2415,6 +2434,27 @@ bool checkRangeRead(Node *n, std::span begin, return checkRangeLeftSide.ok & checkRangeRightSide.ok; } +#if defined(__SANITIZE_THREAD__) || !defined(__x86_64__) +bool checkRangeRead(Node *n, std::span begin, + std::span end, InternalVersionT readVersion, + ConflictSet::Impl *impl) { + return checkRangeReadImpl(n, begin, end, readVersion, impl); +} +#else +__attribute__((target("default"))) bool +checkRangeRead(Node *n, std::span begin, + std::span end, InternalVersionT readVersion, + ConflictSet::Impl *impl) { + return checkRangeReadImpl(n, begin, end, readVersion, impl); +} +__attribute__((target("avx512f"))) bool +checkRangeRead(Node *n, std::span begin, + std::span end, InternalVersionT readVersion, + ConflictSet::Impl *impl) { + return checkRangeReadImpl(n, begin, end, readVersion, impl); +} +#endif + // Returns a pointer to the newly inserted node. Caller must set // `entryPresent`, `entry` fields and `maxVersion` on the result. The search // path of the result's parent will have `maxVersion` at least `writeVersion` as