From d91538dcad3381deeab77929f46404f07f0e8ef7 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Fri, 8 Mar 2024 13:50:40 -0800 Subject: [PATCH] Variable length partial keys --- ConflictSet.cpp | 128 +++++++++++++++++++++++------------------------- 1 file changed, 62 insertions(+), 66 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 3d7b59c..cc07878 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -168,19 +168,11 @@ enum class Type : int8_t { Node256, }; -constexpr static int kPartialKeyMaxLenEntryPresent = 24; - struct Node { /* begin section that's copied to the next node */ Node *parent = nullptr; - union { - uint8_t partialKey[kPartialKeyMaxLenEntryPresent + sizeof(Entry)]; - struct { - uint8_t padding[kPartialKeyMaxLenEntryPresent]; - Entry entry; - }; - }; + Entry entry; int32_t partialKeyLen = 0; int16_t numChildren : 15 = 0; bool entryPresent : 1 = false; @@ -188,15 +180,14 @@ struct Node { /* end section that's copied to the next node */ Type type; + int32_t partialKeyCapacity; + + uint8_t *partialKey(); }; constexpr int kNodeCopyBegin = offsetof(Node, parent); constexpr int kNodeCopySize = offsetof(Node, type) - kNodeCopyBegin; -static_assert(offsetof(Node, entry) == - offsetof(Node, partialKey) + kPartialKeyMaxLenEntryPresent); -static_assert(std::is_trivial_v); - struct Child { int64_t childMaxVersion; Node *child; @@ -246,8 +237,25 @@ struct Node256 : Node { } }; -template NodeT *newNode() { - return new (safe_malloc(sizeof(NodeT))) NodeT; +template NodeT *newNode(int partialKeyCapacity) { + auto *result = new (safe_malloc(sizeof(NodeT) + partialKeyCapacity)) NodeT; + result->partialKeyCapacity = partialKeyCapacity; + return result; +} + +uint8_t *Node::partialKey() { + switch (type) { + case Type::Node0: + return (uint8_t *)((Node0 *)this + 1); + case Type::Node4: + return (uint8_t *)((Node4 *)this + 1); + case Type::Node16: + return (uint8_t *)((Node16 *)this + 1); + case Type::Node48: + return (uint8_t *)((Node48 *)this + 1); + case Type::Node256: + return (uint8_t *)((Node256 *)this + 1); + } } int getNodeIndex(Node16 *self, uint8_t index) { @@ -481,9 +489,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) { if (self->type == Type::Node0) { auto *self0 = static_cast(self); - auto *newSelf = newNode(); + auto *newSelf = newNode(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); + memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen); free(self0); self = newSelf; @@ -493,9 +502,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) { auto *self4 = static_cast(self); if (self->numChildren == 4) { - auto *newSelf = newNode(); + auto *newSelf = newNode(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); + memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen); // TODO replace with memcpy? for (int i = 0; i < 4; ++i) { newSelf->index[i] = self4->index[i]; @@ -512,9 +522,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) { if (self->numChildren == 16) { auto *self16 = static_cast(self); - auto *newSelf = newNode(); + auto *newSelf = newNode(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); + memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen); newSelf->nextFree = 16; int i = 0; for (auto x : self16->index) { @@ -552,9 +563,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) { if (self->numChildren == 48) { auto *self48 = static_cast(self); - auto *newSelf = newNode(); + auto *newSelf = newNode(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); + memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen); newSelf->bitSet = self48->bitSet; newSelf->bitSet.forEachInRange( [&](int i) { @@ -819,14 +831,7 @@ bytes: int longestCommonPrefixPartialKey(const uint8_t *ap, const uint8_t *bp, int cl) { - assert(cl <= kPartialKeyMaxLenEntryPresent + int(sizeof(Entry))); - int i = 0; - for (; i < cl; ++i) { - if (*ap++ != *bp++) { - break; - } - } - return i; + return longestCommonPrefix(ap, bp, cl); } // Performs a physical search for remaining @@ -849,7 +854,7 @@ struct SearchStepWise { return true; } int cl = std::min(child->partialKeyLen, remaining.size() - 1); - int i = longestCommonPrefixPartialKey(child->partialKey, + int i = longestCommonPrefixPartialKey(child->partialKey(), remaining.data() + 1, cl); if (i != child->partialKeyLen) { return true; @@ -906,10 +911,10 @@ bool checkPointRead(Node *n, const std::span key, if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(), commonLen); if (i < commonLen) { - auto c = n->partialKey[i] <=> remaining[i]; + auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { @@ -1008,7 +1013,7 @@ Vector getSearchPath(Arena &arena, Node *n) { auto result = vector(arena); for (;;) { for (int i = n->partialKeyLen - 1; i >= 0; --i) { - result.push_back(n->partialKey[i]); + result.push_back(n->partialKey()[i]); } if (n->parent == nullptr) { break; @@ -1054,10 +1059,10 @@ bool checkRangeStartsWith(Node *n, std::span key, int begin, if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(), commonLen); if (i < commonLen) { - auto c = n->partialKey[i] <=> remaining[i]; + auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { @@ -1069,8 +1074,8 @@ bool checkRangeStartsWith(Node *n, std::span key, int begin, // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { - if (begin < n->partialKey[remaining.size()] && - n->partialKey[remaining.size()] < end) { + if (begin < n->partialKey()[remaining.size()] && + n->partialKey()[remaining.size()] < end) { if (n->entryPresent && n->entry.rangeVersion > readVersion) { return false; } @@ -1161,11 +1166,11 @@ struct CheckRangeLeftSide { if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(), commonLen); searchPathLen += i; if (i < commonLen) { - auto c = n->partialKey[i] <=> remaining[i]; + auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { if (searchPathLen < prefixLen) { return downLeftSpine(); @@ -1299,12 +1304,12 @@ struct CheckRangeRightSide { if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(), commonLen); searchPathLen += i; if (i < commonLen) { ++searchPathLen; - auto c = n->partialKey[i] <=> remaining[i]; + auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { return downLeftSpine(); } else { @@ -1455,33 +1460,30 @@ template for (;;) { if ((*self)->partialKeyLen > 0) { - const bool wouldBePresent = - key.size() <= kPartialKeyMaxLenEntryPresent + int(sizeof(Entry)); // Handle an existing partial key int commonLen = std::min((*self)->partialKeyLen, key.size()); - if (wouldBePresent) { - commonLen = std::min(commonLen, kPartialKeyMaxLenEntryPresent); - } int partialKeyIndex = longestCommonPrefixPartialKey( - (*self)->partialKey, key.data(), commonLen); + (*self)->partialKey(), key.data(), commonLen); if (partialKeyIndex < (*self)->partialKeyLen) { auto *old = *self; int64_t oldMaxVersion = maxVersion(old, impl); - *self = newNode(); + *self = newNode(partialKeyIndex); memcpy((char *)*self + kNodeCopyBegin, (char *)old + kNodeCopyBegin, kNodeCopySize); (*self)->partialKeyLen = partialKeyIndex; (*self)->entryPresent = false; (*self)->numChildren = 0; + memcpy((*self)->partialKey(), old->partialKey(), + (*self)->partialKeyLen); - getOrCreateChild(*self, old->partialKey[partialKeyIndex]) = old; + getOrCreateChild(*self, old->partialKey()[partialKeyIndex]) = old; old->parent = *self; - old->parentsIndex = old->partialKey[partialKeyIndex]; + old->parentsIndex = old->partialKey()[partialKeyIndex]; maxVersion(old, impl) = oldMaxVersion; - memmove(old->partialKey, old->partialKey + partialKeyIndex + 1, + memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1, old->partialKeyLen - (partialKeyIndex + 1)); old->partialKeyLen -= partialKeyIndex + 1; } @@ -1490,13 +1492,9 @@ template } else { // Consider adding a partial key if ((*self)->numChildren == 0 && !(*self)->entryPresent) { - const bool willNotBePresent = - key.size() > kPartialKeyMaxLenEntryPresent + int(sizeof(Entry)); - (*self)->partialKeyLen = std::min( - key.size(), willNotBePresent - ? kPartialKeyMaxLenEntryPresent + int(sizeof(Entry)) - : kPartialKeyMaxLenEntryPresent); - memcpy((*self)->partialKey, key.data(), (*self)->partialKeyLen); + (*self)->partialKeyLen = + std::min(key.size(), (*self)->partialKeyCapacity); + memcpy((*self)->partialKey(), key.data(), (*self)->partialKeyLen); key = key.subspan((*self)->partialKeyLen, key.size() - (*self)->partialKeyLen); } @@ -1520,7 +1518,7 @@ template auto &child = getOrCreateChild(*self, key.front()); if (!child) { - child = newNode(); + child = newNode(key.size() - 1); child->parent = *self; child->parentsIndex = key.front(); maxVersion(child, impl) = @@ -1585,7 +1583,7 @@ void addWriteRange(Node *&root, int64_t oldestVersion, if (int(remaining.size()) <= n->partialKeyLen) { break; } - int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(), n->partialKeyLen); if (i != n->partialKeyLen) { break; @@ -1692,10 +1690,10 @@ Iterator firstGeq(Node *n, const std::span key) { if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); - int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), + int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(), commonLen); if (i < commonLen) { - auto c = n->partialKey[i] <=> remaining[i]; + auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { @@ -1799,7 +1797,7 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) { // Insert "" - root = newNode(); + root = newNode(0); rootMaxVersion = oldestVersion; root->entry.pointVersion = oldestVersion; root->entry.rangeVersion = oldestVersion; @@ -1916,7 +1914,7 @@ std::string getSearchPathPrintable(Node *n) { auto result = vector(arena); for (;;) { for (int i = n->partialKeyLen - 1; i >= 0; --i) { - result.push_back(n->partialKey[i]); + result.push_back(n->partialKey()[i]); } if (n->parent == nullptr) { break; @@ -1940,7 +1938,7 @@ std::string getPartialKeyPrintable(Node *n) { } auto result = std::string((const char *)&n->parentsIndex, n->parent == nullptr ? 0 : 1) + - std::string((const char *)n->partialKey, n->partialKeyLen); + std::string((const char *)n->partialKey(), n->partialKeyLen); return printable(result); // NOLINT } @@ -2141,7 +2139,7 @@ int main(void) { ankerl::nanobench::Bench bench; ConflictSet::Impl cs{0}; for (int j = 0; j < 256; ++j) { - getOrCreateChild(cs.root, j) = newNode(); + getOrCreateChild(cs.root, j) = newNode(0); if (j % 10 == 0) { bench.run("MaxExclusive " + std::to_string(j), [&]() { bench.doNotOptimizeAway(maxBetweenExclusive(cs.root, 0, 256)); @@ -2155,8 +2153,6 @@ int main(void) { #ifdef ENABLE_FUZZ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { TestDriver driver{data, size}; - static_assert(driver.kMaxKeyLen > - kPartialKeyMaxLenEntryPresent + sizeof(Entry)); for (;;) { bool done = driver.next();