diff --git a/ConflictSet.cpp b/ConflictSet.cpp index c4898c5..63d5eb4 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -35,7 +36,7 @@ enum class Type : int8_t { struct Node { /* begin section that's copied to the next node */ Node *parent = nullptr; - int64_t maxVersion = std::numeric_limits::lowest(); + int64_t maxVersion; Entry entry; int16_t numChildren = 0; bool entryPresent = false; @@ -85,6 +86,16 @@ struct BitSet { } } + void reset(int i) { + assert(0 <= i); + assert(i < 256); + if (i < 128) { + lo &= ~(__uint128_t(1) << i); + } else { + hi &= ~(__uint128_t(1) << (i - 128)); + } + } + int firstSetGeq(int i) const { if (i < 128) { int a = std::countr_zero(lo >> i); @@ -552,6 +563,7 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) { (self->numChildren - (nodeIndex + 1))); } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); + self48->bitSet.reset(index); int8_t toRemoveChildrenIndex = std::exchange(self48->index[index], -1); int8_t lastChildrenIndex = --self48->nextFree; assert(toRemoveChildrenIndex >= 0); @@ -564,6 +576,7 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) { } } else { auto *self256 = static_cast(self); + self256->bitSet.reset(index); self256->children[index] = nullptr; } --self->numChildren; @@ -762,6 +775,7 @@ outerLoop: auto &child = getOrCreateChild(self, key.front()); if (!child) { child = newNode(); + child->maxVersion = writeVersion; child->parent = self; child->parentsIndex = key.front(); } @@ -816,18 +830,45 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { void addWrites(const WriteRange *writes, int count) { for (int i = 0; i < count; ++i) { const auto &w = writes[i]; - // TODO support non-point writes - assert(w.end.len == 0); - auto *n = insert(&root, std::span(w.begin.p, w.begin.len), - w.writeVersion); - if (!n->entryPresent) { - auto *p = prevLogical(n); - assert(p != nullptr); + if (w.end.len > 0) { + auto *n = insert(&root, std::span(w.end.p, w.end.len), + std::numeric_limits::lowest()); + if (!n->entryPresent) { + auto *p = prevLogical(n); + assert(p != nullptr); + n->entryPresent = true; + n->entry.pointVersion = p->entry.rangeVersion; + n->entry.rangeVersion = p->entry.rangeVersion; + n->maxVersion = p->entry.rangeVersion; + } + + auto *end = n; + n = insert(&root, std::span(w.begin.p, w.begin.len), + w.writeVersion); n->entryPresent = true; n->entry.pointVersion = w.writeVersion; - n->entry.rangeVersion = p->entry.rangeVersion; + n->entry.rangeVersion = w.writeVersion; + for (n = nextLogical(n); n != end;) { + auto *old = n; + n = nextLogical(n); + if (old->numChildren == 0 && old->parent != nullptr) { + eraseChild(old->parent, old->parentsIndex); + } + } } else { - n->entry.pointVersion = std::max(n->entry.pointVersion, w.writeVersion); + auto *n = + insert(&root, std::span(w.begin.p, w.begin.len), + w.writeVersion); + if (!n->entryPresent) { + auto *p = prevLogical(n); + assert(p != nullptr); + n->entryPresent = true; + n->entry.pointVersion = w.writeVersion; + n->entry.rangeVersion = p->entry.rangeVersion; + } else { + n->entry.pointVersion = + std::max(n->entry.pointVersion, w.writeVersion); + } } } } @@ -952,12 +993,14 @@ void printLogical(std::string &result, Node *node) { assert(n != nullptr); auto partialKey = printable(Key{n->partialKey, n->partialKeyLen}); if (n->entryPresent) { - fprintf(file, " k_%p [label=\"m=%d p=%d r=%d %s\"];\n", (void *)n, - int(n->maxVersion), int(n->entry.pointVersion), - int(n->entry.rangeVersion), partialKey.c_str()); + fprintf(file, + " k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64 + " %s\"];\n", + (void *)n, n->maxVersion, n->entry.pointVersion, + n->entry.rangeVersion, partialKey.c_str()); } else { - fprintf(file, " k_%p [label=\"m=%d %s\"];\n", (void *)n, - int(n->maxVersion), partialKey.c_str()); + fprintf(file, " k_%p [label=\"m=%" PRId64 " %s\"];\n", (void *)n, + n->maxVersion, partialKey.c_str()); } for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { @@ -1002,9 +1045,9 @@ void checkParentPointers(Node *node, bool &success) { } if (node->maxVersion != expected) { Arena arena; - fprintf(stderr, "%s has max version %d. Expected %d\n", - printable(getSearchPath(arena, node)).c_str(), - int(node->maxVersion), int(expected)); + fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n", + printable(getSearchPath(arena, node)).c_str(), node->maxVersion, + expected); success = false; } return expected; diff --git a/Internal.h b/Internal.h index 0a08ff9..e1f73d4 100644 --- a/Internal.h +++ b/Internal.h @@ -14,7 +14,7 @@ #include #include -#define DEBUG_VERBOSE 0 +#define DEBUG_VERBOSE 1 // This header contains code that we want to reuse outside of ConflictSet.cpp or // want to exclude from coverage since it's only testing related. @@ -459,11 +459,11 @@ template struct TestDriver { } Arena arena; { - int numWrites = arbitrary.bounded(kMaxKeyLen); + int numWriteKeys = arbitrary.bounded(10); int64_t v = ++writeVersion; - auto *writes = new (arena) ConflictSet::WriteRange[numWrites]; + auto *writes = new (arena) ConflictSet::WriteRange[numWriteKeys]; auto keys = set(arena); - while (int(keys.size()) < numWrites) { + while (int(keys.size()) < numWriteKeys) { if (!arbitrary.hasEntropy()) { return true; } @@ -473,16 +473,31 @@ template struct TestDriver { keys.insert(std::string_view((const char *)begin, keyLen)); } auto iter = keys.begin(); - for (int i = 0; i < numWrites; ++i) { - writes[i].begin.p = (const uint8_t *)iter->data(); - writes[i].begin.len = iter->size(); + int numWrites = 0; + for (int i = 0; i < numWriteKeys; ++i, ++numWrites) { + writes[numWrites].begin.p = (const uint8_t *)iter->data(); + writes[numWrites].begin.len = iter->size(); ++iter; - writes[i].end.len = 0; - writes[i].writeVersion = v; + if (i + 1 < numWriteKeys && arbitrary.bounded(2)) { + ++i; + writes[numWrites].end.p = (const uint8_t *)iter->data(); + writes[numWrites].end.len = iter->size(); + ++iter; + } else { + writes[numWrites].end.len = 0; + } + writes[numWrites].writeVersion = v; #if DEBUG_VERBOSE && !defined(NDEBUG) - fprintf(stderr, "Write: {%s} -> %d\n", - printable(writes[i].begin).c_str(), - int(writes[i].writeVersion)); + if (writes[numWrites].end.len == 0) { + fprintf(stderr, "Write: {%s} -> %d\n", + printable(writes[numWrites].begin).c_str(), + int(writes[numWrites].writeVersion)); + } else { + fprintf(stderr, "Write: [%s, %s) -> %d\n", + printable(writes[numWrites].begin).c_str(), + printable(writes[numWrites].end).c_str(), + int(writes[numWrites].writeVersion)); + } #endif } assert(iter == keys.end());