From f8acc5ee865c473349b6be2560713b02c3504970 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Mon, 12 Feb 2024 15:00:46 -0800 Subject: [PATCH] Use longestCommonPrefix instead of strinc in checkRangeRead --- ConflictSet.cpp | 203 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 165 insertions(+), 38 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 9b375fd..317fce0 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -22,24 +22,6 @@ // ==================== BEGIN IMPLEMENTATION ==================== -std::span strincMutate(std::span str, bool &ok) { - int index; - for (index = str.size() - 1; index >= 0; index--) - if (str[index] != 255) - break; - - // Must not be called with a string that consists only of zero or more '\xff' - // bytes. - if (index < 0) { - ok = false; - return {}; - } - ok = true; - str = str.subspan(0, index + 1); - ++str.back(); - return str; -} - struct Entry { int64_t pointVersion; int64_t rangeVersion; @@ -762,6 +744,153 @@ Vector getSearchPath(Arena &arena, Node *n) { return result; } +#if defined(HAS_AVX) || defined(HAS_ARM_NEON) +constexpr int kStride = 64; +#else +constexpr int kStride = 16; +#endif + +constexpr int kUnrollFactor = 4; + +bool compareStride(const uint8_t *ap, const uint8_t *bp) { +#if defined(HAS_ARM_NEON) + static_assert(kStride == 64); + uint8x16_t x[4]; + for (int i = 0; i < 4; ++i) { + x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); + } + auto results = vreinterpretq_u16_u8( + vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); + bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == + uint64_t(-1); +#elif defined(HAS_AVX) + static_assert(kStride == 64); + __m128i x[4]; + for (int i = 0; i < 4; ++i) { + x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), + _mm_loadu_si128((__m128i *)(bp + i * 16))); + } + auto eq = + _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), + _mm_and_si128(x[2], x[3]))) == 0xffff; +#else + // Hope it gets vectorized + auto eq = memcmp(ap, bp, kStride) == 0; +#endif + return eq; +} + +// Precondition: ap[0:kStride] != bp[0:kStride] +int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { +#if defined(HAS_AVX) + static_assert(kStride == 64); + uint64_t c[kStride / 16]; + for (int i = 0; i < kStride; i += 16) { + const auto a = _mm_loadu_si128((__m128i *)(ap + i)); + const auto b = _mm_loadu_si128((__m128i *)(bp + i)); + const auto compared = _mm_cmpeq_epi8(a, b); + c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; + } + return __builtin_ctzll(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); +#elif defined(HAS_ARM_NEON) + static_assert(kStride == 64); + for (int i = 0; i < kStride; i += 16) { + // 0xff for each match + uint16x8_t results = + vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); + // 0xf for each mismatch + uint64_t bitfield = + ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); + if (bitfield) { + return i + (__builtin_ctzll(bitfield) >> 2); + } + } + __builtin_unreachable(); +#else + int i = 0; + for (; i < kStride - 1; ++i) { + if (*ap++ != *bp++) { + break; + } + } + return i; +#endif +} + +int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { + int i = 0; + int end; + + if (cl < 8) { + goto bytes; + } + + // Optimistic early return + { + uint64_t a; + uint64_t b; + memcpy(&a, ap, 8); + memcpy(&b, bp, 8); + const auto mismatched = a ^ b; + if (mismatched) { + return __builtin_ctzll(mismatched) / 8; + } + } + + // kStride * kUnrollCount at a time + end = cl & ~(kStride * kUnrollFactor - 1); + while (i < end) { + for (int j = 0; j < kUnrollFactor; ++j) { + if (!compareStride(ap, bp)) { + return i + firstNeqStride(ap, bp); + } + i += kStride; + ap += kStride; + bp += kStride; + } + } + + // kStride at a time + end = cl & ~(kStride - 1); + while (i < end) { + if (!compareStride(ap, bp)) { + return i + firstNeqStride(ap, bp); + } + i += kStride; + ap += kStride; + bp += kStride; + } + + // word at a time + end = cl & ~(sizeof(uint64_t) - 1); + while (i < end) { + uint64_t a; + uint64_t b; + memcpy(&a, ap, 8); + memcpy(&b, bp, 8); + const auto mismatched = a ^ b; + if (mismatched) { + return i + __builtin_ctzll(mismatched) / 8; + } + i += 8; + ap += 8; + bp += 8; + } + +bytes: + // byte at a time + while (i < cl) { + if (*ap != *bp) { + break; + } + ++ap; + ++bp; + ++i; + } + + return i; +} + bool checkRangeRead(Node *n, const std::span begin, const std::span end, int64_t readVersion, Arena &arena) { @@ -785,8 +914,16 @@ bool checkRangeRead(Node *n, const std::span begin, } bool first = true; - for (auto *iter = left.n; iter != nullptr && searchPath < end; - first = false) { + for (auto *iter = left.n; iter != nullptr; first = false) { + int cl = std::min(searchPath.size(), end.size()); + int lcp = longestCommonPrefix(searchPath.data(), end.data(), cl); + + // if (searchPath >= end) break; + if ((cl == lcp ? searchPath.size() <=> end.size() + : searchPath[lcp] <=> end[lcp]) >= 0) { + break; + } + if (iter->entryPresent) { if (!first && iter->entry.rangeVersion > readVersion) { return false; @@ -801,28 +938,18 @@ bool checkRangeRead(Node *n, const std::span begin, fprintf(stderr, "Max version of keys starting with %s: %" PRId64 "\n", printable(searchPath).c_str(), iter->maxVersion); #endif - bool ok = true; - auto rangeEnd = strincMutate(searchPath, ok); - auto c = std::strong_ordering::equal; - if (!ok) { - goto iterate; - } - c = rangeEnd <=> end; - --rangeEnd.back(); - - if (c == 0) { - return iter->maxVersion <= readVersion; - } else if (c < 0) { - if (iter->maxVersion > readVersion) { - return false; - } - iter = nextSibling(iter, searchPath); - } else { - iterate: + if (lcp == int(searchPath.size())) { + // end starts with searchPath, so end < range if (iter->maxVersion <= readVersion) { return true; } iter = nextPhysical(iter, searchPath); + } else { + // end does not start with searchPath, so range end <= end + if (iter->maxVersion > readVersion) { + return false; + } + iter = nextSibling(iter, searchPath); } } return true;