From 567d385fbd32e3bb35c534685a2bb30336824410 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Mon, 12 Aug 2024 16:11:16 -0700 Subject: [PATCH] WIP create child in getOrCreateChild --- ConflictSet.cpp | 52 +++++++++++++++++++++++++++++-------------- LongestCommonPrefix.h | 8 +++++++ 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 499f1cf..b54593b 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -40,11 +40,13 @@ limitations under the License. #include #endif +#ifndef __SANITIZE_THREAD__ #if defined(__has_feature) #if __has_feature(thread_sanitizer) #define __SANITIZE_THREAD__ #endif #endif +#endif #include @@ -1126,9 +1128,12 @@ Node *getFirstChildExists(Node *self) { // Caller is responsible for assigning a non-null pointer to the returned // reference if null. Updates child's max version to `newMaxVersion` if child // exists but does not have a partial key. -Node *&getOrCreateChild(Node *&self, uint8_t index, +Node *&getOrCreateChild(Node *&self, std::span &key, InternalVersionT newMaxVersion, WriteContext *tls) { + int index = key.front(); + key = key.subspan(1, key.size() - 1); + // Fast path for if it exists already switch (self->getType()) { case Type_Node0: @@ -1181,6 +1186,14 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, __builtin_unreachable(); // GCOVR_EXCL_LINE } + auto *newChild = tls->allocate(key.size()); + newChild->numChildren = 0; + newChild->entryPresent = false; + newChild->partialKeyLen = key.size(); + newChild->parentsIndex = index; + memcpy(newChild->partialKey(), key.data(), key.size()); + key = {}; + switch (self->getType()) { case Type_Node0: { auto *self0 = static_cast(self); @@ -1215,8 +1228,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, } self3->index[i + 1] = index; auto &result = self3->children[i + 1]; - result = nullptr; + self3->childMaxVersion[i + 1] = newMaxVersion; + result = newChild; ++self->numChildren; + newChild->parent = self; return result; } case Type_Node16: { @@ -1243,8 +1258,10 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, } self16->index[i + 1] = index; auto &result = self16->children[i + 1]; - result = nullptr; + self16->childMaxVersion[i + 1] = newMaxVersion; + result = newChild; ++self->numChildren; + newChild->parent = self; return result; } case Type_Node48: { @@ -1267,7 +1284,11 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, self48->index[index] = nextFree; self48->reverseIndex[nextFree] = index; auto &result = self48->children[nextFree]; - result = nullptr; + self48->childMaxVersion[nextFree] = newMaxVersion; + self48->maxOfMax[nextFree >> Node48::kMaxOfMaxShift] = std::max( + newMaxVersion, self48->maxOfMax[nextFree >> Node48::kMaxOfMaxShift]); + result = newChild; + newChild->parent = self; return result; } case Type_Node256: { @@ -1276,7 +1297,13 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, auto *self256 = static_cast(self); ++self->numChildren; self256->bitSet.set(index); - return self256->children[index]; + auto &result = self256->children[index]; + self256->childMaxVersion[index] = newMaxVersion; + self256->maxOfMax[index >> Node256::kMaxOfMaxShift] = std::max( + newMaxVersion, self256->maxOfMax[index >> Node256::kMaxOfMaxShift]); + result = newChild; + newChild->parent = self; + return result; } default: // GCOVR_EXCL_LINE __builtin_unreachable(); // GCOVR_EXCL_LINE @@ -2779,18 +2806,9 @@ Node **insert(Node **self, std::span key, return self; } - int index = key.front(); - key = key.subspan(1, key.size() - 1); - auto &child = getOrCreateChild(*self, index, writeVersion, tls); - if (!child) { - child = tls->allocate(key.size()); - child->numChildren = 0; - child->entryPresent = false; - child->partialKeyLen = key.size(); - child->parent = *self; - child->parentsIndex = index; - setMaxVersion(child, impl, writeVersion); - memcpy(child->partialKey(), key.data(), child->partialKeyLen); + bool existed = getChild(*self, key.front()); + auto &child = getOrCreateChild(*self, key, writeVersion, tls); + if (!existed) { return &child; } diff --git a/LongestCommonPrefix.h b/LongestCommonPrefix.h index 2d1d3b9..da3c824 100644 --- a/LongestCommonPrefix.h +++ b/LongestCommonPrefix.h @@ -11,6 +11,14 @@ #include #endif +#ifndef __SANITIZE_THREAD__ +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define __SANITIZE_THREAD__ +#endif +#endif +#endif + #if defined(HAS_AVX) || defined(HAS_ARM_NEON) constexpr int kStride = 64; #else