diff --git a/ConflictSet.cpp b/ConflictSet.cpp index eb95808..5637eed 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -569,6 +569,195 @@ Node *nextSibling(Node *node) { } } +#if defined(HAS_AVX) || defined(HAS_ARM_NEON) +constexpr int kStride = 64; +#else +constexpr int kStride = 16; +#endif + +constexpr int kUnrollFactor = 4; + +bool compareStride(const uint8_t *ap, const uint8_t *bp) { +#if defined(HAS_ARM_NEON) + static_assert(kStride == 64); + uint8x16_t x[4]; + for (int i = 0; i < 4; ++i) { + x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); + } + auto results = vreinterpretq_u16_u8( + vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); + bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == + uint64_t(-1); +#elif defined(HAS_AVX) + static_assert(kStride == 64); + __m128i x[4]; + for (int i = 0; i < 4; ++i) { + x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), + _mm_loadu_si128((__m128i *)(bp + i * 16))); + } + auto eq = + _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), + _mm_and_si128(x[2], x[3]))) == 0xffff; +#else + // Hope it gets vectorized + auto eq = memcmp(ap, bp, kStride) == 0; +#endif + return eq; +} + +int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { +#if defined(HAS_AVX) + static_assert(kStride == 64); + uint64_t c[kStride / 16]; + for (int i = 0; i < kStride; i += 16) { + const auto a = _mm_loadu_si128((__m128i *)(ap + i)); + const auto b = _mm_loadu_si128((__m128i *)(bp + i)); + const auto compared = _mm_cmpeq_epi8(a, b); + c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; + } + return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); +#elif defined(HAS_ARM_NEON) + static_assert(kStride == 64); + for (int i = 0; i < kStride; i += 16) { + // 0xff for each match + uint16x8_t results = + vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); + // 0xf for each mismatch + uint64_t bitfield = + ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); + if (bitfield) { + return i + (std::countr_zero(bitfield) >> 2); + } + } + __builtin_unreachable(); // GCOVR_EXCL_LINE +#else + int i = 0; + for (; i < kStride - 1; ++i) { + if (*ap++ != *bp++) { + break; + } + } + return i; +#endif +} + +int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { + if (cl < 0) { + __builtin_unreachable(); // GCOVR_EXCL_LINE + } + int i = 0; + int end; + + if (cl < 8) { + goto bytes; + } + + // Optimistic early return + { + uint64_t a; + uint64_t b; + memcpy(&a, ap, 8); + memcpy(&b, bp, 8); + const auto mismatched = a ^ b; + if (mismatched) { + return std::countr_zero(mismatched) / 8; + } + } + + // kStride * kUnrollCount at a time + end = cl & ~(kStride * kUnrollFactor - 1); + while (i < end) { + for (int j = 0; j < kUnrollFactor; ++j) { + if (!compareStride(ap, bp)) { + return i + firstNeqStride(ap, bp); + } + i += kStride; + ap += kStride; + bp += kStride; + } + } + + // kStride at a time + end = cl & ~(kStride - 1); + while (i < end) { + if (!compareStride(ap, bp)) { + return i + firstNeqStride(ap, bp); + } + i += kStride; + ap += kStride; + bp += kStride; + } + + // word at a time + end = cl & ~(sizeof(uint64_t) - 1); + while (i < end) { + uint64_t a; + uint64_t b; + memcpy(&a, ap, 8); + memcpy(&b, bp, 8); + const auto mismatched = a ^ b; + if (mismatched) { + return i + std::countr_zero(mismatched) / 8; + } + i += 8; + ap += 8; + bp += 8; + } + +bytes: + // byte at a time + while (i < cl) { + if (*ap != *bp) { + break; + } + ++ap; + ++bp; + ++i; + } + + return i; +} + +int longestCommonPrefixPartialKey(const uint8_t *ap, const uint8_t *bp, + int cl) { + if (cl > Node::kPartialKeyMaxLen) { + __builtin_unreachable(); // GCOVR_EXCL_LINE + } + return longestCommonPrefix(ap, bp, cl); +#if 0 +static_assert(Node::kPartialKeyMaxLen == 16); + // SOMEDAY: use masked loads (requires avx-512/sve2) +#if defined(HAS_AVX) + __m128i a; + memcpy(&a, ap, cl); + __m128i b; + memcpy(&b, bp, cl); + const auto compared = _mm_cmpeq_epi8(a, b); + int mask = (1 << cl) - 1; + auto c = = _mm_movemask_epi8(compared) & mask; + return std::countr_zero(~c); +#elif defined(HAS_ARM_NEON) + uint8x16_t a; + memcpy(&a, ap, cl); + uint8x16_t b; + memcpy(&b, bp, cl); + uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(a, b)); + uint64_t mask = cl == 16 ? uint64_t(-1) : (uint64_t(1) << (cl * 4)) - 1; + uint64_t bitfield = + vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask; + return std::countr_zero(~bitfield) >> 2; +#else + int i = 0; + for (; i < 16; ++i) { + if (*ap++ != *bp++) { + break; + } + } + return i; +#endif +#endif +} + // Performs a physical search for remaining struct SearchStepWise { Node *n; @@ -651,11 +840,9 @@ bool checkPointRead(Node *n, const std::span key, } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - for (int i = 0; i < commonLen; ++i) { + int i = longestCommonPrefix(n->partialKey, remaining.data(), commonLen); + if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; - if (c == 0) { - continue; - } if (c > 0) { goto downLeftSpine; } else { @@ -777,153 +964,6 @@ Vector getSearchPath(Arena &arena, Node *n) { return result; } -#if defined(HAS_AVX) || defined(HAS_ARM_NEON) -constexpr int kStride = 64; -#else -constexpr int kStride = 16; -#endif - -constexpr int kUnrollFactor = 4; - -bool compareStride(const uint8_t *ap, const uint8_t *bp) { -#if defined(HAS_ARM_NEON) - static_assert(kStride == 64); - uint8x16_t x[4]; - for (int i = 0; i < 4; ++i) { - x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); - } - auto results = vreinterpretq_u16_u8( - vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); - bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == - uint64_t(-1); -#elif defined(HAS_AVX) - static_assert(kStride == 64); - __m128i x[4]; - for (int i = 0; i < 4; ++i) { - x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), - _mm_loadu_si128((__m128i *)(bp + i * 16))); - } - auto eq = - _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), - _mm_and_si128(x[2], x[3]))) == 0xffff; -#else - // Hope it gets vectorized - auto eq = memcmp(ap, bp, kStride) == 0; -#endif - return eq; -} - -// Precondition: ap[0:kStride] != bp[0:kStride] -int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { -#if defined(HAS_AVX) - static_assert(kStride == 64); - uint64_t c[kStride / 16]; - for (int i = 0; i < kStride; i += 16) { - const auto a = _mm_loadu_si128((__m128i *)(ap + i)); - const auto b = _mm_loadu_si128((__m128i *)(bp + i)); - const auto compared = _mm_cmpeq_epi8(a, b); - c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; - } - return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); -#elif defined(HAS_ARM_NEON) - static_assert(kStride == 64); - for (int i = 0; i < kStride; i += 16) { - // 0xff for each match - uint16x8_t results = - vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); - // 0xf for each mismatch - uint64_t bitfield = - ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); - if (bitfield) { - return i + (std::countr_zero(bitfield) >> 2); - } - } - __builtin_unreachable(); // GCOVR_EXCL_LINE -#else - int i = 0; - for (; i < kStride - 1; ++i) { - if (*ap++ != *bp++) { - break; - } - } - return i; -#endif -} - -int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { - int i = 0; - int end; - - if (cl < 8) { - goto bytes; - } - - // Optimistic early return - { - uint64_t a; - uint64_t b; - memcpy(&a, ap, 8); - memcpy(&b, bp, 8); - const auto mismatched = a ^ b; - if (mismatched) { - return std::countr_zero(mismatched) / 8; - } - } - - // kStride * kUnrollCount at a time - end = cl & ~(kStride * kUnrollFactor - 1); - while (i < end) { - for (int j = 0; j < kUnrollFactor; ++j) { - if (!compareStride(ap, bp)) { - return i + firstNeqStride(ap, bp); - } - i += kStride; - ap += kStride; - bp += kStride; - } - } - - // kStride at a time - end = cl & ~(kStride - 1); - while (i < end) { - if (!compareStride(ap, bp)) { - return i + firstNeqStride(ap, bp); - } - i += kStride; - ap += kStride; - bp += kStride; - } - - // word at a time - end = cl & ~(sizeof(uint64_t) - 1); - while (i < end) { - uint64_t a; - uint64_t b; - memcpy(&a, ap, 8); - memcpy(&b, bp, 8); - const auto mismatched = a ^ b; - if (mismatched) { - return i + std::countr_zero(mismatched) / 8; - } - i += 8; - ap += 8; - bp += 8; - } - -bytes: - // byte at a time - while (i < cl) { - if (*ap != *bp) { - break; - } - ++ap; - ++bp; - ++i; - } - - return i; -} - // Return true if the max version among all keys that start with key + [child], // where begin < child < end, is <= readVersion bool checkRangeStartsWith(Node *n, std::span key, int begin, @@ -956,11 +996,10 @@ bool checkRangeStartsWith(Node *n, std::span key, int begin, if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - for (int i = 0; i < commonLen; ++i) { + int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + commonLen); + if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; - if (c == 0) { - continue; - } if (c > 0) { goto downLeftSpine; } else { @@ -1061,12 +1100,10 @@ struct CheckRangeLeftSide { if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - for (int i = 0; i < commonLen; ++i) { + int i = longestCommonPrefix(n->partialKey, remaining.data(), commonLen); + searchPathLen += i; + if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; - if (c == 0) { - ++searchPathLen; - continue; - } if (c > 0) { if (searchPathLen < prefixLen) { return downLeftSpine(); @@ -1198,12 +1235,12 @@ struct CheckRangeRightSide { if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - for (int i = 0; i < commonLen; ++i) { + int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + commonLen); + searchPathLen += i; + if (i < commonLen) { ++searchPathLen; auto c = n->partialKey[i] <=> remaining[i]; - if (c == 0) { - continue; - } if (c > 0) { return downLeftSpine(); } else { @@ -1453,18 +1490,14 @@ void addPointWrite(Node *&root, int64_t oldestVersion, } void addWriteRange(Node *&root, int64_t oldestVersion, - const ConflictSet::WriteRange &w, - NodeAllocators *allocators) { - - auto begin = std::span(w.begin.p, w.begin.len); - auto end = std::span(w.end.p, w.end.len); + std::span begin, std::span end, + int64_t writeVersion, NodeAllocators *allocators) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && end.back() == 0) { - return addPointWrite(root, oldestVersion, begin, w.writeVersion, - allocators); + return addPointWrite(root, oldestVersion, begin, writeVersion, allocators); } auto remaining = begin.subspan(0, lcp); @@ -1489,7 +1522,7 @@ void addWriteRange(Node *&root, int64_t oldestVersion, break; } - n->maxVersion = std::max(n->maxVersion, w.writeVersion); + n->maxVersion = std::max(n->maxVersion, writeVersion); remaining = remaining.subspan(n->partialKeyLen + 1, remaining.size() - (n->partialKeyLen + 1)); @@ -1505,7 +1538,7 @@ void addWriteRange(Node *&root, int64_t oldestVersion, begin = begin.subspan(consumed, begin.size() - consumed); end = end.subspan(consumed, end.size() - consumed); - auto *beginNode = insert(useAsRoot, begin, w.writeVersion, true, allocators); + auto *beginNode = insert(useAsRoot, begin, writeVersion, true, allocators); const bool insertedBegin = !std::exchange(beginNode->entryPresent, true); @@ -1513,14 +1546,14 @@ void addWriteRange(Node *&root, int64_t oldestVersion, auto *p = nextLogical(beginNode); beginNode->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; - beginNode->entry.pointVersion = w.writeVersion; - beginNode->maxVersion = w.writeVersion; + beginNode->entry.pointVersion = writeVersion; + beginNode->maxVersion = writeVersion; } - beginNode->maxVersion = std::max(beginNode->maxVersion, w.writeVersion); + beginNode->maxVersion = std::max(beginNode->maxVersion, writeVersion); beginNode->entry.pointVersion = - std::max(beginNode->entry.pointVersion, w.writeVersion); + std::max(beginNode->entry.pointVersion, writeVersion); - auto *endNode = insert(useAsRoot, end, w.writeVersion, false, allocators); + auto *endNode = insert(useAsRoot, end, writeVersion, false, allocators); const bool insertedEnd = !std::exchange(endNode->entryPresent, true); @@ -1531,11 +1564,11 @@ void addWriteRange(Node *&root, int64_t oldestVersion, endNode->maxVersion = std::max(endNode->maxVersion, endNode->entry.pointVersion); } - endNode->entry.rangeVersion = w.writeVersion; + endNode->entry.rangeVersion = writeVersion; if (insertedEnd) { // beginNode may have been invalidated - beginNode = insert(useAsRoot, begin, w.writeVersion, true, allocators); + beginNode = insert(useAsRoot, begin, writeVersion, true, allocators); } for (beginNode = nextLogical(beginNode); beginNode != endNode;) { @@ -1591,11 +1624,10 @@ struct FirstGeqStepwise { } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - for (int i = 0; i < commonLen; ++i) { + int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + commonLen); + if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; - if (c == 0) { - continue; - } if (c > 0) { return downLeftSpine(); } else { @@ -1660,19 +1692,14 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { void check(const ReadRange *reads, Result *result, int count) const { for (int i = 0; i < count; ++i) { + const auto &r = reads[i]; + auto begin = std::span(r.begin.p, r.begin.len); + auto end = std::span(r.end.p, r.end.len); result[i] = reads[i].readVersion < oldestVersion ? TooOld - : (reads[i].end.len > 0 - ? checkRangeRead(root, - std::span(reads[i].begin.p, - reads[i].begin.len), - std::span(reads[i].end.p, - reads[i].end.len), - reads[i].readVersion) - : checkPointRead(root, - std::span(reads[i].begin.p, - reads[i].begin.len), - reads[i].readVersion)) + : (end.size() > 0 + ? checkRangeRead(root, begin, end, reads[i].readVersion) + : checkPointRead(root, begin, reads[i].readVersion)) ? Commit : Conflict; } @@ -1681,14 +1708,15 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { void addWrites(const WriteRange *writes, int count) { for (int i = 0; i < count; ++i) { const auto &w = writes[i]; + auto begin = std::span(w.begin.p, w.begin.len); + auto end = std::span(w.end.p, w.end.len); if (w.end.len > 0) { keyUpdates += 2; - addWriteRange(root, oldestVersion, w, &allocators); + addWriteRange(root, oldestVersion, begin, end, w.writeVersion, + &allocators); } else { keyUpdates += 1; - addPointWrite(root, oldestVersion, - std::span(w.begin.p, w.begin.len), - w.writeVersion, &allocators); + addPointWrite(root, oldestVersion, begin, w.writeVersion, &allocators); } } }