From c86e407985682f9742cf8e5d872c79b398e6c908 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Thu, 1 Aug 2024 13:53:18 -0700 Subject: [PATCH] Return Node from getChildGeq It seems all callers ultimately want this --- ConflictSet.cpp | 168 +++++++++++++++++++++++------------------------- 1 file changed, 79 insertions(+), 89 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 9d23348..2272d95 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -923,7 +923,7 @@ Node *getChild(Node *self, uint8_t index) { } } -template int getChildGeqSimd(NodeT *self, int child) { +template Node *getChildGeqSimd(NodeT *self, int child) { static_assert(std::is_same_v || std::is_same_v); // cachegrind says the plain loop is fewer instructions and more mis-predicted @@ -934,13 +934,13 @@ template int getChildGeqSimd(NodeT *self, int child) { Node3 *n = (Node3 *)self; for (int i = 0; i < n->numChildren; ++i) { if (n->index[i] >= child) { - return n->index[i]; + return n->children[i]; } } - return -1; + return nullptr; } if (child > 255) { - return -1; + return nullptr; } #ifdef HAS_AVX @@ -950,16 +950,7 @@ template int getChildGeqSimd(NodeT *self, int child) { __m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices)); int mask = (1 << self->numChildren) - 1; uint32_t bitfield = _mm_movemask_epi8(results) & mask; - int result = bitfield == 0 ? -1 : self->index[std::countr_zero(bitfield)]; - assert(result == [&]() -> int { - for (int i = 0; i < self->numChildren; ++i) { - if (self->index[i] >= child) { - return self->index[i]; - } - } - return -1; - }()); - return result; + return bitfield == 0 ? nullptr : self->children[std::countr_zero(bitfield)]; #elif defined(HAS_ARM_NEON) uint8x16_t indices; memcpy(&indices, self->index, sizeof(self->index)); @@ -968,7 +959,7 @@ template int getChildGeqSimd(NodeT *self, int child) { static_assert(NodeT::kMaxNodes <= 16); assume(self->numChildren <= NodeT::kMaxNodes); uint64_t mask = self->numChildren == 16 - ? uint64_t(-1) + ? uint64_t(nullptr) : (uint64_t(1) << (self->numChildren * 4)) - 1; // 0xf for each 0xff (within mask) uint64_t bitfield = @@ -976,43 +967,44 @@ template int getChildGeqSimd(NodeT *self, int child) { vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0) & mask; - int simd = bitfield == 0 ? -1 : self->index[std::countr_zero(bitfield) / 4]; - assert(simd == [&]() -> int { - for (int i = 0; i < self->numChildren; ++i) { - if (self->index[i] >= child) { - return self->index[i]; - } - } - return -1; - }()); - return simd; + return bitfield == 0 ? nullptr + : self->children[std::countr_zero(bitfield) / 4]; #else for (int i = 0; i < self->numChildren; ++i) { if (i > 0) { assert(self->index[i - 1] < self->index[i]); } if (self->index[i] >= child) { - return self->index[i]; + return self->children[i]; } } - return -1; + return nullptr; #endif } -int getChildGeq(Node0 *, int) { return -1; } -int getChildGeq(Node3 *self, int child) { return getChildGeqSimd(self, child); } -int getChildGeq(Node16 *self, int child) { +Node *getChildGeq(Node0 *, int) { return nullptr; } +Node *getChildGeq(Node3 *self, int child) { return getChildGeqSimd(self, child); } -int getChildGeq(Node48 *self, int child) { - return self->bitSet.firstSetGeq(child); +Node *getChildGeq(Node16 *self, int child) { + return getChildGeqSimd(self, child); } -int getChildGeq(Node256 *self, int child) { - static_assert(offsetof(Node48, bitSet) == offsetof(Node256, bitSet)); - return getChildGeq(reinterpret_cast(self), child); +Node *getChildGeq(Node48 *self, int child) { + int c = self->bitSet.firstSetGeq(child); + if (c < 0) { + return nullptr; + } + return self->children[self->index[c]]; +} +Node *getChildGeq(Node256 *self, int child) { + int c = self->bitSet.firstSetGeq(child); + if (c < 0) { + return nullptr; + } + return self->children[c]; } -int getChildGeq(Node *self, int child) { +Node *getChildGeq(Node *self, int child) { switch (self->getType()) { case Type_Node0: return getChildGeq(static_cast(self), child); @@ -1269,36 +1261,36 @@ Node *nextPhysical(Node *node) { case Type_Node0: { auto *n = static_cast(node); auto nextChild = getChildGeq(n, index + 1); - if (nextChild >= 0) { - return getChildExists(n, nextChild); + if (nextChild != nullptr) { + return nextChild; } } break; case Type_Node3: { auto *n = static_cast(node); auto nextChild = getChildGeq(n, index + 1); - if (nextChild >= 0) { - return getChildExists(n, nextChild); + if (nextChild != nullptr) { + return nextChild; } } break; case Type_Node16: { auto *n = static_cast(node); auto nextChild = getChildGeq(n, index + 1); - if (nextChild >= 0) { - return getChildExists(n, nextChild); + if (nextChild != nullptr) { + return nextChild; } } break; case Type_Node48: { auto *n = static_cast(node); auto nextChild = getChildGeq(n, index + 1); - if (nextChild >= 0) { - return getChildExists(n, nextChild); + if (nextChild != nullptr) { + return nextChild; } } break; case Type_Node256: { auto *n = static_cast(node); auto nextChild = getChildGeq(n, index + 1); - if (nextChild >= 0) { - return getChildExists(n, nextChild); + if (nextChild != nullptr) { + return nextChild; } } break; default: // GCOVR_EXCL_LINE @@ -1686,10 +1678,10 @@ Node *nextSibling(Node *node) { return nullptr; } auto next = getChildGeq(node->parent, node->parentsIndex + 1); - if (next < 0) { + if (next == nullptr) { node = node->parent; } else { - return getChildExists(node->parent, next); + return next; } } } @@ -1888,9 +1880,9 @@ bool checkPointRead(Node *n, const std::span key, auto *child = getChild(n, remaining[0]); if (child == nullptr) { - int c = getChildGeq(n, remaining[0]); - if (c >= 0) { - n = getChildExists(n, c); + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; goto downLeftSpine; } else { n = nextSibling(n); @@ -1959,9 +1951,9 @@ bool checkPrefixRead(Node *n, const std::span key, auto *child = getChild(n, remaining[0]); if (child == nullptr) { - int c = getChildGeq(n, remaining[0]); - if (c >= 0) { - n = getChildExists(n, c); + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; goto downLeftSpine; } else { n = nextSibling(n); @@ -2460,9 +2452,9 @@ bool checkRangeStartsWith(Node *n, std::span key, int begin, auto *child = getChild(n, remaining[0]); if (child == nullptr) { - int c = getChildGeq(n, remaining[0]); - if (c >= 0) { - n = getChildExists(n, c); + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; goto downLeftSpine; } else { n = nextSibling(n); @@ -2556,13 +2548,13 @@ template struct CheckRangeLeftSide { auto *child = getChild(n, remaining[0]); if (child == nullptr) { - int c = getChildGeq(n, remaining[0]); - if (c >= 0) { + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { if (searchPathLen < prefixLen) { - n = getChildExists(n, c); + n = c; return downLeftSpine(); } - n = getChildExists(n, c); + n = c; ok = maxVersion(n, impl) <= readVersion; return true; } else { @@ -2688,9 +2680,9 @@ template struct CheckRangeRightSide { auto *child = getChild(n, remaining[0]); if (child == nullptr) { - int c = getChildGeq(n, remaining[0]); - if (c >= 0) { - n = getChildExists(n, c); + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; return downLeftSpine(); } else { return backtrack(); @@ -2740,12 +2732,12 @@ template struct CheckRangeRightSide { return true; } auto next = getChildGeq(n->parent, n->parentsIndex + 1); - if (next < 0) { + if (next == nullptr) { searchPathLen -= 1 + n->partialKeyLen; n = n->parent; } else { searchPathLen -= n->partialKeyLen; - n = getChildExists(n->parent, next); + n = next; searchPathLen += n->partialKeyLen; return downLeftSpine(); } @@ -3004,9 +2996,8 @@ void destroyTree(Node *root) { auto *n = toFree.back(); toFree.pop_back(); // Add all children to toFree - for (int child = getChildGeq(n, 0); child >= 0; - child = getChildGeq(n, child + 1)) { - auto *c = getChildExists(n, child); + for (auto c = getChildGeq(n, 0); c != nullptr; + c = getChildGeq(n, c->parentsIndex + 1)) { assert(c != nullptr); toFree.push_back(c); } @@ -3144,9 +3135,9 @@ Node *firstGeqPhysical(Node *n, const std::span key) { auto *child = getChild(n, remaining[0]); if (child == nullptr) { - int c = getChildGeq(n, remaining[0]); - if (c >= 0) { - n = getChildExists(n, c); + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; return n; } else { n = nextSibling(n); @@ -3673,9 +3664,9 @@ Node *firstGeqLogical(Node *n, const std::span key) { auto *child = getChild(n, remaining[0]); if (child == nullptr) { - int c = getChildGeq(n, remaining[0]); - if (c >= 0) { - n = getChildExists(n, c); + auto c = getChildGeq(n, remaining[0]); + if (c != nullptr) { + n = c; goto downLeftSpine; } else { n = nextSibling(n); @@ -3888,9 +3879,8 @@ std::string getSearchPath(Node *n) { getPartialKeyPrintable(n).c_str(), x, y); } x += kSeparation; - for (int child = getChildGeq(n, 0); child >= 0; - child = getChildGeq(n, child + 1)) { - auto *c = getChildExists(n, child); + for (auto c = getChildGeq(n, 0); c != nullptr; + c = getChildGeq(n, c->parentsIndex + 1)) { fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c); print(c, y - kSeparation); } @@ -3909,12 +3899,12 @@ std::string getSearchPath(Node *n) { } void checkParentPointers(Node *node, bool &success) { - for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { - auto *child = getChildExists(node, i); + for (auto child = getChildGeq(node, 0); child != nullptr; + child = getChildGeq(node, child->parentsIndex + 1)) { if (child->parent != node) { fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n", - getSearchPathPrintable(node).c_str(), i, (void *)child->parent, - (void *)node); + getSearchPathPrintable(node).c_str(), child->parentsIndex, + (void *)child->parent, (void *)node); success = false; } checkParentPointers(child, success); @@ -3979,8 +3969,8 @@ checkMaxVersion(Node *root, Node *node, InternalVersionT oldestVersion, if (node->entryPresent) { expected = std::max(expected, node->entry.pointVersion); } - for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { - auto *child = getChildExists(node, i); + for (auto child = getChildGeq(node, 0); child != nullptr; + child = getChildGeq(node, child->parentsIndex + 1)) { expected = std::max( expected, checkMaxVersion(root, child, oldestVersion, success, impl)); if (child->entryPresent) { @@ -4008,14 +3998,14 @@ checkMaxVersion(Node *root, Node *node, InternalVersionT oldestVersion, [[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) { int64_t total = node->entryPresent; - for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { - auto *child = getChildExists(node, i); + for (auto child = getChildGeq(node, 0); child != nullptr; + child = getChildGeq(node, child->parentsIndex + 1)) { int64_t e = checkEntriesExist(child, success); total += e; if (e == 0) { Arena arena; fprintf(stderr, "%s has child %02x with no reachable entries\n", - getSearchPathPrintable(node).c_str(), i); + getSearchPathPrintable(node).c_str(), child->parentsIndex); success = false; } } @@ -4052,8 +4042,8 @@ checkMaxVersion(Node *root, Node *node, InternalVersionT oldestVersion, success = false; } // TODO check that the max capacity property eventually holds - for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { - auto *child = getChildExists(node, i); + for (auto child = getChildGeq(node, 0); child != nullptr; + child = getChildGeq(node, child->parentsIndex + 1)) { checkMemoryBoundInvariants(child, success); } }