diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 420a2a0..8e4a94f 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -2474,6 +2474,52 @@ bool checkMaxBetweenExclusive(Node256 *n, int begin, int end, return checkMaxBetweenExclusiveImpl(n, begin, end, readVersion, tls); } +#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__) +__attribute__((target("avx512f"))) bool +checkMaxBetweenExclusive(Node *n, int begin, int end, + InternalVersionT readVersion, ReadContext *tls) { + switch (n->getType()) { + case Type_Node0: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node3: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node16: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node48: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node256: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + } +} +__attribute__((target("default"))) +#endif + +bool checkMaxBetweenExclusive(Node *n, int begin, int end, + InternalVersionT readVersion, ReadContext *tls) { + switch (n->getType()) { + case Type_Node0: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node3: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node16: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node48: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + case Type_Node256: + return checkMaxBetweenExclusiveImpl(static_cast(n), begin, + end, readVersion, tls); + } +} + Vector getSearchPath(Arena &arena, Node *n) { assert(n != nullptr); auto result = vector(arena); @@ -2939,28 +2985,16 @@ struct CheckContext { ConflictSet::Result *results; int64_t started; ReadContext tls; -#if !__has_attribute(musttail) - CheckJob *job; - bool done; -#endif }; PRESERVE_NONE void keepGoing(CheckJob *job, CheckContext *context) { -#if __has_attribute(musttail) job = job->next; MUSTTAIL return job->continuation(job, context); -#else - context->job = job->next; - return; -#endif } PRESERVE_NONE void complete(CheckJob *job, CheckContext *context) { if (context->started == context->count) { if (job->prev == job) { -#if !__has_attribute(musttail) - context->done = true; -#endif return; } job->prev->next = job->next; @@ -3729,8 +3763,438 @@ void CheckJob::init(const ConflictSet::ReadRange *read, } } +// Sequential implementations +namespace { +// Logically this is the same as performing firstGeq and then checking against +// point or range version according to cmp, but this version short circuits as +// soon as it can prove that there's no conflict. +bool checkPointRead(Node *n, const std::span key, + InternalVersionT readVersion, ReadContext *tls) { + ++tls->point_read_accum; +#if DEBUG_VERBOSE && !defined(NDEBUG) + fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); +#endif + auto remaining = key; + for (;; ++tls->point_read_iterations_accum) { + if (remaining.size() == 0) { + if (n->entryPresent) { + return n->entry.pointVersion <= readVersion; + } + n = getFirstChildExists(n); + goto downLeftSpine; + } + + auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]); + Node *child = c; + if (child == nullptr) { + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; + goto downLeftSpine; + } else { + n = nextSibling(n); + if (n == nullptr) { + return true; + } + goto downLeftSpine; + } + } + + n = child; + remaining = remaining.subspan(1, remaining.size() - 1); + + if (n->partialKeyLen > 0) { + int commonLen = std::min(n->partialKeyLen, remaining.size()); + int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); + if (i < commonLen) { + auto c = n->partialKey()[i] <=> remaining[i]; + if (c > 0) { + goto downLeftSpine; + } else { + n = nextSibling(n); + if (n == nullptr) { + return true; + } + goto downLeftSpine; + } + } + if (commonLen == n->partialKeyLen) { + // partial key matches + remaining = remaining.subspan(commonLen, remaining.size() - commonLen); + } else if (n->partialKeyLen > int(remaining.size())) { + // n is the first physical node greater than remaining, and there's no + // eq node + goto downLeftSpine; + } + } + + if (maxV <= readVersion) { + ++tls->point_read_short_circuit_accum; + return true; + } + } +downLeftSpine: + for (; !n->entryPresent; n = getFirstChildExists(n)) { + } + return n->entry.rangeVersion <= readVersion; +} + +// Logically this is the same as performing firstGeq and then checking against +// max version or range version if this prefix doesn't exist, but this version +// short circuits as soon as it can prove that there's no conflict. +bool checkPrefixRead(Node *n, const std::span key, + InternalVersionT readVersion, ReadContext *tls) { + ++tls->prefix_read_accum; +#if DEBUG_VERBOSE && !defined(NDEBUG) + fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str()); +#endif + auto remaining = key; + for (;; ++tls->prefix_read_iterations_accum) { + if (remaining.size() == 0) { + // There's no way to encode a prefix read of "", so n is not the root + return maxVersion(n) <= readVersion; + } + + auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]); + Node *child = c; + if (child == nullptr) { + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; + goto downLeftSpine; + } else { + n = nextSibling(n); + if (n == nullptr) { + return true; + } + goto downLeftSpine; + } + } + + n = child; + remaining = remaining.subspan(1, remaining.size() - 1); + + if (n->partialKeyLen > 0) { + int commonLen = std::min(n->partialKeyLen, remaining.size()); + int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); + if (i < commonLen) { + auto c = n->partialKey()[i] <=> remaining[i]; + if (c > 0) { + goto downLeftSpine; + } else { + n = nextSibling(n); + if (n == nullptr) { + return true; + } + goto downLeftSpine; + } + } + if (commonLen == n->partialKeyLen) { + // partial key matches + remaining = remaining.subspan(commonLen, remaining.size() - commonLen); + } else if (n->partialKeyLen > int(remaining.size())) { + // n is the first physical node greater than remaining, and there's no + // eq node. All physical nodes that start with prefix are reachable from + // n. + if (maxVersion(n) > readVersion) { + return false; + } + goto downLeftSpine; + } + } + + if (maxV <= readVersion) { + ++tls->prefix_read_short_circuit_accum; + return true; + } + } +downLeftSpine: + for (; !n->entryPresent; n = getFirstChildExists(n)) { + } + return n->entry.rangeVersion <= readVersion; +} + +// Return true if the max version among all keys that start with key[:prefixLen] +// that are >= key is <= readVersion +bool checkRangeLeftSide(Node *n, std::span key, int prefixLen, + InternalVersionT readVersion, ReadContext *tls) { + auto remaining = key; + int searchPathLen = 0; + for (;; ++tls->range_read_iterations_accum) { + if (remaining.size() == 0) { + assert(searchPathLen >= prefixLen); + return maxVersion(n) <= readVersion; + } + + if (searchPathLen >= prefixLen) { + if (!checkMaxBetweenExclusive(n, remaining[0], 256, readVersion, tls)) { + return false; + } + } + + auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]); + Node *child = c; + if (child == nullptr) { + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + if (searchPathLen < prefixLen) { + n = c; + goto downLeftSpine; + } + n = c; + return maxVersion(n) <= readVersion; + } else { + n = nextSibling(n); + if (n == nullptr) { + return true; + } + goto downLeftSpine; + } + } + + n = child; + remaining = remaining.subspan(1, remaining.size() - 1); + ++searchPathLen; + + if (n->partialKeyLen > 0) { + int commonLen = std::min(n->partialKeyLen, remaining.size()); + int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); + searchPathLen += i; + if (i < commonLen) { + auto c = n->partialKey()[i] <=> remaining[i]; + if (c > 0) { + if (searchPathLen < prefixLen) { + goto downLeftSpine; + } + if (n->entryPresent && n->entry.rangeVersion > readVersion) { + return false; + } + return maxVersion(n) <= readVersion; + } else { + n = nextSibling(n); + if (n == nullptr) { + return true; + } + goto downLeftSpine; + } + } + if (commonLen == n->partialKeyLen) { + // partial key matches + remaining = remaining.subspan(commonLen, remaining.size() - commonLen); + } else if (n->partialKeyLen > int(remaining.size())) { + assert(searchPathLen >= prefixLen); + if (n->entryPresent && n->entry.rangeVersion > readVersion) { + return false; + } + return maxVersion(n) <= readVersion; + } + } + if (maxV <= readVersion) { + return true; + } + } +downLeftSpine: + for (; !n->entryPresent; n = getFirstChildExists(n)) { + } + return n->entry.rangeVersion <= readVersion; +} + +// Return true if the max version among all keys that start with key[:prefixLen] +// that are < key is <= readVersion +bool checkRangeRightSide(Node *n, std::span key, int prefixLen, + InternalVersionT readVersion, ReadContext *tls) { + auto remaining = key; + int searchPathLen = 0; + + for (;; ++tls->range_read_iterations_accum) { + assert(searchPathLen <= int(key.size())); + if (remaining.size() == 0) { + goto downLeftSpine; + } + + if (searchPathLen >= prefixLen) { + if (n->entryPresent && n->entry.pointVersion > readVersion) { + return false; + } + + if (!checkMaxBetweenExclusive(n, -1, remaining[0], readVersion, tls)) { + return false; + } + } + + if (searchPathLen > prefixLen && n->entryPresent && + n->entry.rangeVersion > readVersion) { + return false; + } + + Node *child = getChild(n, remaining[0]); + if (child == nullptr) { + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; + goto downLeftSpine; + } else { + goto backtrack; + } + } + + n = child; + remaining = remaining.subspan(1, remaining.size() - 1); + ++searchPathLen; + + if (n->partialKeyLen > 0) { + int commonLen = std::min(n->partialKeyLen, remaining.size()); + int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); + searchPathLen += i; + if (i < commonLen) { + ++searchPathLen; + auto c = n->partialKey()[i] <=> remaining[i]; + if (c > 0) { + goto downLeftSpine; + } else { + if (searchPathLen > prefixLen && n->entryPresent && + n->entry.rangeVersion > readVersion) { + return false; + } + goto backtrack; + } + } + if (commonLen == n->partialKeyLen) { + // partial key matches + remaining = remaining.subspan(commonLen, remaining.size() - commonLen); + } else if (n->partialKeyLen > int(remaining.size())) { + goto downLeftSpine; + } + } + } +backtrack: + for (;;) { + // searchPathLen > prefixLen implies n is not the root + if (searchPathLen > prefixLen && maxVersion(n) > readVersion) { + return false; + } + if (n->parent == nullptr) { + return true; + } + auto next = getChildGeq(n->parent, n->parentsIndex + 1); + if (next == nullptr) { + searchPathLen -= 1 + n->partialKeyLen; + n = n->parent; + } else { + searchPathLen -= n->partialKeyLen; + n = next; + searchPathLen += n->partialKeyLen; + goto downLeftSpine; + } + } +downLeftSpine: + for (; !n->entryPresent; n = getFirstChildExists(n)) { + } + return n->entry.rangeVersion <= readVersion; +} +bool checkRangeRead(Node *n, std::span begin, + std::span end, InternalVersionT readVersion, + ReadContext *tls) { + int lcp = longestCommonPrefix(begin.data(), end.data(), + std::min(begin.size(), end.size())); + if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && + end.back() == 0) { + return checkPointRead(n, begin, readVersion, tls); + } + if (lcp == int(begin.size() - 1) && end.size() == begin.size() && + int(begin.back()) + 1 == int(end.back())) { + return checkPrefixRead(n, begin, readVersion, tls); + } + + ++tls->range_read_accum; + + auto remaining = begin.subspan(0, lcp); + Arena arena; + + // Advance down common prefix, but stay on a physical path in the tree + for (;; ++tls->range_read_iterations_accum) { + assert(getSearchPath(arena, n) <=> + begin.subspan(0, lcp - remaining.size()) == + 0); + if (remaining.size() == 0) { + break; + } + auto [c, v] = getChildAndMaxVersion(n, remaining[0]); + Node *child = c; + if (child == nullptr) { + break; + } + + if (child->partialKeyLen > 0) { + int cl = std::min(child->partialKeyLen, remaining.size() - 1); + int i = + longestCommonPrefix(child->partialKey(), remaining.data() + 1, cl); + if (i != child->partialKeyLen) { + break; + } + } + if (v <= readVersion) { + ++tls->range_read_short_circuit_accum; + return true; + } + n = child; + remaining = + remaining.subspan(1 + child->partialKeyLen, + remaining.size() - (1 + child->partialKeyLen)); + } + assert(getSearchPath(arena, n) <=> begin.subspan(0, lcp - remaining.size()) == + 0); + + const int consumed = lcp - remaining.size(); + assume(consumed >= 0); + + begin = begin.subspan(consumed, int(begin.size()) - consumed); + end = end.subspan(consumed, int(end.size()) - consumed); + lcp -= consumed; + + if (lcp == int(begin.size())) { + return checkRangeRightSide(n, end, lcp, readVersion, tls); + } + + // This makes it safe to check maxVersion within checkRangeLeftSide. If this + // were false, then we would have returned above since lcp == begin.size(). + assert(!(n->parent == nullptr && begin.size() == 0)); + + return checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp], + readVersion, tls) && + checkRangeLeftSide(n, begin, lcp + 1, readVersion, tls) && + checkRangeRightSide(n, end, lcp + 1, readVersion, tls); +} +} // namespace + struct __attribute__((visibility("hidden"))) ConflictSet::Impl { + // We still have the sequential implementation for compilers that don't + // support preserve_none and musttail + void useSequential(const ReadRange *reads, Result *result, int count, + CheckContext &context) { + for (int i = 0; i < count; ++i) { + if (reads[i].readVersion < oldestVersionFullPrecision) [[unlikely]] { + result[i] = TooOld; + } else { + bool ok; + if (reads[i].end.len == 0) { + ok = checkPointRead( + root, + std::span(reads[i].begin.p, reads[i].begin.len), + InternalVersionT(reads[i].readVersion), &context.tls); + } else { + ok = checkRangeRead( + root, + std::span(reads[i].begin.p, reads[i].begin.len), + std::span(reads[i].end.p, reads[i].end.len), + InternalVersionT(reads[i].readVersion), &context.tls); + } + result[i] = ok ? Commit : Conflict; + } + } + } + void check(const ReadRange *reads, Result *result, int count) { assert(oldestVersionFullPrecision >= newestVersionFullPrecision - kNominalVersionWindow); @@ -3740,40 +4204,42 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { } int64_t check_byte_accum = 0; - constexpr int kConcurrent = 16; - CheckJob inProgress[kConcurrent]; CheckContext context; context.tls.impl = this; - context.count = count; - context.oldestVersionFullPrecision = oldestVersionFullPrecision; - context.root = root; - context.queries = reads; - context.results = result; - int64_t started = std::min(kConcurrent, count); - context.started = started; - for (int i = 0; i < started; i++) { - inProgress[i].init(reads + i, result + i, root, - oldestVersionFullPrecision); - } - for (int i = 0; i < started - 1; i++) { - inProgress[i].next = inProgress + i + 1; - } - for (int i = 1; i < started; i++) { - inProgress[i].prev = inProgress + i - 1; - } - inProgress[0].prev = inProgress + started - 1; - inProgress[started - 1].next = inProgress; -#if __has_attribute(musttail) - // Kick off the sequence of tail calls that finally returns once all jobs - // are done - inProgress->continuation(inProgress, &context); -#else - context.job = inProgress; - context.done = false; - while (!context.done) { - context.job->continuation(context.job, &context); +#if __has_attribute(preserve_none) && __has_attribute(musttail) + if (count == 1) { + useSequential(reads, result, count, context); + } else { + constexpr int kConcurrent = 16; + CheckJob inProgress[kConcurrent]; + context.count = count; + context.oldestVersionFullPrecision = oldestVersionFullPrecision; + context.root = root; + context.queries = reads; + context.results = result; + int64_t started = std::min(kConcurrent, count); + context.started = started; + for (int i = 0; i < started; i++) { + inProgress[i].init(reads + i, result + i, root, + oldestVersionFullPrecision); + } + for (int i = 0; i < started - 1; i++) { + inProgress[i].next = inProgress + i + 1; + } + for (int i = 1; i < started; i++) { + inProgress[i].prev = inProgress + i - 1; + } + inProgress[0].prev = inProgress + started - 1; + inProgress[started - 1].next = inProgress; + + // Kick off the sequence of tail calls that finally returns once all jobs + // are done + inProgress->continuation(inProgress, &context); } + +#else + useSequential(reads, result, count, context); #endif for (int i = 0; i < count; ++i) { @@ -3887,17 +4353,17 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { int64_t set_oldest_iterations_accum = 0; for (; fuel > 0 && n != nullptr; ++set_oldest_iterations_accum) { rezero(n, oldestVersion); - // The "make sure gc keeps up with writes" calculations assume that we're - // scanning key by key, not node by node. Make sure we only spend fuel - // when there's a logical entry. + // The "make sure gc keeps up with writes" calculations assume that + // we're scanning key by key, not node by node. Make sure we only spend + // fuel when there's a logical entry. fuel -= n->entryPresent; if (n->entryPresent && std::max(n->entry.pointVersion, n->entry.rangeVersion) <= oldestVersion) { // Any transaction n would have prevented from committing is // going to fail with TooOld anyway. - // There's no way to insert a range such that range version of the right - // node is greater than the point version of the left node + // There's no way to insert a range such that range version of the + // right node is greater than the point version of the left node assert(n->entry.rangeVersion <= oldestVersion); n = erase(n, &tls, this, /*logical*/ false); } else { @@ -3941,9 +4407,9 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { InternalVersionT::zero = tls.zero = oldestVersion; #endif #ifdef NDEBUG - // This is here for performance reasons, since we want to amortize the cost - // of storing the search path as a string. In tests, we want to exercise the - // rest of the code often. + // This is here for performance reasons, since we want to amortize the + // cost of storing the search path as a string. In tests, we want to + // exercise the rest of the code often. if (keyUpdates < 100) { return; } @@ -4071,16 +4537,14 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { "The total number of entries inserted in the tree"); COUNTER(entries_erased_total, "The total number of entries erased from the tree"); - COUNTER( - gc_iterations_total, - "The total number of iterations of the main loop for garbage collection"); + COUNTER(gc_iterations_total, "The total number of iterations of the main " + "loop for garbage collection"); COUNTER(write_bytes_total, "Total number of key bytes in calls to addWrites"); GAUGE(oldest_version, "The lowest version that doesn't result in \"TooOld\" for checks"); GAUGE(newest_version, "The version of the most recent call to addWrites"); - GAUGE( - oldest_extant_version, - "A lower bound on the lowest version associated with an existing entry"); + GAUGE(oldest_extant_version, "A lower bound on the lowest version " + "associated with an existing entry"); // ==================== END METRICS DEFINITIONS ==================== #undef GAUGE #undef COUNTER @@ -4338,8 +4802,8 @@ std::string strinc(std::string_view str, bool &ok) { if ((uint8_t &)(str[index]) != 255) break; - // Must not be called with a string that consists only of zero or more '\xff' - // bytes. + // Must not be called with a string that consists only of zero or more + // '\xff' bytes. if (index < 0) { ok = false; return {};