Implement new checkRangeRead idea

This commit is contained in:
2024-02-15 11:37:08 -08:00
parent 98dccf5c23
commit c2193fdba0
2 changed files with 171 additions and 342 deletions

View File

@@ -660,6 +660,9 @@ std::string getSearchPathPrintable(Node *n);
// soon as it can prove that there's no conflict.
bool checkPointRead(Node *n, const std::span<const uint8_t> key,
int64_t readVersion) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;;) {
if (n->maxVersion <= readVersion) {
@@ -726,59 +729,6 @@ downLeftSpine:
}
}
// Precondition: node has a child at index begin
int64_t maxRightOf(Node *n, int begin) {
int64_t result = std::numeric_limits<int64_t>::lowest();
while (begin >= 0) {
result = std::max(result, getChildExists(n, begin)->maxVersion);
begin = getChildGeq(n, begin + 1);
}
return result;
}
// Return the maximum version among all keys starting with the search path of
// `n` + a child > `begin`
int64_t maxRightOfExclusive(Node *n, int begin) {
int64_t result = std::numeric_limits<int64_t>::lowest();
int index = begin;
for (;;) {
index = getChildGeq(n, index + 1);
if (index < 0) {
break;
}
auto *child = getChildExists(n, index);
if (index > begin + 1 && child->entryPresent) {
result = std::max(result, child->entry.rangeVersion);
}
result = std::max(result, child->maxVersion);
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "At `%s', max version right of %02x is %" PRId64 "\n",
getSearchPathPrintable(n).c_str(), begin, result);
#endif
return result;
}
// Return the maximum version among all keys starting with the search path of
// `n` + a child < `end`
int64_t maxLeftOfExclusive(Node *n, int end) {
int begin = -1;
int64_t result = std::numeric_limits<int64_t>::lowest();
for (;;) {
begin = getChildGeq(n, begin + 1);
if (begin < 0 || begin >= end) {
break;
}
result = std::max(result, getChildExists(n, begin)->maxVersion);
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "At `%s', max version left of %02x is %" PRId64 "\n",
getSearchPathPrintable(n).c_str(), end, result);
#endif
return result;
}
// Precondition: child exists at `end`
int64_t maxBetweenExclusive(Node *n, int begin, int end) {
int64_t result = std::numeric_limits<int64_t>::lowest();
int next = begin;
@@ -787,7 +737,11 @@ int64_t maxBetweenExclusive(Node *n, int begin, int end) {
if (next < 0 || next >= end) {
break;
}
result = std::max(result, getChildExists(n, next)->maxVersion);
auto *child = getChildExists(n, next);
if (child->entryPresent) {
result = std::max(result, child->entry.rangeVersion);
}
result = std::max(result, child->maxVersion);
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "At `%s', max version in (%02x, %02x) is %" PRId64 "\n",
@@ -796,139 +750,6 @@ int64_t maxBetweenExclusive(Node *n, int begin, int end) {
return result;
}
// Returns true if the version of all keys >= key in the subtree rooted at n is
// <= readVersion
bool checkLeftOfPyramid(Node *n, const std::span<const uint8_t> key,
int64_t readVersion) {
auto remaining = key;
for (;;) {
if (n->maxVersion <= readVersion) {
return true;
}
if (remaining.size() == 0) {
return n->maxVersion <= readVersion;
}
auto v = maxRightOfExclusive(n, remaining[0]);
if (v > readVersion) {
return false;
};
{
int c = getChildGeq(n, int(remaining[0]) + 1);
if (c >= 0) {
auto *child = getChildExists(n, c);
if (child->entryPresent && child->entry.rangeVersion > readVersion) {
return false;
}
}
}
int c = getChildGeq(n, remaining[0]);
if (c == remaining[0]) {
n = getChildExists(n, c);
remaining = remaining.subspan(1, remaining.size() - 1);
} else {
if (c >= 0) {
n = getChildExists(n, c);
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return n->maxVersion <= readVersion;
} else {
return true;
}
}
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
for (int i = 0; i < commonLen; ++i) {
auto c = n->partialKey[i] <=> remaining[i];
if (c == 0) {
continue;
}
if (c < 0) {
return true;
} else {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return n->maxVersion <= readVersion;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return n->maxVersion <= readVersion;
}
}
}
}
// Returns true if the version of all keys < key in n is <= readVersion
bool checkRightOfPyramid(Node *n, const std::span<const uint8_t> key,
int64_t readVersion) {
assert(key.size() > 0);
auto remaining = key;
for (bool first = true;; first = false) {
if (n->maxVersion <= readVersion) {
return true;
}
if (!first && n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
if (remaining.size() == 0) {
return true;
}
if (n->entryPresent && n->entry.pointVersion > readVersion) {
return false;
}
auto v = maxLeftOfExclusive(n, remaining[0]);
if (v > readVersion) {
return false;
}
int c = getChildGeq(n, remaining[0]);
if (c == remaining[0]) {
n = getChildExists(n, c);
remaining = remaining.subspan(1, remaining.size() - 1);
} else {
if (c >= 0) {
return true;
}
return n->maxVersion <= readVersion;
}
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
for (int i = 0; i < commonLen; ++i) {
auto c = n->partialKey[i] <=> remaining[i];
if (c == 0) {
continue;
}
if (c > 0) {
return true;
}
return n->maxVersion <= readVersion;
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
return true;
}
}
}
}
Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
assert(n != nullptr);
auto result = vector<uint8_t>(arena);
@@ -1093,165 +914,186 @@ bytes:
return i;
}
__attribute__((always_inline)) inline void
ascend(int &depth, int &lcp, Node *oldNode, Vector<uint8_t> &searchPath) {
depth -= 1 + oldNode->partialKeyLen;
searchPath.resize(depth);
lcp = std::min(lcp, depth);
}
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key,
int64_t readVersion) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s*\n", printable(key).c_str());
#endif
auto remaining = key;
__attribute__((always_inline)) inline void
descend(int &depth, int &lcp, Node *newNode, std::span<const uint8_t> end,
Vector<uint8_t> &searchPath) {
if (depth == lcp) {
if (lcp < int(end.size()) && newNode->parentsIndex == end[lcp]) {
++lcp;
for (int i = 0; i < newNode->partialKeyLen && lcp < int(end.size());
++i) {
if (newNode->partialKey[i] == end[lcp]) {
++lcp;
} else {
break;
for (;;) {
if (n->maxVersion <= readVersion) {
return true;
}
if (remaining.size() == 0) {
return n->maxVersion <= readVersion;
}
int c = getChildGeq(n, remaining[0]);
if (c == remaining[0]) {
n = getChildExists(n, c);
remaining = remaining.subspan(1, remaining.size() - 1);
} else {
if (c >= 0) {
n = getChildExists(n, c);
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
for (int i = 0; i < commonLen; ++i) {
auto c = n->partialKey[i] <=> remaining[i];
if (c == 0) {
continue;
}
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return n->maxVersion <= readVersion;
}
}
}
depth += 1 + newNode->partialKeyLen;
searchPath.push_back(newNode->parentsIndex);
searchPath.insert(searchPath.end(), newNode->partialKey,
newNode->partialKey + newNode->partialKeyLen);
downLeftSpine:
if (n == nullptr) {
return true;
}
for (;;) {
if (n->entryPresent) {
return n->entry.rangeVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
}
}
// Return true if the max version among all keys that start with key + [child],
// where begin < child < end, is <= readVersion
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
int end, int64_t readVersion) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
#endif
auto remaining = key;
for (;;) {
if (n->maxVersion <= readVersion) {
return true;
}
if (remaining.size() == 0) {
return maxBetweenExclusive(n, begin, end) <= readVersion;
}
int c = getChildGeq(n, remaining[0]);
if (c == remaining[0]) {
n = getChildExists(n, c);
remaining = remaining.subspan(1, remaining.size() - 1);
} else {
if (c >= 0) {
n = getChildExists(n, c);
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
for (int i = 0; i < commonLen; ++i) {
auto c = n->partialKey[i] <=> remaining[i];
if (c == 0) {
continue;
}
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
if (begin < n->partialKey[remaining.size()] &&
n->partialKey[remaining.size()] < end) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return n->maxVersion <= readVersion;
}
return true;
}
}
}
downLeftSpine:
if (n == nullptr) {
return true;
}
for (;;) {
if (n->entryPresent) {
return n->entry.rangeVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
}
}
bool checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, int64_t readVersion,
Arena &arena) {
std::span<const uint8_t> end, int64_t readVersion) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
SearchStepWise search{n, begin.subspan(0, lcp)};
for (;;) {
assert(getSearchPath(arena, search.n) <=>
begin.subspan(0, lcp - search.remaining.size()) ==
0);
if (search.n->maxVersion <= readVersion) {
return true;
if (lcp == int(begin.size())) {
for (int i = lcp; i < int(end.size()); ++i) {
if (!checkPointRead(n, end.subspan(0, i), readVersion)) {
return false;
}
if (!checkRangeStartsWith(n, end.subspan(0, i), -1, end[i],
readVersion)) {
return false;
}
}
if (search.step()) {
break;
}
}
assert(getSearchPath(arena, search.n) <=>
begin.subspan(0, lcp - search.remaining.size()) ==
0);
// Check that we can start FirstGeq where Search left off
const int consumed = lcp - search.remaining.size();
assert(consumed >= 0);
auto trimmedBegin = begin.subspan(consumed, int(begin.size()) - consumed);
auto trimmedEnd = end.subspan(consumed, int(end.size()) - consumed);
auto left =
firstGeq(search.n, begin.subspan(consumed, int(begin.size()) - consumed));
#ifndef NDEBUG
auto iter = firstGeq(n, begin);
assert(left.cmp == iter.cmp);
assert(left.n == iter.n);
#endif
if (left.n == nullptr) {
return true;
}
auto searchPath = getSearchPath(arena, left.n);
if (left.cmp != 0 && left.n->entry.rangeVersion > readVersion) {
if (!checkRangeStartsWith(n, begin, readVersion)) {
return false;
}
int depth = searchPath.size();
lcp = longestCommonPrefix(searchPath.data(), end.data(),
std::min(searchPath.size(), end.size()));
bool first = true;
for (auto *iter = left.n; iter != nullptr; first = false) {
const int cl = std::min(searchPath.size(), end.size());
assert(depth == int(searchPath.size()));
assert(lcp == longestCommonPrefix(searchPath.data(), end.data(), cl));
// if (searchPath >= end) break;
if ((cl == lcp ? searchPath.size() <=> end.size()
: searchPath[lcp] <=> end[lcp]) >= 0) {
break;
for (int i = begin.size() - 1; i >= lcp + 1; --i) {
if (!checkRangeStartsWith(n, begin.subspan(0, i), int(begin[i]), 256,
readVersion)) {
return false;
}
if (iter->entryPresent) {
if (!first && iter->entry.rangeVersion > readVersion) {
return false;
}
if (iter->entry.pointVersion > readVersion) {
return false;
}
}
if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp],
readVersion)) {
return false;
}
for (int i = lcp + 1; i < int(end.size()); ++i) {
if (!checkPointRead(n, end.subspan(0, i), readVersion)) {
return false;
}
assert(searchPath == getSearchPath(arena, iter));
if (lcp == depth) {
// end starts with searchPath, so end < range
if (iter->maxVersion <= readVersion) {
return true;
}
int index = -1;
for (;;) {
auto nextChild = getChildGeq(iter, index + 1);
if (nextChild >= 0) {
auto *result = getChildExists(iter, nextChild);
iter = result;
descend(depth, lcp, iter, end, searchPath);
break;
}
if (iter->parent == nullptr) {
iter = nullptr;
break;
}
ascend(depth, lcp, iter, searchPath);
index = iter->parentsIndex;
iter = iter->parent;
}
} else {
// end does not start with searchPath, so range end <= end
if (iter->maxVersion > readVersion) {
return false;
}
for (;;) {
if (iter->parent == nullptr) {
assert(searchPath.size() == 0);
iter = nullptr;
break;
}
auto next = getChildGeq(iter->parent, iter->parentsIndex + 1);
if (next < 0) {
ascend(depth, lcp, iter, searchPath);
iter = iter->parent;
} else {
ascend(depth, lcp, iter, searchPath);
iter = iter->parent;
if (depth - iter->partialKeyLen - lcp > 1) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s, right of %02x\n",
printable(searchPath).c_str(), next);
#endif
if (maxRightOf(iter, next) > readVersion) {
return false;
}
} else {
iter = getChildExists(iter, next);
descend(depth, lcp, iter, end, searchPath);
break;
}
}
}
if (!checkRangeStartsWith(n, end.subspan(0, i), -1, end[i], readVersion)) {
return false;
}
}
return true;
@@ -1354,26 +1196,13 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
reads[i].begin.len),
std::span<const uint8_t>(reads[i].end.p,
reads[i].end.len),
reads[i].readVersion, arena)
reads[i].readVersion)
: checkPointRead(root,
std::span<const uint8_t>(reads[i].begin.p,
reads[i].begin.len),
reads[i].readVersion))
? Commit
: Conflict;
auto k = std::span<const uint8_t>(reads[i].begin.p, reads[i].begin.len);
if (k.size() > 0) {
bool expected = checkRangeRead(root, k, std::vector<uint8_t>(33, 0xff),
reads[i].readVersion, arena);
bool actual = checkLeftOfPyramid(root, k, reads[i].readVersion);
if (expected != actual) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Expected %d, got %d for [%s,)\n", int(expected),
int(actual), printable(k).c_str());
#endif
result[i] = TooOld;
}
}
}
}