diff --git a/ConflictSet.cpp b/ConflictSet.cpp index ca009b2..1bb5540 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -37,12 +37,12 @@ struct Node { Node *parent = nullptr; int64_t maxVersion = std::numeric_limits::lowest(); Entry entry; - constexpr static auto kCompressedKeyMaxLen = 18; + constexpr static auto kPartialKeyMaxLen = 18; int16_t numChildren = 0; bool entryPresent = false; uint8_t parentsIndex = 0; - uint8_t compressedKey[kCompressedKeyMaxLen]; - int8_t compressedKeyLen = 0; + uint8_t partialKey[kPartialKeyMaxLen]; + int8_t partialKeyLen = 0; /* end section that's copied to the next node */ Type type = Type::Invalid; @@ -728,8 +728,8 @@ struct Iterator { std::string_view getSearchPath(Arena &arena, Node *n) { auto result = vector(arena); for (;;) { - for (int i = n->compressedKeyLen - 1; i >= 0; --i) { - result.push_back(n->compressedKey[i]); + for (int i = n->partialKeyLen - 1; i >= 0; --i) { + result.push_back(n->partialKey[i]); } if (n->parent == nullptr) { break; @@ -738,26 +738,29 @@ std::string_view getSearchPath(Arena &arena, Node *n) { n = n->parent; } std::reverse(result.begin(), result.end()); - return std::string_view((const char *)&result[0], result.size()); // NOLINT + if (result.size() > 0) { + return std::string_view((const char *)&result[0], result.size()); // NOLINT + } else { + return std::string_view(); + } } Iterator lastLeq(Node *n, const std::span key) { auto remaining = key; for (;;) { Arena arena; - int commonLen = std::min(n->compressedKeyLen, remaining.size()); - if (commonLen > Node::kCompressedKeyMaxLen) { + int commonLen = std::min(n->partialKeyLen, remaining.size()); + if (commonLen > Node::kPartialKeyMaxLen) { __builtin_unreachable(); } - int c = memcmp(n->compressedKey, remaining.data(), commonLen); - if (c == 0 && commonLen == n->compressedKeyLen) { - // Compressed key matches + int c = memcmp(n->partialKey, remaining.data(), commonLen); + if (c == 0 && commonLen == n->partialKeyLen) { + // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); - } else if (c < 0 || - (c == 0 && n->compressedKeyLen < int(remaining.size()))) { + } else if (c < 0 || (c == 0 && n->partialKeyLen < int(remaining.size()))) { // n is the last physical node less than remaining, and there's no eq node break; - } else if (c > 0) { + } else if (c > 0 || (c == 0 && n->partialKeyLen > int(remaining.size()))) { // n is the first physical node greater than remaining, and there's no eq // node n = prevPhysical(n); @@ -805,39 +808,40 @@ Iterator lastLeq(Node *n, const std::span key) { void insert(Node **self_, std::span key, int64_t writeVersion) { for (;;) { auto &self = *self_; - self->maxVersion = std::max(self->maxVersion, writeVersion); - int commonLen = std::min(self->compressedKeyLen, key.size()); - // Handle an existing compressed key - int compressedKeyIndex = 0; - for (; compressedKeyIndex < commonLen; ++compressedKeyIndex) { - if (self->compressedKey[compressedKeyIndex] != key[compressedKeyIndex]) { + // Handle an existing partial key + int partialKeyIndex = 0; + for (; partialKeyIndex < self->partialKeyLen; ++partialKeyIndex) { + if (partialKeyIndex == int(key.size()) || + self->partialKey[partialKeyIndex] != key[partialKeyIndex]) { auto *old = self; self = newNode(); - memcpy((void *)self, old, offsetof(Node, type)); - self->entryPresent = false; + self->maxVersion = old->maxVersion; + self->partialKeyLen = partialKeyIndex; + self->parent = old->parent; + self->parentsIndex = old->parentsIndex; + memcpy(self->partialKey, old->partialKey, partialKeyIndex); - getOrCreateChild(self, old->compressedKey[compressedKeyIndex]) = old; + getOrCreateChild(self, old->partialKey[partialKeyIndex]) = old; old->parent = self; - old->parentsIndex = old->compressedKey[compressedKeyIndex]; - self->compressedKeyLen = compressedKeyIndex; + old->parentsIndex = old->partialKey[partialKeyIndex]; - memmove(old->compressedKey, old->compressedKey + compressedKeyIndex + 1, - old->compressedKeyLen - (compressedKeyIndex + 1)); - old->compressedKeyLen -= compressedKeyIndex + 1; + memmove(old->partialKey, old->partialKey + partialKeyIndex + 1, + old->partialKeyLen - (partialKeyIndex + 1)); + old->partialKeyLen -= partialKeyIndex + 1; break; } } - key = key.subspan(compressedKeyIndex, key.size() - compressedKeyIndex); + key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex); - // Consider adding a compressed key + // Consider adding a partial key if (self->numChildren == 0 && !self->entryPresent) { - self->compressedKeyLen = - std::min(key.size(), self->kCompressedKeyMaxLen); - memcpy(self->compressedKey, key.data(), self->compressedKeyLen); - key = key.subspan(self->compressedKeyLen, - key.size() - self->compressedKeyLen); + self->partialKeyLen = std::min(key.size(), self->kPartialKeyMaxLen); + memcpy(self->partialKey, key.data(), self->partialKeyLen); + key = key.subspan(self->partialKeyLen, key.size() - self->partialKeyLen); } + self->maxVersion = std::max(self->maxVersion, writeVersion); + if (key.size() == 0) { auto l = lastLeq(self, key); self->entryPresent = true; @@ -1029,15 +1033,14 @@ void printLogical(std::string &result, Node *node) { void print(Node *n) { assert(n != nullptr); - auto compressedKey = - printable(Key{n->compressedKey, n->compressedKeyLen}); + auto partialKey = printable(Key{n->partialKey, n->partialKeyLen}); if (n->entryPresent) { fprintf(file, " k_%p [label=\"m=%d p=%d r=%d %s\"];\n", (void *)n, int(n->maxVersion), int(n->entry.pointVersion), - int(n->entry.rangeVersion), compressedKey.c_str()); + int(n->entry.rangeVersion), partialKey.c_str()); } else { fprintf(file, " k_%p [label=\"m=%d %s\"];\n", (void *)n, - int(n->maxVersion), compressedKey.c_str()); + int(n->maxVersion), partialKey.c_str()); } for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { @@ -1057,29 +1060,6 @@ void printLogical(std::string &result, Node *node) { fprintf(file, "}\n"); } -void checkCompressedKey(Node *node, bool &success) { - if (node->numChildren == 1 && - node->compressedKeyLen < node->kCompressedKeyMaxLen) { - Arena arena; - fprintf(stderr, "%s has 1 child and %d < %d compressed key bytes\n", - printable(getSearchPath(arena, node)).c_str(), - int(node->compressedKeyLen), int(node->kCompressedKeyMaxLen)); - - success = false; - } - for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { - auto *child = getChildExists(node, i); - if (child->parent != node) { - Arena arena; - fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n", - printable(getSearchPath(arena, node)).c_str(), i, - (void *)child->parent, (void *)node); - success = false; - } - checkCompressedKey(child, success); - } -} - void checkParentPointers(Node *node, bool &success) { for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); @@ -1117,7 +1097,7 @@ bool checkCorrectness(Node *node, ReferenceImpl &refImpl) { bool success = true; checkParentPointers(node, success); - checkCompressedKey(node, success); + checkMaxVersion(node, success); std::string logicalMap; std::string referenceLogicalMap; @@ -1196,10 +1176,13 @@ int main(void) { #ifdef ENABLE_FUZZ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { TestDriver driver{data, size}; + static_assert(driver.kMaxKeyLen > Node::kPartialKeyMaxLen); do { bool success = checkCorrectness(driver.cs.root, driver.refImpl); if (!success) { + debugPrintDot(stdout, driver.cs.root); + fflush(stdout); abort(); } } while (!driver.next()); diff --git a/Internal.h b/Internal.h index bf36daa..0a08ff9 100644 --- a/Internal.h +++ b/Internal.h @@ -449,6 +449,8 @@ template struct TestDriver { ConflictSetImpl cs{writeVersion}; ReferenceImpl refImpl{writeVersion}; + constexpr static auto kMaxKeyLen = 24; + // Call until it returns true, for "done". Check internal invariants etc // between calls to next. bool next() { @@ -457,7 +459,7 @@ template struct TestDriver { } Arena arena; { - int numWrites = arbitrary.bounded(10); + int numWrites = arbitrary.bounded(kMaxKeyLen); int64_t v = ++writeVersion; auto *writes = new (arena) ConflictSet::WriteRange[numWrites]; auto keys = set(arena); @@ -465,7 +467,7 @@ template struct TestDriver { if (!arbitrary.hasEntropy()) { return true; } - int keyLen = arbitrary.bounded(8); + int keyLen = arbitrary.bounded(kMaxKeyLen); auto *begin = new (arena) uint8_t[keyLen]; arbitrary.randomBytes(begin, keyLen); keys.insert(std::string_view((const char *)begin, keyLen)); @@ -478,8 +480,9 @@ template struct TestDriver { writes[i].end.len = 0; writes[i].writeVersion = v; #if DEBUG_VERBOSE && !defined(NDEBUG) - printf("Write: {%s} -> %d\n", printable(writes[i].begin).c_str(), - int(writes[i].writeVersion)); + fprintf(stderr, "Write: {%s} -> %d\n", + printable(writes[i].begin).c_str(), + int(writes[i].writeVersion)); #endif } assert(iter == keys.end()); @@ -488,14 +491,14 @@ template struct TestDriver { } { int numReads = arbitrary.bounded(10); - int64_t v = writeVersion - arbitrary.bounded(10); + int64_t v = std::max(writeVersion - arbitrary.bounded(10), 0); auto *reads = new (arena) ConflictSet::ReadRange[numReads]; auto keys = set(arena); while (int(keys.size()) < numReads) { if (!arbitrary.hasEntropy()) { return true; } - int keyLen = arbitrary.bounded(8); + int keyLen = arbitrary.bounded(kMaxKeyLen); auto *begin = new (arena) uint8_t[keyLen]; arbitrary.randomBytes(begin, keyLen); keys.insert(std::string_view((const char *)begin, keyLen)); @@ -508,8 +511,8 @@ template struct TestDriver { reads[i].end.len = 0; reads[i].readVersion = v; #if DEBUG_VERBOSE && !defined(NDEBUG) - printf("Read: {%s} at %d\n", printable(reads[i].begin).c_str(), - int(reads[i].readVersion)); + fprintf(stderr, "Read: {%s} at %d\n", printable(reads[i].begin).c_str(), + int(reads[i].readVersion)); #endif } assert(iter == keys.end());