Implement point reads

This commit is contained in:
2024-01-23 11:48:12 -08:00
parent 3a720ca3ec
commit 407b9af750

View File

@@ -12,6 +12,8 @@
#include <utility>
#include <vector>
#define DEBUG 0
__attribute__((always_inline)) void *safe_malloc(size_t s) {
if (void *p = malloc(s)) {
return p;
@@ -714,6 +716,48 @@ int getChildGeq(Node *self, int child) {
return -1;
}
int getChildLeq(Node *self, int child) {
if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
for (int i = self->numChildren - 1; i >= 0; --i) {
if (i > 0) {
assert(self4->index[i - 1] < self4->index[i]);
}
if (self4->index[i] <= child) {
return self4->index[i];
}
}
} else if (self->type == Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
for (int i = self->numChildren - 1; i >= 0; --i) {
if (i > 0) {
assert(self16->index[i - 1] < self16->index[i]);
}
if (self16->index[i] <= child) {
return self16->index[i];
}
}
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
// TODO simd
for (int i = child; i >= 0; --i) {
if (self48->index[i] >= 0) {
assert(self48->children[self48->index[i]] != nullptr);
return i;
}
}
} else {
auto *self256 = static_cast<Node256 *>(self);
// TODO simd?
for (int i = child; i >= 0; --i) {
if (self256->children[i]) {
return i;
}
}
}
return -1;
}
static void setChildrenParents(Node *node) {
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
getChildExists(node, i)->parent = node;
@@ -905,15 +949,95 @@ void debugPrintDot(FILE *file, Node *node) {
fprintf(file, "}\n");
}
static void insert(Node **self_, std::span<const uint8_t> key,
int64_t writeVersion) {
void printSearchPath(Node *n) {
Arena arena;
auto result = vector<char>(arena);
for (; n->parent != nullptr; n = n->parent) {
result.push_back(n->parentsIndex);
}
std::reverse(result.begin(), result.end());
result.push_back(0);
printf("Search path: %s\n", result.data());
}
Node *prevPhysical(Node *node) {
// Move up until there's a node at a lower index than the current search path
int selfIndex = 256;
for (;;) {
if (node->parent == nullptr) {
return nullptr;
}
auto prevChild = getChildLeq(node->parent, node->parentsIndex - 1);
if (prevChild >= 0) {
node = getChildExists(node->parent, prevChild);
break;
} else {
node = node->parent;
selfIndex = prevChild;
if (node->entryPresent) {
break;
}
}
}
// Move down the right spine
for (;;) {
auto rightMostChild = getChildLeq(node, selfIndex - 1);
if (rightMostChild >= 0) {
node = getChildExists(node, rightMostChild);
} else {
break;
}
}
return node;
}
struct Iterator {
Node *n;
int cmp;
};
Iterator lastLeq(Node *n, std::span<const uint8_t> key) {
for (;;) {
if (key.size() == 0) {
if (n->entryPresent) {
return {n, 0};
} else {
break;
}
} else {
int c = getChildLeq(n, key[0]);
if (c == key[0]) {
n = getChildExists(n, c);
key = key.subspan(1, key.size() - 1);
} else if (c >= 0) {
n = getChildExists(n, c);
break;
} else {
break;
}
}
}
for (;;) {
if (n->entryPresent) {
break;
}
n = prevPhysical(n);
assert(n != nullptr);
}
return {n, -1};
}
void insert(Node **self_, std::span<const uint8_t> key, int64_t writeVersion) {
for (;;) {
auto &self = *self_;
self->maxVersion = writeVersion;
if (key.size() == 0) {
auto l = lastLeq(self, key);
self->entryPresent = true;
self->entry.pointVersion = writeVersion;
// TODO set correct rangeVersion
assert(l.n != nullptr);
assert(l.n->entryPresent);
self->entry.rangeVersion = l.n->entry.rangeVersion;
return;
}
auto &child = getOrCreateChild(self, key.front());
@@ -928,7 +1052,25 @@ static void insert(Node **self_, std::span<const uint8_t> key,
}
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
void check(const ReadRange *, Result *, int) const {}
void check(const ReadRange *reads, Result *result, int count) const {
for (int i = 0; i < count; ++i) {
const auto &r = reads[i];
if (r.readVersion < oldestVersion) {
result[i] = TooOld;
continue;
}
// TODO support non-point reads
assert(r.end.len == 0);
auto [l, c] =
lastLeq(root, std::span<const uint8_t>(r.begin.p, r.begin.len));
assert(l != nullptr);
assert(l->entryPresent);
result[i] = (c == 0 ? l->entry.pointVersion : l->entry.rangeVersion) >
r.readVersion
? Conflict
: Commit;
}
}
void addWrites(const WriteRange *writes, int count) {
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
@@ -1092,6 +1234,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
writes[i].begin.len = iter->size();
writes[i].end.len = 0;
writes[i].writeVersion = v;
#if DEBUG
printf("Write: {%.*s} -> %d\n", writes[i].begin.len, writes[i].begin.p,
int(writes[i].writeVersion));
#endif
}
cs.addWrites(writes, numWrites);
refImpl.addWrites(writes, numWrites);
@@ -1121,6 +1267,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
reads[i].begin.len = iter->size();
reads[i].end.len = 0;
reads[i].readVersion = v;
#if DEBUG
printf("Read: {%.*s} at %d\n", reads[i].begin.len, reads[i].begin.p,
int(reads[i].readVersion));
#endif
}
auto *results1 = new (arena) ConflictSet::Result[numReads];
auto *results2 = new (arena) ConflictSet::Result[numReads];