Variable length partial keys

This commit is contained in:
2024-03-08 13:50:40 -08:00
parent 43a768d152
commit d91538dcad

View File

@@ -168,19 +168,11 @@ enum class Type : int8_t {
Node256, Node256,
}; };
constexpr static int kPartialKeyMaxLenEntryPresent = 24;
struct Node { struct Node {
/* begin section that's copied to the next node */ /* begin section that's copied to the next node */
Node *parent = nullptr; Node *parent = nullptr;
union { Entry entry;
uint8_t partialKey[kPartialKeyMaxLenEntryPresent + sizeof(Entry)];
struct {
uint8_t padding[kPartialKeyMaxLenEntryPresent];
Entry entry;
};
};
int32_t partialKeyLen = 0; int32_t partialKeyLen = 0;
int16_t numChildren : 15 = 0; int16_t numChildren : 15 = 0;
bool entryPresent : 1 = false; bool entryPresent : 1 = false;
@@ -188,15 +180,14 @@ struct Node {
/* end section that's copied to the next node */ /* end section that's copied to the next node */
Type type; Type type;
int32_t partialKeyCapacity;
uint8_t *partialKey();
}; };
constexpr int kNodeCopyBegin = offsetof(Node, parent); constexpr int kNodeCopyBegin = offsetof(Node, parent);
constexpr int kNodeCopySize = offsetof(Node, type) - kNodeCopyBegin; constexpr int kNodeCopySize = offsetof(Node, type) - kNodeCopyBegin;
static_assert(offsetof(Node, entry) ==
offsetof(Node, partialKey) + kPartialKeyMaxLenEntryPresent);
static_assert(std::is_trivial_v<Entry>);
struct Child { struct Child {
int64_t childMaxVersion; int64_t childMaxVersion;
Node *child; Node *child;
@@ -246,8 +237,25 @@ struct Node256 : Node {
} }
}; };
template <class NodeT> NodeT *newNode() { template <class NodeT> NodeT *newNode(int partialKeyCapacity) {
return new (safe_malloc(sizeof(NodeT))) NodeT; auto *result = new (safe_malloc(sizeof(NodeT) + partialKeyCapacity)) NodeT;
result->partialKeyCapacity = partialKeyCapacity;
return result;
}
uint8_t *Node::partialKey() {
switch (type) {
case Type::Node0:
return (uint8_t *)((Node0 *)this + 1);
case Type::Node4:
return (uint8_t *)((Node4 *)this + 1);
case Type::Node16:
return (uint8_t *)((Node16 *)this + 1);
case Type::Node48:
return (uint8_t *)((Node48 *)this + 1);
case Type::Node256:
return (uint8_t *)((Node256 *)this + 1);
}
} }
int getNodeIndex(Node16 *self, uint8_t index) { int getNodeIndex(Node16 *self, uint8_t index) {
@@ -481,9 +489,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) {
if (self->type == Type::Node0) { if (self->type == Type::Node0) {
auto *self0 = static_cast<Node0 *>(self); auto *self0 = static_cast<Node0 *>(self);
auto *newSelf = newNode<Node4>(); auto *newSelf = newNode<Node4>(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize); kNodeCopySize);
memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen);
free(self0); free(self0);
self = newSelf; self = newSelf;
@@ -493,9 +502,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) {
auto *self4 = static_cast<Node4 *>(self); auto *self4 = static_cast<Node4 *>(self);
if (self->numChildren == 4) { if (self->numChildren == 4) {
auto *newSelf = newNode<Node16>(); auto *newSelf = newNode<Node16>(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize); kNodeCopySize);
memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen);
// TODO replace with memcpy? // TODO replace with memcpy?
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
newSelf->index[i] = self4->index[i]; newSelf->index[i] = self4->index[i];
@@ -512,9 +522,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) {
if (self->numChildren == 16) { if (self->numChildren == 16) {
auto *self16 = static_cast<Node16 *>(self); auto *self16 = static_cast<Node16 *>(self);
auto *newSelf = newNode<Node48>(); auto *newSelf = newNode<Node48>(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize); kNodeCopySize);
memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen);
newSelf->nextFree = 16; newSelf->nextFree = 16;
int i = 0; int i = 0;
for (auto x : self16->index) { for (auto x : self16->index) {
@@ -552,9 +563,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index) {
if (self->numChildren == 48) { if (self->numChildren == 48) {
auto *self48 = static_cast<Node48 *>(self); auto *self48 = static_cast<Node48 *>(self);
auto *newSelf = newNode<Node256>(); auto *newSelf = newNode<Node256>(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize); kNodeCopySize);
memcpy(newSelf->partialKey(), self->partialKey(), self->partialKeyLen);
newSelf->bitSet = self48->bitSet; newSelf->bitSet = self48->bitSet;
newSelf->bitSet.forEachInRange( newSelf->bitSet.forEachInRange(
[&](int i) { [&](int i) {
@@ -819,14 +831,7 @@ bytes:
int longestCommonPrefixPartialKey(const uint8_t *ap, const uint8_t *bp, int longestCommonPrefixPartialKey(const uint8_t *ap, const uint8_t *bp,
int cl) { int cl) {
assert(cl <= kPartialKeyMaxLenEntryPresent + int(sizeof(Entry))); return longestCommonPrefix(ap, bp, cl);
int i = 0;
for (; i < cl; ++i) {
if (*ap++ != *bp++) {
break;
}
}
return i;
} }
// Performs a physical search for remaining // Performs a physical search for remaining
@@ -849,7 +854,7 @@ struct SearchStepWise {
return true; return true;
} }
int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1); int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1);
int i = longestCommonPrefixPartialKey(child->partialKey, int i = longestCommonPrefixPartialKey(child->partialKey(),
remaining.data() + 1, cl); remaining.data() + 1, cl);
if (i != child->partialKeyLen) { if (i != child->partialKeyLen) {
return true; return true;
@@ -906,10 +911,10 @@ bool checkPointRead(Node *n, const std::span<const uint8_t> key,
if (n->partialKeyLen > 0) { if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size()); int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
commonLen); commonLen);
if (i < commonLen) { if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i]; auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) { if (c > 0) {
goto downLeftSpine; goto downLeftSpine;
} else { } else {
@@ -1008,7 +1013,7 @@ Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
auto result = vector<uint8_t>(arena); auto result = vector<uint8_t>(arena);
for (;;) { for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) { for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey[i]); result.push_back(n->partialKey()[i]);
} }
if (n->parent == nullptr) { if (n->parent == nullptr) {
break; break;
@@ -1054,10 +1059,10 @@ bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
if (n->partialKeyLen > 0) { if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size()); int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
commonLen); commonLen);
if (i < commonLen) { if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i]; auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) { if (c > 0) {
goto downLeftSpine; goto downLeftSpine;
} else { } else {
@@ -1069,8 +1074,8 @@ bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
// partial key matches // partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen); remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) { } else if (n->partialKeyLen > int(remaining.size())) {
if (begin < n->partialKey[remaining.size()] && if (begin < n->partialKey()[remaining.size()] &&
n->partialKey[remaining.size()] < end) { n->partialKey()[remaining.size()] < end) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) { if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false; return false;
} }
@@ -1161,11 +1166,11 @@ struct CheckRangeLeftSide {
if (n->partialKeyLen > 0) { if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size()); int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
commonLen); commonLen);
searchPathLen += i; searchPathLen += i;
if (i < commonLen) { if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i]; auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) { if (c > 0) {
if (searchPathLen < prefixLen) { if (searchPathLen < prefixLen) {
return downLeftSpine(); return downLeftSpine();
@@ -1299,12 +1304,12 @@ struct CheckRangeRightSide {
if (n->partialKeyLen > 0) { if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size()); int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
commonLen); commonLen);
searchPathLen += i; searchPathLen += i;
if (i < commonLen) { if (i < commonLen) {
++searchPathLen; ++searchPathLen;
auto c = n->partialKey[i] <=> remaining[i]; auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) { if (c > 0) {
return downLeftSpine(); return downLeftSpine();
} else { } else {
@@ -1455,33 +1460,30 @@ template <bool kBegin>
for (;;) { for (;;) {
if ((*self)->partialKeyLen > 0) { if ((*self)->partialKeyLen > 0) {
const bool wouldBePresent =
key.size() <= kPartialKeyMaxLenEntryPresent + int(sizeof(Entry));
// Handle an existing partial key // Handle an existing partial key
int commonLen = std::min<int>((*self)->partialKeyLen, key.size()); int commonLen = std::min<int>((*self)->partialKeyLen, key.size());
if (wouldBePresent) {
commonLen = std::min(commonLen, kPartialKeyMaxLenEntryPresent);
}
int partialKeyIndex = longestCommonPrefixPartialKey( int partialKeyIndex = longestCommonPrefixPartialKey(
(*self)->partialKey, key.data(), commonLen); (*self)->partialKey(), key.data(), commonLen);
if (partialKeyIndex < (*self)->partialKeyLen) { if (partialKeyIndex < (*self)->partialKeyLen) {
auto *old = *self; auto *old = *self;
int64_t oldMaxVersion = maxVersion(old, impl); int64_t oldMaxVersion = maxVersion(old, impl);
*self = newNode<Node4>(); *self = newNode<Node4>(partialKeyIndex);
memcpy((char *)*self + kNodeCopyBegin, (char *)old + kNodeCopyBegin, memcpy((char *)*self + kNodeCopyBegin, (char *)old + kNodeCopyBegin,
kNodeCopySize); kNodeCopySize);
(*self)->partialKeyLen = partialKeyIndex; (*self)->partialKeyLen = partialKeyIndex;
(*self)->entryPresent = false; (*self)->entryPresent = false;
(*self)->numChildren = 0; (*self)->numChildren = 0;
memcpy((*self)->partialKey(), old->partialKey(),
(*self)->partialKeyLen);
getOrCreateChild(*self, old->partialKey[partialKeyIndex]) = old; getOrCreateChild(*self, old->partialKey()[partialKeyIndex]) = old;
old->parent = *self; old->parent = *self;
old->parentsIndex = old->partialKey[partialKeyIndex]; old->parentsIndex = old->partialKey()[partialKeyIndex];
maxVersion(old, impl) = oldMaxVersion; maxVersion(old, impl) = oldMaxVersion;
memmove(old->partialKey, old->partialKey + partialKeyIndex + 1, memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1,
old->partialKeyLen - (partialKeyIndex + 1)); old->partialKeyLen - (partialKeyIndex + 1));
old->partialKeyLen -= partialKeyIndex + 1; old->partialKeyLen -= partialKeyIndex + 1;
} }
@@ -1490,13 +1492,9 @@ template <bool kBegin>
} else { } else {
// Consider adding a partial key // Consider adding a partial key
if ((*self)->numChildren == 0 && !(*self)->entryPresent) { if ((*self)->numChildren == 0 && !(*self)->entryPresent) {
const bool willNotBePresent = (*self)->partialKeyLen =
key.size() > kPartialKeyMaxLenEntryPresent + int(sizeof(Entry)); std::min<int>(key.size(), (*self)->partialKeyCapacity);
(*self)->partialKeyLen = std::min<int>( memcpy((*self)->partialKey(), key.data(), (*self)->partialKeyLen);
key.size(), willNotBePresent
? kPartialKeyMaxLenEntryPresent + int(sizeof(Entry))
: kPartialKeyMaxLenEntryPresent);
memcpy((*self)->partialKey, key.data(), (*self)->partialKeyLen);
key = key.subspan((*self)->partialKeyLen, key = key.subspan((*self)->partialKeyLen,
key.size() - (*self)->partialKeyLen); key.size() - (*self)->partialKeyLen);
} }
@@ -1520,7 +1518,7 @@ template <bool kBegin>
auto &child = getOrCreateChild(*self, key.front()); auto &child = getOrCreateChild(*self, key.front());
if (!child) { if (!child) {
child = newNode<Node0>(); child = newNode<Node0>(key.size() - 1);
child->parent = *self; child->parent = *self;
child->parentsIndex = key.front(); child->parentsIndex = key.front();
maxVersion(child, impl) = maxVersion(child, impl) =
@@ -1585,7 +1583,7 @@ void addWriteRange(Node *&root, int64_t oldestVersion,
if (int(remaining.size()) <= n->partialKeyLen) { if (int(remaining.size()) <= n->partialKeyLen) {
break; break;
} }
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
n->partialKeyLen); n->partialKeyLen);
if (i != n->partialKeyLen) { if (i != n->partialKeyLen) {
break; break;
@@ -1692,10 +1690,10 @@ Iterator firstGeq(Node *n, const std::span<const uint8_t> key) {
if (n->partialKeyLen > 0) { if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size()); int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
commonLen); commonLen);
if (i < commonLen) { if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i]; auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) { if (c > 0) {
goto downLeftSpine; goto downLeftSpine;
} else { } else {
@@ -1799,7 +1797,7 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) { explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) {
// Insert "" // Insert ""
root = newNode<Node4>(); root = newNode<Node4>(0);
rootMaxVersion = oldestVersion; rootMaxVersion = oldestVersion;
root->entry.pointVersion = oldestVersion; root->entry.pointVersion = oldestVersion;
root->entry.rangeVersion = oldestVersion; root->entry.rangeVersion = oldestVersion;
@@ -1916,7 +1914,7 @@ std::string getSearchPathPrintable(Node *n) {
auto result = vector<char>(arena); auto result = vector<char>(arena);
for (;;) { for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) { for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey[i]); result.push_back(n->partialKey()[i]);
} }
if (n->parent == nullptr) { if (n->parent == nullptr) {
break; break;
@@ -1940,7 +1938,7 @@ std::string getPartialKeyPrintable(Node *n) {
} }
auto result = std::string((const char *)&n->parentsIndex, auto result = std::string((const char *)&n->parentsIndex,
n->parent == nullptr ? 0 : 1) + n->parent == nullptr ? 0 : 1) +
std::string((const char *)n->partialKey, n->partialKeyLen); std::string((const char *)n->partialKey(), n->partialKeyLen);
return printable(result); // NOLINT return printable(result); // NOLINT
} }
@@ -2141,7 +2139,7 @@ int main(void) {
ankerl::nanobench::Bench bench; ankerl::nanobench::Bench bench;
ConflictSet::Impl cs{0}; ConflictSet::Impl cs{0};
for (int j = 0; j < 256; ++j) { for (int j = 0; j < 256; ++j) {
getOrCreateChild(cs.root, j) = newNode<Node0>(); getOrCreateChild(cs.root, j) = newNode<Node0>(0);
if (j % 10 == 0) { if (j % 10 == 0) {
bench.run("MaxExclusive " + std::to_string(j), [&]() { bench.run("MaxExclusive " + std::to_string(j), [&]() {
bench.doNotOptimizeAway(maxBetweenExclusive(cs.root, 0, 256)); bench.doNotOptimizeAway(maxBetweenExclusive(cs.root, 0, 256));
@@ -2155,8 +2153,6 @@ int main(void) {
#ifdef ENABLE_FUZZ #ifdef ENABLE_FUZZ
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
TestDriver<ConflictSet::Impl> driver{data, size}; TestDriver<ConflictSet::Impl> driver{data, size};
static_assert(driver.kMaxKeyLen >
kPartialKeyMaxLenEntryPresent + sizeof(Entry));
for (;;) { for (;;) {
bool done = driver.next(); bool done = driver.next();