Bring back sequential read implementation

This will be used if there's only one check to perform or if the
compiler does not support musttail and preserve_none
This commit is contained in:
2024-10-15 17:12:11 -07:00
parent 84942a5bf8
commit 769cf8de9a

View File

@@ -2474,6 +2474,52 @@ bool checkMaxBetweenExclusive(Node256 *n, int begin, int end,
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion, tls);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) bool
checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *tls) {
switch (n->getType()) {
case Type_Node0:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node0 *>(n), begin,
end, readVersion, tls);
case Type_Node3:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node3 *>(n), begin,
end, readVersion, tls);
case Type_Node16:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node16 *>(n), begin,
end, readVersion, tls);
case Type_Node48:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node48 *>(n), begin,
end, readVersion, tls);
case Type_Node256:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node256 *>(n), begin,
end, readVersion, tls);
}
}
__attribute__((target("default")))
#endif
bool checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *tls) {
switch (n->getType()) {
case Type_Node0:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node0 *>(n), begin,
end, readVersion, tls);
case Type_Node3:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node3 *>(n), begin,
end, readVersion, tls);
case Type_Node16:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node16 *>(n), begin,
end, readVersion, tls);
case Type_Node48:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node48 *>(n), begin,
end, readVersion, tls);
case Type_Node256:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node256 *>(n), begin,
end, readVersion, tls);
}
}
Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
assert(n != nullptr);
auto result = vector<uint8_t>(arena);
@@ -2939,28 +2985,16 @@ struct CheckContext {
ConflictSet::Result *results;
int64_t started;
ReadContext tls;
#if !__has_attribute(musttail)
CheckJob *job;
bool done;
#endif
};
PRESERVE_NONE void keepGoing(CheckJob *job, CheckContext *context) {
#if __has_attribute(musttail)
job = job->next;
MUSTTAIL return job->continuation(job, context);
#else
context->job = job->next;
return;
#endif
}
PRESERVE_NONE void complete(CheckJob *job, CheckContext *context) {
if (context->started == context->count) {
if (job->prev == job) {
#if !__has_attribute(musttail)
context->done = true;
#endif
return;
}
job->prev->next = job->next;
@@ -3729,8 +3763,438 @@ void CheckJob::init(const ConflictSet::ReadRange *read,
}
}
// Sequential implementations
namespace {
// Logically this is the same as performing firstGeq and then checking against
// point or range version according to cmp, but this version short circuits as
// soon as it can prove that there's no conflict.
bool checkPointRead(Node *n, const std::span<const uint8_t> key,
InternalVersionT readVersion, ReadContext *tls) {
++tls->point_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;; ++tls->point_read_iterations_accum) {
if (remaining.size() == 0) {
if (n->entryPresent) {
return n->entry.pointVersion <= readVersion;
}
n = getFirstChildExists(n);
goto downLeftSpine;
}
auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
if (maxV <= readVersion) {
++tls->point_read_short_circuit_accum;
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
// Logically this is the same as performing firstGeq and then checking against
// max version or range version if this prefix doesn't exist, but this version
// short circuits as soon as it can prove that there's no conflict.
bool checkPrefixRead(Node *n, const std::span<const uint8_t> key,
InternalVersionT readVersion, ReadContext *tls) {
++tls->prefix_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;; ++tls->prefix_read_iterations_accum) {
if (remaining.size() == 0) {
// There's no way to encode a prefix read of "", so n is not the root
return maxVersion(n) <= readVersion;
}
auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node. All physical nodes that start with prefix are reachable from
// n.
if (maxVersion(n) > readVersion) {
return false;
}
goto downLeftSpine;
}
}
if (maxV <= readVersion) {
++tls->prefix_read_short_circuit_accum;
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are >= key is <= readVersion
bool checkRangeLeftSide(Node *n, std::span<const uint8_t> key, int prefixLen,
InternalVersionT readVersion, ReadContext *tls) {
auto remaining = key;
int searchPathLen = 0;
for (;; ++tls->range_read_iterations_accum) {
if (remaining.size() == 0) {
assert(searchPathLen >= prefixLen);
return maxVersion(n) <= readVersion;
}
if (searchPathLen >= prefixLen) {
if (!checkMaxBetweenExclusive(n, remaining[0], 256, readVersion, tls)) {
return false;
}
}
auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
if (searchPathLen < prefixLen) {
n = c;
goto downLeftSpine;
}
n = c;
return maxVersion(n) <= readVersion;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
if (searchPathLen < prefixLen) {
goto downLeftSpine;
}
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n) <= readVersion;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
assert(searchPathLen >= prefixLen);
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n) <= readVersion;
}
}
if (maxV <= readVersion) {
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are < key is <= readVersion
bool checkRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen,
InternalVersionT readVersion, ReadContext *tls) {
auto remaining = key;
int searchPathLen = 0;
for (;; ++tls->range_read_iterations_accum) {
assert(searchPathLen <= int(key.size()));
if (remaining.size() == 0) {
goto downLeftSpine;
}
if (searchPathLen >= prefixLen) {
if (n->entryPresent && n->entry.pointVersion > readVersion) {
return false;
}
if (!checkMaxBetweenExclusive(n, -1, remaining[0], readVersion, tls)) {
return false;
}
}
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
return false;
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
goto backtrack;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
++searchPathLen;
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
return false;
}
goto backtrack;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
goto downLeftSpine;
}
}
}
backtrack:
for (;;) {
// searchPathLen > prefixLen implies n is not the root
if (searchPathLen > prefixLen && maxVersion(n) > readVersion) {
return false;
}
if (n->parent == nullptr) {
return true;
}
auto next = getChildGeq(n->parent, n->parentsIndex + 1);
if (next == nullptr) {
searchPathLen -= 1 + n->partialKeyLen;
n = n->parent;
} else {
searchPathLen -= n->partialKeyLen;
n = next;
searchPathLen += n->partialKeyLen;
goto downLeftSpine;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
bool checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, InternalVersionT readVersion,
ReadContext *tls) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return checkPointRead(n, begin, readVersion, tls);
}
if (lcp == int(begin.size() - 1) && end.size() == begin.size() &&
int(begin.back()) + 1 == int(end.back())) {
return checkPrefixRead(n, begin, readVersion, tls);
}
++tls->range_read_accum;
auto remaining = begin.subspan(0, lcp);
Arena arena;
// Advance down common prefix, but stay on a physical path in the tree
for (;; ++tls->range_read_iterations_accum) {
assert(getSearchPath(arena, n) <=>
begin.subspan(0, lcp - remaining.size()) ==
0);
if (remaining.size() == 0) {
break;
}
auto [c, v] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
break;
}
if (child->partialKeyLen > 0) {
int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1);
int i =
longestCommonPrefix(child->partialKey(), remaining.data() + 1, cl);
if (i != child->partialKeyLen) {
break;
}
}
if (v <= readVersion) {
++tls->range_read_short_circuit_accum;
return true;
}
n = child;
remaining =
remaining.subspan(1 + child->partialKeyLen,
remaining.size() - (1 + child->partialKeyLen));
}
assert(getSearchPath(arena, n) <=> begin.subspan(0, lcp - remaining.size()) ==
0);
const int consumed = lcp - remaining.size();
assume(consumed >= 0);
begin = begin.subspan(consumed, int(begin.size()) - consumed);
end = end.subspan(consumed, int(end.size()) - consumed);
lcp -= consumed;
if (lcp == int(begin.size())) {
return checkRangeRightSide(n, end, lcp, readVersion, tls);
}
// This makes it safe to check maxVersion within checkRangeLeftSide. If this
// were false, then we would have returned above since lcp == begin.size().
assert(!(n->parent == nullptr && begin.size() == 0));
return checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp],
readVersion, tls) &&
checkRangeLeftSide(n, begin, lcp + 1, readVersion, tls) &&
checkRangeRightSide(n, end, lcp + 1, readVersion, tls);
}
} // namespace
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
// We still have the sequential implementation for compilers that don't
// support preserve_none and musttail
void useSequential(const ReadRange *reads, Result *result, int count,
CheckContext &context) {
for (int i = 0; i < count; ++i) {
if (reads[i].readVersion < oldestVersionFullPrecision) [[unlikely]] {
result[i] = TooOld;
} else {
bool ok;
if (reads[i].end.len == 0) {
ok = checkPointRead(
root,
std::span<const uint8_t>(reads[i].begin.p, reads[i].begin.len),
InternalVersionT(reads[i].readVersion), &context.tls);
} else {
ok = checkRangeRead(
root,
std::span<const uint8_t>(reads[i].begin.p, reads[i].begin.len),
std::span<const uint8_t>(reads[i].end.p, reads[i].end.len),
InternalVersionT(reads[i].readVersion), &context.tls);
}
result[i] = ok ? Commit : Conflict;
}
}
}
void check(const ReadRange *reads, Result *result, int count) {
assert(oldestVersionFullPrecision >=
newestVersionFullPrecision - kNominalVersionWindow);
@@ -3740,10 +4204,15 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
}
int64_t check_byte_accum = 0;
constexpr int kConcurrent = 16;
CheckJob inProgress[kConcurrent];
CheckContext context;
context.tls.impl = this;
#if __has_attribute(preserve_none) && __has_attribute(musttail)
if (count == 1) {
useSequential(reads, result, count, context);
} else {
constexpr int kConcurrent = 16;
CheckJob inProgress[kConcurrent];
context.count = count;
context.oldestVersionFullPrecision = oldestVersionFullPrecision;
context.root = root;
@@ -3764,16 +4233,13 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
inProgress[0].prev = inProgress + started - 1;
inProgress[started - 1].next = inProgress;
#if __has_attribute(musttail)
// Kick off the sequence of tail calls that finally returns once all jobs
// are done
inProgress->continuation(inProgress, &context);
#else
context.job = inProgress;
context.done = false;
while (!context.done) {
context.job->continuation(context.job, &context);
}
#else
useSequential(reads, result, count, context);
#endif
for (int i = 0; i < count; ++i) {
@@ -3887,17 +4353,17 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
int64_t set_oldest_iterations_accum = 0;
for (; fuel > 0 && n != nullptr; ++set_oldest_iterations_accum) {
rezero(n, oldestVersion);
// The "make sure gc keeps up with writes" calculations assume that we're
// scanning key by key, not node by node. Make sure we only spend fuel
// when there's a logical entry.
// The "make sure gc keeps up with writes" calculations assume that
// we're scanning key by key, not node by node. Make sure we only spend
// fuel when there's a logical entry.
fuel -= n->entryPresent;
if (n->entryPresent && std::max(n->entry.pointVersion,
n->entry.rangeVersion) <= oldestVersion) {
// Any transaction n would have prevented from committing is
// going to fail with TooOld anyway.
// There's no way to insert a range such that range version of the right
// node is greater than the point version of the left node
// There's no way to insert a range such that range version of the
// right node is greater than the point version of the left node
assert(n->entry.rangeVersion <= oldestVersion);
n = erase(n, &tls, this, /*logical*/ false);
} else {
@@ -3941,9 +4407,9 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
InternalVersionT::zero = tls.zero = oldestVersion;
#endif
#ifdef NDEBUG
// This is here for performance reasons, since we want to amortize the cost
// of storing the search path as a string. In tests, we want to exercise the
// rest of the code often.
// This is here for performance reasons, since we want to amortize the
// cost of storing the search path as a string. In tests, we want to
// exercise the rest of the code often.
if (keyUpdates < 100) {
return;
}
@@ -4071,16 +4537,14 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
"The total number of entries inserted in the tree");
COUNTER(entries_erased_total,
"The total number of entries erased from the tree");
COUNTER(
gc_iterations_total,
"The total number of iterations of the main loop for garbage collection");
COUNTER(gc_iterations_total, "The total number of iterations of the main "
"loop for garbage collection");
COUNTER(write_bytes_total, "Total number of key bytes in calls to addWrites");
GAUGE(oldest_version,
"The lowest version that doesn't result in \"TooOld\" for checks");
GAUGE(newest_version, "The version of the most recent call to addWrites");
GAUGE(
oldest_extant_version,
"A lower bound on the lowest version associated with an existing entry");
GAUGE(oldest_extant_version, "A lower bound on the lowest version "
"associated with an existing entry");
// ==================== END METRICS DEFINITIONS ====================
#undef GAUGE
#undef COUNTER
@@ -4338,8 +4802,8 @@ std::string strinc(std::string_view str, bool &ok) {
if ((uint8_t &)(str[index]) != 255)
break;
// Must not be called with a string that consists only of zero or more '\xff'
// bytes.
// Must not be called with a string that consists only of zero or more
// '\xff' bytes.
if (index < 0) {
ok = false;
return {};