Implement setOldestVersion

This commit is contained in:
2024-02-19 15:58:59 -08:00
parent 939b791e01
commit c9baa80212
3 changed files with 202 additions and 137 deletions

View File

@@ -3,6 +3,7 @@
#include <byteswap.h> #include <byteswap.h>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <string>
#define ANKERL_NANOBENCH_IMPLEMENT #define ANKERL_NANOBENCH_IMPLEMENT
#include "third_party/nanobench.h" #include "third_party/nanobench.h"
@@ -552,10 +553,9 @@ struct SkipListConflictSet {
auto *ranges = new (arena) ReadConflictRange[count]; auto *ranges = new (arena) ReadConflictRange[count];
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
ranges[i].begin = {reads[i].begin.p, size_t(reads[i].begin.len)}; ranges[i].begin = {reads[i].begin.p, size_t(reads[i].begin.len)};
ranges[i].end = reads[i].end.len > 0 ranges[i].end = {reads[i].end.p,
? std::span<const uint8_t>{reads[i].end.p, size_t(reads[i].end.len == 0 ? reads[i].begin.len + 1
size_t(reads[i].end.len)} : reads[i].end.len)};
: keyAfter(arena, ranges[i].begin);
ranges[i].version = reads[i].readVersion; ranges[i].version = reads[i].readVersion;
if (reads[i].readVersion < oldestVersion) { if (reads[i].readVersion < oldestVersion) {
results[i] = ConflictSet::TooOld; results[i] = ConflictSet::TooOld;
@@ -581,12 +581,10 @@ struct SkipListConflictSet {
for (int i = 0; i * 2 < ss; ++i) { for (int i = 0; i * 2 < ss; ++i) {
const auto &w = writes[s * stripeSize / 2 + i]; const auto &w = writes[s * stripeSize / 2 + i];
values[i * 2] = {w.begin.p, size_t(w.begin.len)}; values[i * 2] = {w.begin.p, size_t(w.begin.len)};
if (w.end.len > 0) { values[i * 2 + 1] = {
values[i * 2 + 1] = {w.end.p, size_t(w.end.len)}; w.end.p, size_t(w.end.len == 0 ? w.begin.len + 1 : w.end.len)};
} else {
values[i * 2 + 1] = keyAfter(arena, values[i * 2]);
}
writeVersions[i] = w.writeVersion; writeVersions[i] = w.writeVersion;
keyUpdates += 2;
} }
skipList.find(values, fingers, temp, ss); skipList.find(values, fingers, temp, ss);
skipList.addConflictRanges(fingers, ss / 2, writeVersions); skipList.addConflictRanges(fingers, ss / 2, writeVersions);
@@ -594,11 +592,32 @@ struct SkipListConflictSet {
} }
} }
void setOldestVersion(int64_t oldestVersion) {
this->oldestVersion = oldestVersion;
SkipList::Finger finger;
int temp;
std::span<const uint8_t> key = removalKey;
skipList.find(&key, &finger, &temp, 1);
skipList.removeBefore(oldestVersion, finger, std::exchange(keyUpdates, 0));
removalKey = std::basic_string<uint8_t>(finger.getValue().data(),
finger.getValue().size());
}
private: private:
int64_t keyUpdates = 0;
std::basic_string<uint8_t> removalKey;
int64_t oldestVersion; int64_t oldestVersion;
SkipList skipList; SkipList skipList;
}; };
ConflictSet::ReadRange singleton(Arena &arena, std::span<const uint8_t> key) {
auto r =
std::span<uint8_t>(new (arena) uint8_t[key.size() + 1], key.size() + 1);
memcpy(r.data(), key.data(), key.size());
r[key.size()] = 0;
return {key.data(), int(key.size()), r.data(), int(r.size())};
}
ConflictSet::ReadRange prefixRange(Arena &arena, std::span<const uint8_t> key) { ConflictSet::ReadRange prefixRange(Arena &arena, std::span<const uint8_t> key) {
int index; int index;
for (index = key.size() - 1; index >= 0; index--) for (index = key.size() - 1; index >= 0; index--)
@@ -625,6 +644,8 @@ constexpr int kOpsPerTx = 100;
constexpr int kPrefixLen = 0; constexpr int kPrefixLen = 0;
constexpr int kMvccWindow = 100000;
std::span<const uint8_t> makeKey(Arena &arena, int index) { std::span<const uint8_t> makeKey(Arena &arena, int index) {
auto result = auto result =
@@ -652,12 +673,13 @@ template <class ConflictSet_> void benchConflictSet(const std::string &name) {
writes.reserve(kNumKeys); writes.reserve(kNumKeys);
for (int i = 0; i < kNumKeys; ++i) { for (int i = 0; i < kNumKeys; ++i) {
auto key = makeKey(arena, i); auto key = makeKey(arena, i);
ConflictSet::WriteRange conflict; ConflictSet::WriteRange w;
conflict.begin.p = key.data(); auto r = singleton(arena, key);
conflict.begin.len = key.size(); w.begin.p = r.begin.p;
conflict.end.len = 0; w.begin.len = r.begin.len;
conflict.writeVersion = version + 1; w.end.p = r.end.p;
writes.push_back(conflict); w.end.len = 0;
writes.push_back(w);
} }
cs.addWrites(writes.data(), writes.size()); cs.addWrites(writes.data(), writes.size());
++version; ++version;
@@ -701,9 +723,7 @@ template <class ConflictSet_> void benchConflictSet(const std::string &name) {
std::vector<ConflictSet::ReadRange> reads; std::vector<ConflictSet::ReadRange> reads;
auto iter = points.begin(); auto iter = points.begin();
for (int i = 0; i < kOpsPerTx; ++i) { for (int i = 0; i < kOpsPerTx; ++i) {
ConflictSet::ReadRange r; auto r = singleton(arena, *iter);
r.begin.p = iter->data();
r.begin.len = iter->size();
r.end.len = 0; r.end.len = 0;
r.readVersion = version - 1; r.readVersion = version - 1;
reads.push_back(r); reads.push_back(r);
@@ -758,8 +778,10 @@ template <class ConflictSet_> void benchConflictSet(const std::string &name) {
auto iter = points.begin(); auto iter = points.begin();
for (int i = 0; i < kOpsPerTx; ++i) { for (int i = 0; i < kOpsPerTx; ++i) {
ConflictSet::WriteRange w; ConflictSet::WriteRange w;
w.begin.p = iter->data(); auto r = singleton(arena, *iter);
w.begin.len = iter->size(); w.begin.p = r.begin.p;
w.begin.len = r.begin.len;
w.end.p = r.end.p;
w.end.len = 0; w.end.len = 0;
writes.push_back(w); writes.push_back(w);
++iter; ++iter;
@@ -771,6 +793,7 @@ template <class ConflictSet_> void benchConflictSet(const std::string &name) {
w.writeVersion = v; w.writeVersion = v;
} }
cs.addWrites(writes.data(), writes.size()); cs.addWrites(writes.data(), writes.size());
cs.setOldestVersion(std::max<int64_t>(version - kMvccWindow, 0));
}); });
} }
@@ -794,6 +817,7 @@ template <class ConflictSet_> void benchConflictSet(const std::string &name) {
w.writeVersion = v; w.writeVersion = v;
} }
cs.addWrites(writes.data(), writes.size()); cs.addWrites(writes.data(), writes.size());
cs.setOldestVersion(std::max<int64_t>(version - kMvccWindow, 0));
}); });
} }
@@ -817,6 +841,7 @@ template <class ConflictSet_> void benchConflictSet(const std::string &name) {
w.writeVersion = v; w.writeVersion = v;
} }
cs.addWrites(writes.data(), writes.size()); cs.addWrites(writes.data(), writes.size());
cs.setOldestVersion(std::max<int64_t>(version - kMvccWindow, 0));
}); });
} }
} }

View File

@@ -1380,6 +1380,114 @@ void addWriteRange(Node *&root, int64_t oldestVersion,
} }
} }
struct FirstGeqStepwise {
Node *n;
std::span<const uint8_t> remaining;
int cmp;
enum Phase {
Init,
// Being in this phase implies that the key matches the search path exactly
// up to this point
Search,
DownLeftSpine
};
Phase phase;
FirstGeqStepwise(Node *n, std::span<const uint8_t> remaining)
: n(n), remaining(remaining), phase(Init) {}
// Not being done implies that n is not the firstGeq
bool step() {
switch (phase) {
case Search:
if (remaining.size() == 0) {
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
return downLeftSpine();
} else {
int c = getChildGeq(n, remaining[0]);
if (c == remaining[0]) {
n = getChildExists(n, c);
remaining = remaining.subspan(1, remaining.size() - 1);
} else {
if (c >= 0) {
n = getChildExists(n, c);
return downLeftSpine();
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
}
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
for (int i = 0; i < commonLen; ++i) {
auto c = n->partialKey[i] <=> remaining[i];
if (c == 0) {
continue;
}
if (c > 0) {
return downLeftSpine();
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining =
remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
return downLeftSpine();
}
}
[[fallthrough]];
case Init:
phase = Search;
if (remaining.size() == 0 && n->entryPresent) {
cmp = 0;
return true;
}
return false;
case DownLeftSpine:
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
if (n->entryPresent) {
cmp = 1;
return true;
}
return false;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
bool downLeftSpine() {
phase = DownLeftSpine;
if (n == nullptr || n->entryPresent) {
cmp = 1;
return true;
}
return step();
}
};
Iterator firstGeq(Node *n, const std::span<const uint8_t> key) {
FirstGeqStepwise stepwise{n, key};
while (!stepwise.step())
;
return {stepwise.n, stepwise.cmp};
}
Iterator firstGeq(Node *n, std::string_view key) {
return firstGeq(
n, std::span<const uint8_t>((const uint8_t *)key.data(), key.size()));
}
struct __attribute__((visibility("hidden"))) ConflictSet::Impl { struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
void check(const ReadRange *reads, Result *result, int count) const { void check(const ReadRange *reads, Result *result, int count) const {
@@ -1407,8 +1515,10 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const auto &w = writes[i]; const auto &w = writes[i];
if (w.end.len > 0) { if (w.end.len > 0) {
keyUpdates += 2;
addWriteRange(root, oldestVersion, w); addWriteRange(root, oldestVersion, w);
} else { } else {
keyUpdates += 1;
auto *n = auto *n =
insert(&root, std::span<const uint8_t>(w.begin.p, w.begin.len), insert(&root, std::span<const uint8_t>(w.begin.p, w.begin.len),
w.writeVersion, true); w.writeVersion, true);
@@ -1427,9 +1537,39 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
} }
} }
} }
void setOldestVersion(int64_t oldestVersion) { void setOldestVersion(int64_t oldestVersion) {
this->oldestVersion = oldestVersion; this->oldestVersion = oldestVersion;
Node *prev = firstGeq(root, removalKey).n;
assert(prev != nullptr);
while (keyUpdates-- > 0) {
Node *n = nextLogical(prev);
if (n == nullptr) {
removalKey = {};
return;
}
if (std::max(prev->entry.pointVersion, prev->entry.rangeVersion) <=
oldestVersion) {
// Any transaction prev would have prevented from committing is
// going to fail with TooOld anyway.
// We still need to make sure that we don't introduce false positives by
// just removing it though.
if (n->entry.rangeVersion <= oldestVersion) {
prev->entryPresent = false;
if (prev->numChildren == 0 && prev->parent != nullptr) {
eraseChild(prev->parent, prev->parentsIndex);
}
}
}
prev = n;
}
removalKeyArena = Arena();
removalKey = getSearchPath(removalKeyArena, prev);
} }
explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) { explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) {
// Insert "" // Insert ""
root = newNode(); root = newNode();
@@ -1440,6 +1580,10 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
} }
~Impl() { destroyTree(root); } ~Impl() { destroyTree(root); }
Arena removalKeyArena;
std::span<const uint8_t> removalKey;
int64_t keyUpdates = 0;
Node *root; Node *root;
int64_t oldestVersion; int64_t oldestVersion;
}; };
@@ -1629,123 +1773,16 @@ void checkParentPointers(Node *node, bool &success) {
} }
} }
struct FirstGeqStepwise {
Node *n;
std::span<const uint8_t> remaining;
int cmp;
enum Phase {
Init,
// Being in this phase implies that the key matches the search path exactly
// up to this point
Search,
DownLeftSpine
};
Phase phase;
FirstGeqStepwise(Node *n, std::span<const uint8_t> remaining)
: n(n), remaining(remaining), phase(Init) {}
// Not being done implies that n is not the firstGeq
bool step() {
switch (phase) {
case Search:
if (remaining.size() == 0) {
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
return downLeftSpine();
} else {
int c = getChildGeq(n, remaining[0]);
if (c == remaining[0]) {
n = getChildExists(n, c);
remaining = remaining.subspan(1, remaining.size() - 1);
} else {
if (c >= 0) {
n = getChildExists(n, c);
return downLeftSpine();
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
}
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
for (int i = 0; i < commonLen; ++i) {
auto c = n->partialKey[i] <=> remaining[i];
if (c == 0) {
continue;
}
if (c > 0) {
return downLeftSpine();
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining =
remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
return downLeftSpine();
}
}
[[fallthrough]];
case Init:
phase = Search;
if (remaining.size() == 0 && n->entryPresent) {
cmp = 0;
return true;
}
return false;
case DownLeftSpine:
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
if (n->entryPresent) {
cmp = 1;
return true;
}
return false;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
bool downLeftSpine() {
phase = DownLeftSpine;
if (n == nullptr || n->entryPresent) {
cmp = 1;
return true;
}
return step();
}
};
Iterator firstGeq(Node *n, const std::span<const uint8_t> key) {
FirstGeqStepwise stepwise{n, key};
while (!stepwise.step())
;
return {stepwise.n, stepwise.cmp};
}
Iterator firstGeq(Node *n, std::string_view key) {
return firstGeq(
n, std::span<const uint8_t>((const uint8_t *)key.data(), key.size()));
}
[[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node, [[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node,
bool &success) { int64_t oldestVersion, bool &success) {
int64_t expected = std::numeric_limits<int64_t>::lowest(); int64_t expected = std::numeric_limits<int64_t>::lowest();
if (node->entryPresent) { if (node->entryPresent) {
expected = std::max(expected, node->entry.pointVersion); expected = std::max(expected, node->entry.pointVersion);
} }
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChildExists(node, i); auto *child = getChildExists(node, i);
expected = std::max(expected, checkMaxVersion(root, child, success)); expected = std::max(expected,
checkMaxVersion(root, child, oldestVersion, success));
if (child->entryPresent) { if (child->entryPresent) {
expected = std::max(expected, child->entry.rangeVersion); expected = std::max(expected, child->entry.rangeVersion);
} }
@@ -1759,7 +1796,7 @@ Iterator firstGeq(Node *n, std::string_view key) {
expected = std::max(expected, borrowed.n->entry.rangeVersion); expected = std::max(expected, borrowed.n->entry.rangeVersion);
} }
} }
if (node->maxVersion != expected) { if (node->maxVersion > oldestVersion && node->maxVersion != expected) {
fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n", fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",
getSearchPathPrintable(node).c_str(), node->maxVersion, expected); getSearchPathPrintable(node).c_str(), node->maxVersion, expected);
success = false; success = false;
@@ -1783,11 +1820,11 @@ Iterator firstGeq(Node *n, std::string_view key) {
return total; return total;
} }
bool checkCorrectness(Node *node) { bool checkCorrectness(Node *node, int64_t oldestVersion) {
bool success = true; bool success = true;
checkParentPointers(node, success); checkParentPointers(node, success);
checkMaxVersion(node, node, success); checkMaxVersion(node, node, oldestVersion, success);
checkEntriesExist(node, success); checkEntriesExist(node, success);
return success; return success;
@@ -1842,7 +1879,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
#if DEBUG_VERBOSE && !defined(NDEBUG) #if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check correctness\n"); fprintf(stderr, "Check correctness\n");
#endif #endif
bool success = checkCorrectness(driver.cs.root); bool success = checkCorrectness(driver.cs.root, driver.cs.oldestVersion);
if (!success) { if (!success) {
debugPrintDot(stdout, driver.cs.root); debugPrintDot(stdout, driver.cs.root);
fflush(stdout); fflush(stdout);

View File

@@ -470,14 +470,14 @@ inline std::string printable(std::span<const uint8_t> key) {
namespace { namespace {
template <class ConflictSetImpl> struct TestDriver { template <class ConflictSetImpl> struct TestDriver {
// TODO call setOldestVersion
Arbitrary arbitrary; Arbitrary arbitrary;
explicit TestDriver(const uint8_t *data, size_t size) explicit TestDriver(const uint8_t *data, size_t size)
: arbitrary({data, size}) {} : arbitrary({data, size}) {}
int64_t writeVersion = 0; int64_t writeVersion = 0;
ConflictSetImpl cs{writeVersion}; int64_t oldestVersion = 0;
ReferenceImpl refImpl{writeVersion}; ConflictSetImpl cs{oldestVersion};
ReferenceImpl refImpl{oldestVersion};
constexpr static auto kMaxKeyLen = 32; constexpr static auto kMaxKeyLen = 32;
@@ -659,6 +659,9 @@ template <class ConflictSetImpl> struct TestDriver {
} }
} }
} }
oldestVersion += arbitrary.bounded(2);
cs.setOldestVersion(oldestVersion);
refImpl.setOldestVersion(oldestVersion);
return false; return false;
} }
}; };