Don't plumb impl and ReadContext

This commit is contained in:
2024-07-17 18:12:45 -07:00
parent 640c1ca9dd
commit 12f361f33a

View File

@@ -718,6 +718,7 @@ struct ReadContext {
double prefix_read_iterations_accum = 0; double prefix_read_iterations_accum = 0;
double range_read_iterations_accum = 0; double range_read_iterations_accum = 0;
double range_read_node_scan_accum = 0; double range_read_node_scan_accum = 0;
ConflictSet::Impl *impl;
}; };
// A type that's plumbed along the non-const call tree. Same lifetime as // A type that's plumbed along the non-const call tree. Same lifetime as
@@ -1769,13 +1770,13 @@ struct SearchStepWise {
// point or range version according to cmp, but this version short circuits as // point or range version according to cmp, but this version short circuits as
// soon as it can prove that there's no conflict. // soon as it can prove that there's no conflict.
bool checkPointRead(Node *n, const std::span<const uint8_t> key, bool checkPointRead(Node *n, const std::span<const uint8_t> key,
InternalVersionT readVersion, ConflictSet::Impl *impl, InternalVersionT readVersion, ReadContext *tls) {
ReadContext *tls) {
++tls->point_read_accum; ++tls->point_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG) #if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif #endif
auto remaining = key; auto remaining = key;
auto *impl = tls->impl;
for (;; ++tls->point_read_iterations_accum) { for (;; ++tls->point_read_iterations_accum) {
if (maxVersion(n, impl) <= readVersion) { if (maxVersion(n, impl) <= readVersion) {
++tls->point_read_short_circuit_accum; ++tls->point_read_short_circuit_accum;
@@ -1849,13 +1850,13 @@ downLeftSpine:
// max version or range version if this prefix doesn't exist, but this version // max version or range version if this prefix doesn't exist, but this version
// short circuits as soon as it can prove that there's no conflict. // short circuits as soon as it can prove that there's no conflict.
bool checkPrefixRead(Node *n, const std::span<const uint8_t> key, bool checkPrefixRead(Node *n, const std::span<const uint8_t> key,
InternalVersionT readVersion, ConflictSet::Impl *impl, InternalVersionT readVersion, ReadContext *tls) {
ReadContext *tls) {
++tls->prefix_read_accum; ++tls->prefix_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG) #if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str()); fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str());
#endif #endif
auto remaining = key; auto remaining = key;
auto *impl = tls->impl;
for (;; ++tls->prefix_read_iterations_accum) { for (;; ++tls->prefix_read_iterations_accum) {
auto m = maxVersion(n, impl); auto m = maxVersion(n, impl);
if (remaining.size() == 0) { if (remaining.size() == 0) {
@@ -2363,11 +2364,12 @@ Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
template <bool kAVX512> template <bool kAVX512>
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin, bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
int end, InternalVersionT readVersion, int end, InternalVersionT readVersion,
ConflictSet::Impl *impl, ReadContext *tls) { ReadContext *tls) {
#if DEBUG_VERBOSE && !defined(NDEBUG) #if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end); fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
#endif #endif
auto remaining = key; auto remaining = key;
auto *impl = tls->impl;
if (remaining.size() == 0) { if (remaining.size() == 0) {
return checkMaxBetweenExclusive<kAVX512>(n, begin, end, readVersion, tls); return checkMaxBetweenExclusive<kAVX512>(n, begin, end, readVersion, tls);
} }
@@ -2435,10 +2437,9 @@ namespace {
// that are >= key is <= readVersion // that are >= key is <= readVersion
template <bool kAVX512> struct CheckRangeLeftSide { template <bool kAVX512> struct CheckRangeLeftSide {
CheckRangeLeftSide(Node *n, std::span<const uint8_t> key, int prefixLen, CheckRangeLeftSide(Node *n, std::span<const uint8_t> key, int prefixLen,
InternalVersionT readVersion, ConflictSet::Impl *impl, InternalVersionT readVersion, ReadContext *tls)
ReadContext *tls)
: n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion), : n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion),
impl(impl), tls(tls) { impl(tls->impl), tls(tls) {
#if DEBUG_VERBOSE && !defined(NDEBUG) #if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range left side from %s for keys starting with %s\n", fprintf(stderr, "Check range left side from %s for keys starting with %s\n",
printable(key).c_str(), printable(key).c_str(),
@@ -2557,10 +2558,9 @@ template <bool kAVX512> struct CheckRangeLeftSide {
// that are < key is <= readVersion // that are < key is <= readVersion
template <bool kAVX512> struct CheckRangeRightSide { template <bool kAVX512> struct CheckRangeRightSide {
CheckRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen, CheckRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen,
InternalVersionT readVersion, ConflictSet::Impl *impl, InternalVersionT readVersion, ReadContext *tls)
ReadContext *tls)
: n(n), key(key), remaining(key), prefixLen(prefixLen), : n(n), key(key), remaining(key), prefixLen(prefixLen),
readVersion(readVersion), impl(impl), tls(tls) { readVersion(readVersion), impl(tls->impl), tls(tls) {
#if DEBUG_VERBOSE && !defined(NDEBUG) #if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range right side to %s for keys starting with %s\n", fprintf(stderr, "Check range right side to %s for keys starting with %s\n",
printable(key).c_str(), printable(key).c_str(),
@@ -2695,23 +2695,23 @@ template <bool kAVX512> struct CheckRangeRightSide {
template <bool kAVX512> template <bool kAVX512>
bool checkRangeReadImpl(Node *n, std::span<const uint8_t> begin, bool checkRangeReadImpl(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, std::span<const uint8_t> end,
InternalVersionT readVersion, ConflictSet::Impl *impl, InternalVersionT readVersion, ReadContext *tls) {
ReadContext *tls) {
int lcp = longestCommonPrefix(begin.data(), end.data(), int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size())); std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) { end.back() == 0) {
return checkPointRead(n, begin, readVersion, impl, tls); return checkPointRead(n, begin, readVersion, tls);
} }
if (lcp == int(begin.size() - 1) && end.size() == begin.size() && if (lcp == int(begin.size() - 1) && end.size() == begin.size() &&
int(begin.back()) + 1 == int(end.back())) { int(begin.back()) + 1 == int(end.back())) {
return checkPrefixRead(n, begin, readVersion, impl, tls); return checkPrefixRead(n, begin, readVersion, tls);
} }
++tls->range_read_accum; ++tls->range_read_accum;
SearchStepWise search{n, begin.subspan(0, lcp)}; SearchStepWise search{n, begin.subspan(0, lcp)};
Arena arena; Arena arena;
auto *impl = tls->impl;
for (;; ++tls->range_read_iterations_accum) { for (;; ++tls->range_read_iterations_accum) {
assert(getSearchPath(arena, search.n) <=> assert(getSearchPath(arena, search.n) <=>
begin.subspan(0, lcp - search.remaining.size()) == begin.subspan(0, lcp - search.remaining.size()) ==
@@ -2737,22 +2737,22 @@ bool checkRangeReadImpl(Node *n, std::span<const uint8_t> begin,
lcp -= consumed; lcp -= consumed;
if (lcp == int(begin.size())) { if (lcp == int(begin.size())) {
CheckRangeRightSide<kAVX512> checkRangeRightSide{n, end, lcp, CheckRangeRightSide<kAVX512> checkRangeRightSide{n, end, lcp, readVersion,
readVersion, impl, tls}; tls};
while (!checkRangeRightSide.step()) while (!checkRangeRightSide.step())
; ;
return checkRangeRightSide.ok; return checkRangeRightSide.ok;
} }
if (!checkRangeStartsWith<kAVX512>(n, begin.subspan(0, lcp), begin[lcp], if (!checkRangeStartsWith<kAVX512>(n, begin.subspan(0, lcp), begin[lcp],
end[lcp], readVersion, impl, tls)) { end[lcp], readVersion, tls)) {
return false; return false;
} }
CheckRangeLeftSide<kAVX512> checkRangeLeftSide{n, begin, lcp + 1, CheckRangeLeftSide<kAVX512> checkRangeLeftSide{n, begin, lcp + 1, readVersion,
readVersion, impl, tls}; tls};
CheckRangeRightSide<kAVX512> checkRangeRightSide{n, end, lcp + 1, CheckRangeRightSide<kAVX512> checkRangeRightSide{n, end, lcp + 1, readVersion,
readVersion, impl, tls}; tls};
for (;;) { for (;;) {
bool leftDone = checkRangeLeftSide.step(); bool leftDone = checkRangeLeftSide.step();
@@ -2796,7 +2796,7 @@ checkMaxBetweenExclusive<true>(Node *n, int begin, int end,
template __attribute__((target("avx512f"))) bool template __attribute__((target("avx512f"))) bool
checkRangeStartsWith<true>(Node *n, std::span<const uint8_t> key, int begin, checkRangeStartsWith<true>(Node *n, std::span<const uint8_t> key, int begin,
int end, InternalVersionT readVersion, int end, InternalVersionT readVersion,
ConflictSet::Impl *impl); ReadContext *);
template __attribute__((target("avx512f"))) bool template __attribute__((target("avx512f"))) bool
CheckRangeLeftSide<true>::step(); CheckRangeLeftSide<true>::step();
template __attribute__((target("avx512f"))) bool template __attribute__((target("avx512f"))) bool
@@ -2804,27 +2804,27 @@ CheckRangeRightSide<true>::step();
template __attribute__((target("avx512f"))) bool template __attribute__((target("avx512f"))) bool
checkRangeReadImpl<true>(Node *n, std::span<const uint8_t> begin, checkRangeReadImpl<true>(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, std::span<const uint8_t> end,
InternalVersionT readVersion, ConflictSet::Impl *impl); InternalVersionT readVersion, ReadContext *);
#endif #endif
#if defined(__SANITIZE_THREAD__) || !defined(__x86_64__) #if defined(__SANITIZE_THREAD__) || !defined(__x86_64__)
bool checkRangeRead(Node *n, std::span<const uint8_t> begin, bool checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, InternalVersionT readVersion, std::span<const uint8_t> end, InternalVersionT readVersion,
ConflictSet::Impl *impl, ReadContext *tls) { ReadContext *tls) {
return checkRangeReadImpl<false>(n, begin, end, readVersion, impl, tls); return checkRangeReadImpl<false>(n, begin, end, readVersion, tls);
} }
#else #else
__attribute__((target("default"))) bool __attribute__((target("default"))) bool
checkRangeRead(Node *n, std::span<const uint8_t> begin, checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, InternalVersionT readVersion, std::span<const uint8_t> end, InternalVersionT readVersion,
ConflictSet::Impl *impl) { ReadContext *tls) {
return checkRangeReadImpl<false>(n, begin, end, readVersion, impl); return checkRangeReadImpl<false>(n, begin, end, readVersion, tls);
} }
__attribute__((target("avx512f"))) bool __attribute__((target("avx512f"))) bool
checkRangeRead(Node *n, std::span<const uint8_t> begin, checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, InternalVersionT readVersion, std::span<const uint8_t> end, InternalVersionT readVersion,
ConflictSet::Impl *impl) { ReadContext *tls) {
return checkRangeReadImpl<true>(n, begin, end, readVersion, impl); return checkRangeReadImpl<true>(n, begin, end, readVersion, tls);
} }
#endif #endif
@@ -3127,6 +3127,7 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts_begin); clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts_begin);
#endif #endif
ReadContext tls; ReadContext tls;
tls.impl = this;
int commits_accum = 0; int commits_accum = 0;
int conflicts_accum = 0; int conflicts_accum = 0;
int too_olds_accum = 0; int too_olds_accum = 0;
@@ -3142,11 +3143,9 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
reads[i].readVersion < oldestVersionFullPrecision ? TooOld reads[i].readVersion < oldestVersionFullPrecision ? TooOld
: (end.size() > 0 : (end.size() > 0
? checkRangeRead(root, begin, end, ? checkRangeRead(root, begin, end,
InternalVersionT(reads[i].readVersion), this, InternalVersionT(reads[i].readVersion), &tls)
&tls)
: checkPointRead(root, begin, : checkPointRead(root, begin,
InternalVersionT(reads[i].readVersion), this, InternalVersionT(reads[i].readVersion), &tls))
&tls))
? Commit ? Commit
: Conflict; : Conflict;
commits_accum += result[i] == Commit; commits_accum += result[i] == Commit;