Use longestCommonPrefix instead of strinc in checkRangeRead
This commit is contained in:
203
ConflictSet.cpp
203
ConflictSet.cpp
@@ -22,24 +22,6 @@
|
|||||||
|
|
||||||
// ==================== BEGIN IMPLEMENTATION ====================
|
// ==================== BEGIN IMPLEMENTATION ====================
|
||||||
|
|
||||||
std::span<uint8_t> strincMutate(std::span<uint8_t> str, bool &ok) {
|
|
||||||
int index;
|
|
||||||
for (index = str.size() - 1; index >= 0; index--)
|
|
||||||
if (str[index] != 255)
|
|
||||||
break;
|
|
||||||
|
|
||||||
// Must not be called with a string that consists only of zero or more '\xff'
|
|
||||||
// bytes.
|
|
||||||
if (index < 0) {
|
|
||||||
ok = false;
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ok = true;
|
|
||||||
str = str.subspan(0, index + 1);
|
|
||||||
++str.back();
|
|
||||||
return str;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Entry {
|
struct Entry {
|
||||||
int64_t pointVersion;
|
int64_t pointVersion;
|
||||||
int64_t rangeVersion;
|
int64_t rangeVersion;
|
||||||
@@ -762,6 +744,153 @@ Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
|
||||||
|
constexpr int kStride = 64;
|
||||||
|
#else
|
||||||
|
constexpr int kStride = 16;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
constexpr int kUnrollFactor = 4;
|
||||||
|
|
||||||
|
bool compareStride(const uint8_t *ap, const uint8_t *bp) {
|
||||||
|
#if defined(HAS_ARM_NEON)
|
||||||
|
static_assert(kStride == 64);
|
||||||
|
uint8x16_t x[4];
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16));
|
||||||
|
}
|
||||||
|
auto results = vreinterpretq_u16_u8(
|
||||||
|
vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3])));
|
||||||
|
bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) ==
|
||||||
|
uint64_t(-1);
|
||||||
|
#elif defined(HAS_AVX)
|
||||||
|
static_assert(kStride == 64);
|
||||||
|
__m128i x[4];
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)),
|
||||||
|
_mm_loadu_si128((__m128i *)(bp + i * 16)));
|
||||||
|
}
|
||||||
|
auto eq =
|
||||||
|
_mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]),
|
||||||
|
_mm_and_si128(x[2], x[3]))) == 0xffff;
|
||||||
|
#else
|
||||||
|
// Hope it gets vectorized
|
||||||
|
auto eq = memcmp(ap, bp, kStride) == 0;
|
||||||
|
#endif
|
||||||
|
return eq;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Precondition: ap[0:kStride] != bp[0:kStride]
|
||||||
|
int firstNeqStride(const uint8_t *ap, const uint8_t *bp) {
|
||||||
|
#if defined(HAS_AVX)
|
||||||
|
static_assert(kStride == 64);
|
||||||
|
uint64_t c[kStride / 16];
|
||||||
|
for (int i = 0; i < kStride; i += 16) {
|
||||||
|
const auto a = _mm_loadu_si128((__m128i *)(ap + i));
|
||||||
|
const auto b = _mm_loadu_si128((__m128i *)(bp + i));
|
||||||
|
const auto compared = _mm_cmpeq_epi8(a, b);
|
||||||
|
c[i / 16] = _mm_movemask_epi8(compared) & 0xffff;
|
||||||
|
}
|
||||||
|
return __builtin_ctzll(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48));
|
||||||
|
#elif defined(HAS_ARM_NEON)
|
||||||
|
static_assert(kStride == 64);
|
||||||
|
for (int i = 0; i < kStride; i += 16) {
|
||||||
|
// 0xff for each match
|
||||||
|
uint16x8_t results =
|
||||||
|
vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i)));
|
||||||
|
// 0xf for each mismatch
|
||||||
|
uint64_t bitfield =
|
||||||
|
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
|
||||||
|
if (bitfield) {
|
||||||
|
return i + (__builtin_ctzll(bitfield) >> 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_unreachable();
|
||||||
|
#else
|
||||||
|
int i = 0;
|
||||||
|
for (; i < kStride - 1; ++i) {
|
||||||
|
if (*ap++ != *bp++) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
|
||||||
|
int i = 0;
|
||||||
|
int end;
|
||||||
|
|
||||||
|
if (cl < 8) {
|
||||||
|
goto bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimistic early return
|
||||||
|
{
|
||||||
|
uint64_t a;
|
||||||
|
uint64_t b;
|
||||||
|
memcpy(&a, ap, 8);
|
||||||
|
memcpy(&b, bp, 8);
|
||||||
|
const auto mismatched = a ^ b;
|
||||||
|
if (mismatched) {
|
||||||
|
return __builtin_ctzll(mismatched) / 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// kStride * kUnrollCount at a time
|
||||||
|
end = cl & ~(kStride * kUnrollFactor - 1);
|
||||||
|
while (i < end) {
|
||||||
|
for (int j = 0; j < kUnrollFactor; ++j) {
|
||||||
|
if (!compareStride(ap, bp)) {
|
||||||
|
return i + firstNeqStride(ap, bp);
|
||||||
|
}
|
||||||
|
i += kStride;
|
||||||
|
ap += kStride;
|
||||||
|
bp += kStride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// kStride at a time
|
||||||
|
end = cl & ~(kStride - 1);
|
||||||
|
while (i < end) {
|
||||||
|
if (!compareStride(ap, bp)) {
|
||||||
|
return i + firstNeqStride(ap, bp);
|
||||||
|
}
|
||||||
|
i += kStride;
|
||||||
|
ap += kStride;
|
||||||
|
bp += kStride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// word at a time
|
||||||
|
end = cl & ~(sizeof(uint64_t) - 1);
|
||||||
|
while (i < end) {
|
||||||
|
uint64_t a;
|
||||||
|
uint64_t b;
|
||||||
|
memcpy(&a, ap, 8);
|
||||||
|
memcpy(&b, bp, 8);
|
||||||
|
const auto mismatched = a ^ b;
|
||||||
|
if (mismatched) {
|
||||||
|
return i + __builtin_ctzll(mismatched) / 8;
|
||||||
|
}
|
||||||
|
i += 8;
|
||||||
|
ap += 8;
|
||||||
|
bp += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes:
|
||||||
|
// byte at a time
|
||||||
|
while (i < cl) {
|
||||||
|
if (*ap != *bp) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ap;
|
||||||
|
++bp;
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
bool checkRangeRead(Node *n, const std::span<const uint8_t> begin,
|
bool checkRangeRead(Node *n, const std::span<const uint8_t> begin,
|
||||||
const std::span<const uint8_t> end, int64_t readVersion,
|
const std::span<const uint8_t> end, int64_t readVersion,
|
||||||
Arena &arena) {
|
Arena &arena) {
|
||||||
@@ -785,8 +914,16 @@ bool checkRangeRead(Node *n, const std::span<const uint8_t> begin,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (auto *iter = left.n; iter != nullptr && searchPath < end;
|
for (auto *iter = left.n; iter != nullptr; first = false) {
|
||||||
first = false) {
|
int cl = std::min(searchPath.size(), end.size());
|
||||||
|
int lcp = longestCommonPrefix(searchPath.data(), end.data(), cl);
|
||||||
|
|
||||||
|
// if (searchPath >= end) break;
|
||||||
|
if ((cl == lcp ? searchPath.size() <=> end.size()
|
||||||
|
: searchPath[lcp] <=> end[lcp]) >= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if (iter->entryPresent) {
|
if (iter->entryPresent) {
|
||||||
if (!first && iter->entry.rangeVersion > readVersion) {
|
if (!first && iter->entry.rangeVersion > readVersion) {
|
||||||
return false;
|
return false;
|
||||||
@@ -801,28 +938,18 @@ bool checkRangeRead(Node *n, const std::span<const uint8_t> begin,
|
|||||||
fprintf(stderr, "Max version of keys starting with %s: %" PRId64 "\n",
|
fprintf(stderr, "Max version of keys starting with %s: %" PRId64 "\n",
|
||||||
printable(searchPath).c_str(), iter->maxVersion);
|
printable(searchPath).c_str(), iter->maxVersion);
|
||||||
#endif
|
#endif
|
||||||
bool ok = true;
|
if (lcp == int(searchPath.size())) {
|
||||||
auto rangeEnd = strincMutate(searchPath, ok);
|
// end starts with searchPath, so end < range
|
||||||
auto c = std::strong_ordering::equal;
|
|
||||||
if (!ok) {
|
|
||||||
goto iterate;
|
|
||||||
}
|
|
||||||
c = rangeEnd <=> end;
|
|
||||||
--rangeEnd.back();
|
|
||||||
|
|
||||||
if (c == 0) {
|
|
||||||
return iter->maxVersion <= readVersion;
|
|
||||||
} else if (c < 0) {
|
|
||||||
if (iter->maxVersion > readVersion) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
iter = nextSibling(iter, searchPath);
|
|
||||||
} else {
|
|
||||||
iterate:
|
|
||||||
if (iter->maxVersion <= readVersion) {
|
if (iter->maxVersion <= readVersion) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
iter = nextPhysical(iter, searchPath);
|
iter = nextPhysical(iter, searchPath);
|
||||||
|
} else {
|
||||||
|
// end does not start with searchPath, so range end <= end
|
||||||
|
if (iter->maxVersion > readVersion) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
iter = nextSibling(iter, searchPath);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
Reference in New Issue
Block a user