diff --git a/ConflictSet.cpp b/ConflictSet.cpp index d35a52f..ca009b2 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -37,9 +37,12 @@ struct Node { Node *parent = nullptr; int64_t maxVersion = std::numeric_limits::lowest(); Entry entry; + constexpr static auto kCompressedKeyMaxLen = 18; int16_t numChildren = 0; bool entryPresent = false; uint8_t parentsIndex = 0; + uint8_t compressedKey[kCompressedKeyMaxLen]; + int8_t compressedKeyLen = 0; /* end section that's copied to the next node */ Type type = Type::Invalid; @@ -723,12 +726,16 @@ struct Iterator { }; std::string_view getSearchPath(Arena &arena, Node *n) { - if (n->parent == nullptr) { - return {}; - } auto result = vector(arena); - for (; n->parent != nullptr; n = n->parent) { + for (;;) { + for (int i = n->compressedKeyLen - 1; i >= 0; --i) { + result.push_back(n->compressedKey[i]); + } + if (n->parent == nullptr) { + break; + } result.push_back(n->parentsIndex); + n = n->parent; } std::reverse(result.begin(), result.end()); return std::string_view((const char *)&result[0], result.size()); // NOLINT @@ -738,6 +745,24 @@ Iterator lastLeq(Node *n, const std::span key) { auto remaining = key; for (;;) { Arena arena; + int commonLen = std::min(n->compressedKeyLen, remaining.size()); + if (commonLen > Node::kCompressedKeyMaxLen) { + __builtin_unreachable(); + } + int c = memcmp(n->compressedKey, remaining.data(), commonLen); + if (c == 0 && commonLen == n->compressedKeyLen) { + // Compressed key matches + remaining = remaining.subspan(commonLen, remaining.size() - commonLen); + } else if (c < 0 || + (c == 0 && n->compressedKeyLen < int(remaining.size()))) { + // n is the last physical node less than remaining, and there's no eq node + break; + } else if (c > 0) { + // n is the first physical node greater than remaining, and there's no eq + // node + n = prevPhysical(n); + break; + } assert((std::string(getSearchPath(arena, n)) + std::string((const char *)remaining.data(), remaining.size())) .ends_with(std::string((const char *)key.data(), key.size()))); @@ -781,6 +806,38 @@ void insert(Node **self_, std::span key, int64_t writeVersion) { for (;;) { auto &self = *self_; self->maxVersion = std::max(self->maxVersion, writeVersion); + int commonLen = std::min(self->compressedKeyLen, key.size()); + // Handle an existing compressed key + int compressedKeyIndex = 0; + for (; compressedKeyIndex < commonLen; ++compressedKeyIndex) { + if (self->compressedKey[compressedKeyIndex] != key[compressedKeyIndex]) { + auto *old = self; + self = newNode(); + memcpy((void *)self, old, offsetof(Node, type)); + self->entryPresent = false; + + getOrCreateChild(self, old->compressedKey[compressedKeyIndex]) = old; + old->parent = self; + old->parentsIndex = old->compressedKey[compressedKeyIndex]; + self->compressedKeyLen = compressedKeyIndex; + + memmove(old->compressedKey, old->compressedKey + compressedKeyIndex + 1, + old->compressedKeyLen - (compressedKeyIndex + 1)); + old->compressedKeyLen -= compressedKeyIndex + 1; + break; + } + } + key = key.subspan(compressedKeyIndex, key.size() - compressedKeyIndex); + + // Consider adding a compressed key + if (self->numChildren == 0 && !self->entryPresent) { + self->compressedKeyLen = + std::min(key.size(), self->kCompressedKeyMaxLen); + memcpy(self->compressedKey, key.data(), self->compressedKeyLen); + key = key.subspan(self->compressedKeyLen, + key.size() - self->compressedKeyLen); + } + if (key.size() == 0) { auto l = lastLeq(self, key); self->entryPresent = true; @@ -972,18 +1029,20 @@ void printLogical(std::string &result, Node *node) { void print(Node *n) { assert(n != nullptr); + auto compressedKey = + printable(Key{n->compressedKey, n->compressedKeyLen}); if (n->entryPresent) { - fprintf(file, " k_%p [label=\"m=%d p=%d r=%d\"];\n", (void *)n, + fprintf(file, " k_%p [label=\"m=%d p=%d r=%d %s\"];\n", (void *)n, int(n->maxVersion), int(n->entry.pointVersion), - int(n->entry.rangeVersion)); + int(n->entry.rangeVersion), compressedKey.c_str()); } else { - fprintf(file, " k_%p [label=\"m=%d\"];\n", (void *)n, - int(n->maxVersion)); + fprintf(file, " k_%p [label=\"m=%d %s\"];\n", (void *)n, + int(n->maxVersion), compressedKey.c_str()); } for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { auto *c = getChildExists(n, child); - fprintf(file, " k_%p -> k_%p [label=\"'%02x'\"];\n", (void *)n, + fprintf(file, " k_%p -> k_%p [label=\"x%02x\"];\n", (void *)n, (void *)c, child); print(c); } @@ -998,6 +1057,29 @@ void printLogical(std::string &result, Node *node) { fprintf(file, "}\n"); } +void checkCompressedKey(Node *node, bool &success) { + if (node->numChildren == 1 && + node->compressedKeyLen < node->kCompressedKeyMaxLen) { + Arena arena; + fprintf(stderr, "%s has 1 child and %d < %d compressed key bytes\n", + printable(getSearchPath(arena, node)).c_str(), + int(node->compressedKeyLen), int(node->kCompressedKeyMaxLen)); + + success = false; + } + for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { + auto *child = getChildExists(node, i); + if (child->parent != node) { + Arena arena; + fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n", + printable(getSearchPath(arena, node)).c_str(), i, + (void *)child->parent, (void *)node); + success = false; + } + checkCompressedKey(child, success); + } +} + void checkParentPointers(Node *node, bool &success) { for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); @@ -1035,6 +1117,7 @@ bool checkCorrectness(Node *node, ReferenceImpl &refImpl) { bool success = true; checkParentPointers(node, success); + checkCompressedKey(node, success); std::string logicalMap; std::string referenceLogicalMap; @@ -1096,6 +1179,10 @@ void printTree() { write[i].writeVersion = ++writeVersion; } cs.addWrites(write, kNumKeys); + for (int i = 0; i < kNumKeys; ++i) { + write[i].writeVersion = ++writeVersion; + } + cs.addWrites(write, kNumKeys); debugPrintDot(stdout, cs.root); }