From 797e6b4a3e69313c091f659a22a63853dd3a9565 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Wed, 13 Mar 2024 10:52:18 -0700 Subject: [PATCH] Use switch for type dispatch throughout --- ConflictSet.cpp | 150 +++++++++++++++++++++++++++--------------------- 1 file changed, 85 insertions(+), 65 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 616ee09..6448795 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -439,36 +439,26 @@ int getNodeIndex(Node16 *self, uint8_t index) { // Precondition - an entry for index must exist in the node Node *&getChildExists(Node *self, uint8_t index) { - if (self->type <= Type::Node16) { + switch (self->type) { + case Type::Node0: + __builtin_unreachable(); // GCOVR_EXCL_LINE + case Type::Node4: + [[fallthrough]]; + case Type::Node16: { auto *self16 = static_cast(self); return self16->children[getNodeIndex(self16, index)].child; - } else if (self->type == Type::Node48) { + } + case Type::Node48: { auto *self48 = static_cast(self); assert(self48->bitSet.test(index)); return self48->children[self48->index[index]].child; - } else { + } + case Type::Node256: { auto *self256 = static_cast(self); assert(self256->bitSet.test(index)); return self256->children[index].child; } - __builtin_unreachable(); // GCOVR_EXCL_LINE -} - -// Precondition - an entry for index must exist in the node -int64_t getChildMaxVersion(Node *self, uint8_t index) { - if (self->type <= Type::Node16) { - auto *self16 = static_cast(self); - return self16->children[getNodeIndex(self16, index)].childMaxVersion; - } else if (self->type == Type::Node48) { - auto *self48 = static_cast(self); - assert(self48->bitSet.test(index)); - return self48->children[self48->index[index]].childMaxVersion; - } else { - auto *self256 = static_cast(self); - assert(self256->bitSet.test(index)); - return self256->children[index].childMaxVersion; } - __builtin_unreachable(); // GCOVR_EXCL_LINE } // Precondition - an entry for index must exist in the node @@ -477,31 +467,38 @@ int64_t &maxVersion(Node *n, ConflictSet::Impl *); Node *&getInTree(Node *n, ConflictSet::Impl *); Node *getChild(Node *self, uint8_t index) { - if (self->type <= Type::Node16) { + switch (self->type) { + case Type::Node0: + return nullptr; + case Type::Node4: + [[fallthrough]]; + case Type::Node16: { auto *self16 = static_cast(self); int i = getNodeIndex(self16, index); - if (i >= 0) { - return self16->children[i].child; - } - return nullptr; - } else if (self->type == Type::Node48) { + return i < 0 ? nullptr : self16->children[i].child; + } + case Type::Node48: { auto *self48 = static_cast(self); - int secondIndex = self48->index[index]; - if (secondIndex >= 0) { - return self48->children[secondIndex].child; - } - return nullptr; - } else { + int i = self48->index[index]; + return i < 0 ? nullptr : self48->children[i].child; + } + case Type::Node256: { auto *self256 = static_cast(self); return self256->children[index].child; } + } } int getChildGeq(Node *self, int child) { if (child > 255) { return -1; } - if (self->type <= Type::Node16) { + switch (self->type) { + case Type::Node0: + return -1; + case Type::Node4: + [[fallthrough]]; + case Type::Node16: { auto *self16 = static_cast(self); #ifdef HAS_AVX __m128i key_vec = _mm_set1_epi8(child); @@ -554,13 +551,17 @@ int getChildGeq(Node *self, int child) { return self16->index[i]; } } + return -1; #endif - } else { + } + case Type::Node48: + [[fallthrough]]; + case Type::Node256: { static_assert(offsetof(Node48, bitSet) == offsetof(Node256, bitSet)); auto *self48 = static_cast(self); return self48->bitSet.firstSetGeq(child); } - return -1; + } } void setChildrenParents(Node4 *n) { @@ -591,26 +592,35 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, NodeAllocators *allocators) { // Fast path for if it exists already - if (self->type <= Type::Node16) { + switch (self->type) { + case Type::Node0: + break; + case Type::Node4: + [[fallthrough]]; + case Type::Node16: { auto *self16 = static_cast(self); int i = getNodeIndex(self16, index); if (i >= 0) { return self16->children[i].child; } - } else if (self->type == Type::Node48) { + } break; + case Type::Node48: { auto *self48 = static_cast(self); int secondIndex = self48->index[index]; if (secondIndex >= 0) { return self48->children[secondIndex].child; } - } else { + } break; + case Type::Node256: { auto *self256 = static_cast(self); if (auto &result = self256->children[index].child; result != nullptr) { return result; } + } break; } - if (self->type == Type::Node0) { + switch (self->type) { + case Type::Node0: { auto *self0 = static_cast(self); auto *newSelf = allocators->node4.allocate(self->partialKeyLen); @@ -621,8 +631,8 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, self = newSelf; goto insert16; - - } else if (self->type == Type::Node4) { + } + case Type::Node4: { auto *self4 = static_cast(self); if (self->numChildren == 4) { @@ -641,9 +651,8 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, } goto insert16; - - } else if (self->type == Type::Node16) { - + } + case Type::Node16: { if (self->numChildren == 16) { auto *self16 = static_cast(self); auto *newSelf = allocators->node48.allocate(self->partialKeyLen); @@ -683,7 +692,8 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, auto &result = self16->children[i].child; result = nullptr; return result; - } else if (self->type == Type::Node48) { + } + case Type::Node48: { if (self->numChildren == 48) { auto *self48 = static_cast(self); @@ -713,14 +723,15 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, auto &result = self48->children[nextFree].child; result = nullptr; return result; - } else { - assert(self->type == Type::Node256); + } + case Type::Node256: { insert256: auto *self256 = static_cast(self); ++self->numChildren; self256->bitSet.set(index); return self256->children[index].child; } + } } Node *nextPhysical(Node *node) { @@ -1040,7 +1051,12 @@ Node *erase(Node *self, NodeAllocators *allocators, ConflictSet::Impl *impl, break; } - if (parent->type <= Type::Node16) { + switch (parent->type) { + case Type::Node0: + __builtin_unreachable(); // GCOVR_EXCL_LINE + case Type::Node4: + [[fallthrough]]; + case Type::Node16: { auto *parent16 = static_cast(parent); int nodeIndex = getNodeIndex(parent16, parentsIndex); assert(nodeIndex >= 0); @@ -1050,7 +1066,8 @@ Node *erase(Node *self, NodeAllocators *allocators, ConflictSet::Impl *impl, memmove(parent16->children + nodeIndex, parent16->children + nodeIndex + 1, sizeof(parent16->children[0]) * (parent->numChildren - (nodeIndex + 1))); - } else if (parent->type == Type::Node48) { + } break; + case Type::Node48: { auto *parent48 = static_cast(parent); parent48->bitSet.reset(parentsIndex); int8_t toRemoveChildrenIndex = @@ -1064,11 +1081,14 @@ Node *erase(Node *self, NodeAllocators *allocators, ConflictSet::Impl *impl, parent48->index[parent48->children[toRemoveChildrenIndex] .child->parentsIndex] = toRemoveChildrenIndex; } - } else { + } break; + case Type::Node256: { auto *parent256 = static_cast(parent); parent256->bitSet.reset(parentsIndex); parent256->children[parentsIndex].child = nullptr; + } break; } + --parent->numChildren; if (parent->numChildren == 0 && !parent->entryPresent && parent->parent != nullptr) { @@ -1357,7 +1377,8 @@ downLeftSpine: } // Return the max version among all keys starting with the search path of n + -// [child], where child in (begin, end) +// [child], where child in (begin, end). Does not account for the range version +// of firstGt(searchpath(n) + [end - 1]) int64_t maxBetweenExclusive(Node *n, int begin, int end) { assume(-1 <= begin); assume(begin <= 256); @@ -1379,7 +1400,8 @@ int64_t maxBetweenExclusive(Node *n, int begin, int end) { } switch (n->type) { case Type::Node0: - [[fallthrough]]; + // We would have returned above, after not finding a child + __builtin_unreachable(); // GCOVR_EXCL_LINE case Type::Node4: [[fallthrough]]; case Type::Node16: { @@ -2239,19 +2261,27 @@ int64_t &maxVersion(Node *n, ConflictSet::Impl *impl) { if (n == nullptr) { return impl->rootMaxVersion; } - if (n->type <= Type::Node16) { + switch (n->type) { + case Type::Node0: + __builtin_unreachable(); // GCOVR_EXCL_LINE + case Type::Node4: + [[fallthrough]]; + case Type::Node16: { auto *n16 = static_cast(n); int i = getNodeIndex(n16, index); return n16->children[i].childMaxVersion; - } else if (n->type == Type::Node48) { + } + case Type::Node48: { auto *n48 = static_cast(n); assert(n48->bitSet.test(index)); return n48->children[n48->index[index]].childMaxVersion; - } else { + } + case Type::Node256: { auto *n256 = static_cast(n); assert(n256->bitSet.test(index)); return n256->children[index].childMaxVersion; } + } } Node *&getInTree(Node *n, ConflictSet::Impl *impl) { @@ -2494,16 +2524,6 @@ Iterator firstGeq(Node *n, std::string_view key) { expected = std::max(expected, borrowed.n->entry.rangeVersion); } } - if (node->parent != nullptr && - getChildMaxVersion(node->parent, node->parentsIndex) != - maxVersion(node, impl)) { - fprintf(stderr, - "%s has max version %" PRId64 - " . But parent has child max version %" PRId64 "\n", - getSearchPathPrintable(node).c_str(), maxVersion(node, impl), - getChildMaxVersion(node->parent, node->parentsIndex)); - success = false; - } if (maxVersion(node, impl) > oldestVersion && maxVersion(node, impl) != expected) { fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",