diff --git a/ConflictSet.cpp b/ConflictSet.cpp index c2c5558..4773963 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -12,6 +12,8 @@ #include #include +#define DEBUG 0 + __attribute__((always_inline)) void *safe_malloc(size_t s) { if (void *p = malloc(s)) { return p; @@ -714,6 +716,48 @@ int getChildGeq(Node *self, int child) { return -1; } +int getChildLeq(Node *self, int child) { + if (self->type == Type::Node4) { + auto *self4 = static_cast(self); + for (int i = self->numChildren - 1; i >= 0; --i) { + if (i > 0) { + assert(self4->index[i - 1] < self4->index[i]); + } + if (self4->index[i] <= child) { + return self4->index[i]; + } + } + } else if (self->type == Type::Node16) { + auto *self16 = static_cast(self); + for (int i = self->numChildren - 1; i >= 0; --i) { + if (i > 0) { + assert(self16->index[i - 1] < self16->index[i]); + } + if (self16->index[i] <= child) { + return self16->index[i]; + } + } + } else if (self->type == Type::Node48) { + auto *self48 = static_cast(self); + // TODO simd + for (int i = child; i >= 0; --i) { + if (self48->index[i] >= 0) { + assert(self48->children[self48->index[i]] != nullptr); + return i; + } + } + } else { + auto *self256 = static_cast(self); + // TODO simd? + for (int i = child; i >= 0; --i) { + if (self256->children[i]) { + return i; + } + } + } + return -1; +} + static void setChildrenParents(Node *node) { for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { getChildExists(node, i)->parent = node; @@ -905,15 +949,95 @@ void debugPrintDot(FILE *file, Node *node) { fprintf(file, "}\n"); } -static void insert(Node **self_, std::span key, - int64_t writeVersion) { +void printSearchPath(Node *n) { + Arena arena; + auto result = vector(arena); + for (; n->parent != nullptr; n = n->parent) { + result.push_back(n->parentsIndex); + } + std::reverse(result.begin(), result.end()); + result.push_back(0); + printf("Search path: %s\n", result.data()); +} + +Node *prevPhysical(Node *node) { + // Move up until there's a node at a lower index than the current search path + int selfIndex = 256; + for (;;) { + if (node->parent == nullptr) { + return nullptr; + } + auto prevChild = getChildLeq(node->parent, node->parentsIndex - 1); + if (prevChild >= 0) { + node = getChildExists(node->parent, prevChild); + break; + } else { + node = node->parent; + selfIndex = prevChild; + if (node->entryPresent) { + break; + } + } + } + // Move down the right spine + for (;;) { + auto rightMostChild = getChildLeq(node, selfIndex - 1); + if (rightMostChild >= 0) { + node = getChildExists(node, rightMostChild); + } else { + break; + } + } + return node; +} + +struct Iterator { + Node *n; + int cmp; +}; + +Iterator lastLeq(Node *n, std::span key) { + for (;;) { + if (key.size() == 0) { + if (n->entryPresent) { + return {n, 0}; + } else { + break; + } + } else { + int c = getChildLeq(n, key[0]); + if (c == key[0]) { + n = getChildExists(n, c); + key = key.subspan(1, key.size() - 1); + } else if (c >= 0) { + n = getChildExists(n, c); + break; + } else { + break; + } + } + } + for (;;) { + if (n->entryPresent) { + break; + } + n = prevPhysical(n); + assert(n != nullptr); + } + return {n, -1}; +} + +void insert(Node **self_, std::span key, int64_t writeVersion) { for (;;) { auto &self = *self_; self->maxVersion = writeVersion; if (key.size() == 0) { + auto l = lastLeq(self, key); self->entryPresent = true; self->entry.pointVersion = writeVersion; - // TODO set correct rangeVersion + assert(l.n != nullptr); + assert(l.n->entryPresent); + self->entry.rangeVersion = l.n->entry.rangeVersion; return; } auto &child = getOrCreateChild(self, key.front()); @@ -928,7 +1052,25 @@ static void insert(Node **self_, std::span key, } struct __attribute__((visibility("hidden"))) ConflictSet::Impl { - void check(const ReadRange *, Result *, int) const {} + void check(const ReadRange *reads, Result *result, int count) const { + for (int i = 0; i < count; ++i) { + const auto &r = reads[i]; + if (r.readVersion < oldestVersion) { + result[i] = TooOld; + continue; + } + // TODO support non-point reads + assert(r.end.len == 0); + auto [l, c] = + lastLeq(root, std::span(r.begin.p, r.begin.len)); + assert(l != nullptr); + assert(l->entryPresent); + result[i] = (c == 0 ? l->entry.pointVersion : l->entry.rangeVersion) > + r.readVersion + ? Conflict + : Commit; + } + } void addWrites(const WriteRange *writes, int count) { for (int i = 0; i < count; ++i) { const auto &w = writes[i]; @@ -1092,6 +1234,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { writes[i].begin.len = iter->size(); writes[i].end.len = 0; writes[i].writeVersion = v; +#if DEBUG + printf("Write: {%.*s} -> %d\n", writes[i].begin.len, writes[i].begin.p, + int(writes[i].writeVersion)); +#endif } cs.addWrites(writes, numWrites); refImpl.addWrites(writes, numWrites); @@ -1121,6 +1267,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { reads[i].begin.len = iter->size(); reads[i].end.len = 0; reads[i].readVersion = v; +#if DEBUG + printf("Read: {%.*s} at %d\n", reads[i].begin.len, reads[i].begin.p, + int(reads[i].readVersion)); +#endif } auto *results1 = new (arena) ConflictSet::Result[numReads]; auto *results2 = new (arena) ConflictSet::Result[numReads];