/* Copyright 2024 Andrew Noyes Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "ConflictSet.h" #include "Internal.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef HAS_AVX #include #elif defined(HAS_ARM_NEON) #include #endif #include // ==================== BEGIN IMPLEMENTATION ==================== struct Entry { int64_t pointVersion; int64_t rangeVersion; }; template struct BoundedFreeListAllocator { static_assert(sizeof(T) >= sizeof(void *)); T *allocate() { if (freeListSize == 0) { assert(freeList == nullptr); return new (safe_malloc(sizeof(T))) T; } assert(freeList != nullptr); void *buffer = freeList; VALGRIND_MAKE_MEM_DEFINED(freeList, sizeof(freeList)); memcpy(&freeList, freeList, sizeof(freeList)); --freeListSize; VALGRIND_MAKE_MEM_UNDEFINED(buffer, sizeof(T)); return new (buffer) T; } void release(T *p) { p->~T(); if (freeListSize == kMaxFreeListSize) { return free(p); } memcpy((void *)p, &freeList, sizeof(freeList)); freeList = p; ++freeListSize; VALGRIND_MAKE_MEM_NOACCESS(p, sizeof(T)); } ~BoundedFreeListAllocator() { for (void *iter = freeList; iter != nullptr;) { VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(iter)); auto *tmp = iter; memcpy(&iter, iter, sizeof(void *)); free(tmp); } } private: static constexpr int kMaxFreeListSize = kMemoryBound / sizeof(T); int freeListSize = 0; void *freeList = nullptr; }; struct BitSet { bool test(int i) const { assert(0 <= i); assert(i < 256); if (i < 128) { return (lo >> i) & 1; } else { return (hi >> (i - 128)) & 1; } } void set(int i) { assert(0 <= i); assert(i < 256); if (i < 128) { lo |= __uint128_t(1) << i; } else { hi |= __uint128_t(1) << (i - 128); } } void reset(int i) { assert(0 <= i); assert(i < 256); if (i < 128) { lo &= ~(__uint128_t(1) << i); } else { hi &= ~(__uint128_t(1) << (i - 128)); } } int firstSetGeq(int i) const { assert(0 <= i); if (i >= 256) { return -1; } if (i < 128) { int a = std::countr_zero(lo >> i); if (a < 128) { assert(i + a < 128); return i + a; } i = 128; } int b = std::countr_zero(hi >> (i - 128)); if (b < 128) { assert(i + b < 256); return i + b; } return -1; } private: __uint128_t lo = 0; __uint128_t hi = 0; }; enum class Type : int8_t { Node4, Node16, Node48, Node256, Invalid, }; struct Node { /* begin section that's copied to the next node */ Node *parent = nullptr; // The max write version over all keys that start with the search path up to // this point int64_t maxVersion; Entry entry; int16_t numChildren = 0; bool entryPresent = false; uint8_t parentsIndex = 0; constexpr static auto kPartialKeyMaxLen = 26; uint8_t partialKey[kPartialKeyMaxLen]; int8_t partialKeyLen = 0; /* end section that's copied to the next node */ Type type = Type::Invalid; }; struct Node4 : Node { // Sorted uint8_t index[16]; // 16 so that we can use the same simd index search // implementation for Node4 as Node16 Node *children[4]; Node4() { this->type = Type::Node4; } }; struct Node16 : Node { // Sorted uint8_t index[16]; Node *children[16]; Node16() { this->type = Type::Node16; } }; struct Node48 : Node { BitSet bitSet; Node *children[48]; int8_t nextFree = 0; int8_t index[256]; Node48() { memset(index, -1, 256); this->type = Type::Node48; } }; struct Node256 : Node { BitSet bitSet; Node *children[256] = {}; Node256() { this->type = Type::Node256; } }; struct NodeAllocators { BoundedFreeListAllocator node4; BoundedFreeListAllocator node16; BoundedFreeListAllocator node48; BoundedFreeListAllocator node256; }; int getNodeIndex(Node16 *self, uint8_t index) { #ifdef HAS_AVX // Based on https://www.the-paper-trail.org/post/art-paper-notes/ // key_vec is 16 repeated copies of the searched-for byte, one for every // possible position in child_keys that needs to be searched. __m128i key_vec = _mm_set1_epi8(index); // Compare all child_keys to 'index' in parallel. Don't worry if some of the // keys aren't valid, we'll mask the results to only consider the valid ones // below. __m128i indices; memcpy(&indices, self->index, sizeof(self->index)); __m128i results = _mm_cmpeq_epi8(key_vec, indices); // Build a mask to select only the first node->num_children values from the // comparison (because the other values are meaningless) uint32_t mask = (1 << self->numChildren) - 1; // Change the results of the comparison into a bitfield, masking off any // invalid comparisons. uint32_t bitfield = _mm_movemask_epi8(results) & mask; // No match if there are no '1's in the bitfield. if (bitfield == 0) return -1; // Find the index of the first '1' in the bitfield by counting the leading // zeros. return std::countr_zero(bitfield); #elif defined(HAS_ARM_NEON) // Based on // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon uint8x16_t indices; memcpy(&indices, self->index, sizeof(self->index)); // 0xff for each match uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices)); uint64_t mask = self->numChildren == 16 ? uint64_t(-1) : (uint64_t(1) << (self->numChildren * 4)) - 1; // 0xf for each match in valid range uint64_t bitfield = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask; if (bitfield == 0) return -1; return std::countr_zero(bitfield) / 4; #else for (int i = 0; i < self->numChildren; ++i) { if (self->index[i] == index) { return i; } } return -1; #endif } // Precondition - an entry for index must exist in the node Node *&getChildExists(Node *self, uint8_t index) { if (self->type <= Type::Node16) { auto *self16 = static_cast(self); return self16->children[getNodeIndex(self16, index)]; } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); assert(self48->bitSet.test(index)); return self48->children[self48->index[index]]; } else { auto *self256 = static_cast(self); return self256->children[index]; } __builtin_unreachable(); // GCOVR_EXCL_LINE } Node *getChild(Node *self, uint8_t index) { if (self->type <= Type::Node16) { auto *self16 = static_cast(self); int i = getNodeIndex(self16, index); if (i >= 0) { return self16->children[i]; } return nullptr; } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); int secondIndex = self48->index[index]; if (secondIndex >= 0) { return self48->children[secondIndex]; } return nullptr; } else { auto *self256 = static_cast(self); return self256->children[index]; } } int getChildGeq(Node *self, int child) { if (child > 255) { return -1; } if (self->type <= Type::Node16) { auto *self16 = static_cast(self); #ifdef HAS_AVX __m128i key_vec = _mm_set1_epi8(child); __m128i indices; memcpy(&indices, self16->index, sizeof(self16->index)); __m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices)); int mask = (1 << self16->numChildren) - 1; uint32_t bitfield = _mm_movemask_epi8(results) & mask; int result = bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield)]; assert(result == [&]() -> int { for (int i = 0; i < self16->numChildren; ++i) { if (self16->index[i] >= child) { return self16->index[i]; } } return -1; }()); return result; #elif defined(HAS_ARM_NEON) uint8x16_t indices; memcpy(&indices, self16->index, sizeof(self16->index)); // 0xff for each leq auto results = vcleq_u8(vdupq_n_u8(child), indices); uint64_t mask = self->numChildren == 16 ? uint64_t(-1) : (uint64_t(1) << (self->numChildren * 4)) - 1; // 0xf for each 0xff (within mask) uint64_t bitfield = vget_lane_u64( vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0) & mask; int simd = bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield) / 4]; assert(simd == [&]() -> int { for (int i = 0; i < self->numChildren; ++i) { if (self16->index[i] >= child) { return self16->index[i]; } } return -1; }()); return simd; #else for (int i = 0; i < self->numChildren; ++i) { if (i > 0) { assert(self16->index[i - 1] < self16->index[i]); } if (self16->index[i] >= child) { return self16->index[i]; } } #endif } else { static_assert(offsetof(Node48, bitSet) == offsetof(Node256, bitSet)); auto *self48 = static_cast(self); return self48->bitSet.firstSetGeq(child); } return -1; } void setChildrenParents(Node *node) { for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { getChildExists(node, i)->parent = node; } } // Caller is responsible for assigning a non-null pointer to the returned // reference if null Node *&getOrCreateChild(Node *&self, uint8_t index, NodeAllocators *allocators) { if (self->type == Type::Node4) { auto *self4 = static_cast(self); { int i = getNodeIndex((Node16 *)self4, index); if (i >= 0) { return self4->children[i]; } } if (self->numChildren == 4) { auto *newSelf = allocators->node16.allocate(); memcpy((void *)newSelf, self, offsetof(Node, type)); memcpy(newSelf->index, self4->index, 4); memcpy(newSelf->children, self4->children, 4 * sizeof(void *)); allocators->node4.release(self4); self = newSelf; setChildrenParents(self); goto insert16; } else { ++self->numChildren; for (int i = 0; i < int(self->numChildren) - 1; ++i) { if (int(self4->index[i]) > int(index)) { memmove(self4->index + i + 1, self4->index + i, self->numChildren - (i + 1)); memmove(self4->children + i + 1, self4->children + i, (self->numChildren - (i + 1)) * sizeof(void *)); self4->index[i] = index; self4->children[i] = nullptr; return self4->children[i]; } } self4->index[self->numChildren - 1] = index; self4->children[self->numChildren - 1] = nullptr; return self4->children[self->numChildren - 1]; } } else if (self->type == Type::Node16) { insert16: auto *self16 = static_cast(self); { int i = getNodeIndex(self16, index); if (i >= 0) { return self16->children[i]; } } if (self->numChildren == 16) { auto *newSelf = allocators->node48.allocate(); memcpy((void *)newSelf, self, offsetof(Node, type)); newSelf->nextFree = 16; int i = 0; for (auto x : self16->index) { newSelf->bitSet.set(x); newSelf->children[i] = self16->children[i]; newSelf->index[x] = i; ++i; } assert(i == 16); allocators->node16.release(self16); self = newSelf; setChildrenParents(self); goto insert48; } else { ++self->numChildren; for (int i = 0; i < int(self->numChildren) - 1; ++i) { if (int(self16->index[i]) > int(index)) { memmove(self16->index + i + 1, self16->index + i, self->numChildren - (i + 1)); memmove(self16->children + i + 1, self16->children + i, (self->numChildren - (i + 1)) * sizeof(void *)); self16->index[i] = index; self16->children[i] = nullptr; return self16->children[i]; } } self16->index[self->numChildren - 1] = index; self16->children[self->numChildren - 1] = nullptr; return self16->children[self->numChildren - 1]; } } else if (self->type == Type::Node48) { insert48: auto *self48 = static_cast(self); if (self48->bitSet.test(index)) { return self48->children[self48->index[index]]; } if (self->numChildren == 48) { auto *newSelf = allocators->node256.allocate(); memcpy((void *)newSelf, self, offsetof(Node, type)); for (int i = 0; i < 256; ++i) { if (self48->bitSet.test(i)) { newSelf->bitSet.set(i); newSelf->children[i] = self48->children[self48->index[i]]; } } allocators->node48.release(self48); self = newSelf; setChildrenParents(self); goto insert256; } else { self48->bitSet.set(index); ++self->numChildren; assert(self48->nextFree < 48); self48->index[index] = self48->nextFree; self48->children[self48->nextFree] = nullptr; return self48->children[self48->nextFree++]; } } else { insert256: auto *self256 = static_cast(self); if (!self256->children[index]) { ++self->numChildren; } self256->bitSet.set(index); return self256->children[index]; } } // Precondition - an entry for index must exist in the node void eraseChild(Node *self, uint8_t index, NodeAllocators *allocators) { auto *child = getChildExists(self, index); switch (child->type) { case Type::Node4: allocators->node4.release((Node4 *)child); break; case Type::Node16: allocators->node16.release((Node16 *)child); break; case Type::Node48: allocators->node48.release((Node48 *)child); break; case Type::Node256: allocators->node256.release((Node256 *)child); break; case Type::Invalid: __builtin_unreachable(); // GCOVR_EXCL_LINE } if (self->type <= Type::Node16) { auto *self16 = static_cast(self); int nodeIndex = getNodeIndex(self16, index); memmove(self16->index + nodeIndex, self16->index + nodeIndex + 1, sizeof(self16->index[0]) * (self->numChildren - (nodeIndex + 1))); memmove(self16->children + nodeIndex, self16->children + nodeIndex + 1, sizeof(self16->children[0]) * // NOLINT (self->numChildren - (nodeIndex + 1))); } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); self48->bitSet.reset(index); int8_t toRemoveChildrenIndex = std::exchange(self48->index[index], -1); int8_t lastChildrenIndex = --self48->nextFree; assert(toRemoveChildrenIndex >= 0); assert(lastChildrenIndex >= 0); if (toRemoveChildrenIndex != lastChildrenIndex) { self48->children[toRemoveChildrenIndex] = std::exchange(self48->children[lastChildrenIndex], nullptr); self48->index[self48->children[toRemoveChildrenIndex]->parentsIndex] = toRemoveChildrenIndex; } } else { auto *self256 = static_cast(self); self256->bitSet.reset(index); self256->children[index] = nullptr; } --self->numChildren; if (self->numChildren == 0 && !self->entryPresent && self->parent != nullptr) { eraseChild(self->parent, self->parentsIndex, allocators); } } Node *nextPhysical(Node *node) { int index = -1; for (;;) { auto nextChild = getChildGeq(node, index + 1); if (nextChild >= 0) { return getChildExists(node, nextChild); } index = node->parentsIndex; node = node->parent; if (node == nullptr) { return nullptr; } } } Node *nextLogical(Node *node) { for (node = nextPhysical(node); node != nullptr && !node->entryPresent; node = nextPhysical(node)) ; return node; } struct Iterator { Node *n; int cmp; }; Node *nextSibling(Node *node) { for (;;) { if (node->parent == nullptr) { return nullptr; } auto next = getChildGeq(node->parent, node->parentsIndex + 1); if (next < 0) { node = node->parent; } else { return getChildExists(node->parent, next); } } } #if defined(HAS_AVX) || defined(HAS_ARM_NEON) constexpr int kStride = 64; #else constexpr int kStride = 16; #endif constexpr int kUnrollFactor = 4; bool compareStride(const uint8_t *ap, const uint8_t *bp) { #if defined(HAS_ARM_NEON) static_assert(kStride == 64); uint8x16_t x[4]; for (int i = 0; i < 4; ++i) { x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); } auto results = vreinterpretq_u16_u8( vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == uint64_t(-1); #elif defined(HAS_AVX) static_assert(kStride == 64); __m128i x[4]; for (int i = 0; i < 4; ++i) { x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), _mm_loadu_si128((__m128i *)(bp + i * 16))); } auto eq = _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), _mm_and_si128(x[2], x[3]))) == 0xffff; #else // Hope it gets vectorized auto eq = memcmp(ap, bp, kStride) == 0; #endif return eq; } int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { #if defined(HAS_AVX) static_assert(kStride == 64); uint64_t c[kStride / 16]; for (int i = 0; i < kStride; i += 16) { const auto a = _mm_loadu_si128((__m128i *)(ap + i)); const auto b = _mm_loadu_si128((__m128i *)(bp + i)); const auto compared = _mm_cmpeq_epi8(a, b); c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; } return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); #elif defined(HAS_ARM_NEON) static_assert(kStride == 64); for (int i = 0; i < kStride; i += 16) { // 0xff for each match uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); // 0xf for each mismatch uint64_t bitfield = ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); if (bitfield) { return i + (std::countr_zero(bitfield) >> 2); } } __builtin_unreachable(); // GCOVR_EXCL_LINE #else int i = 0; for (; i < kStride - 1; ++i) { if (*ap++ != *bp++) { break; } } return i; #endif } int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { if (cl < 0) { __builtin_unreachable(); // GCOVR_EXCL_LINE } int i = 0; int end; if (cl < 8) { goto bytes; } // Optimistic early return { uint64_t a; uint64_t b; memcpy(&a, ap, 8); memcpy(&b, bp, 8); const auto mismatched = a ^ b; if (mismatched) { return std::countr_zero(mismatched) / 8; } } // kStride * kUnrollCount at a time end = cl & ~(kStride * kUnrollFactor - 1); while (i < end) { for (int j = 0; j < kUnrollFactor; ++j) { if (!compareStride(ap, bp)) { return i + firstNeqStride(ap, bp); } i += kStride; ap += kStride; bp += kStride; } } // kStride at a time end = cl & ~(kStride - 1); while (i < end) { if (!compareStride(ap, bp)) { return i + firstNeqStride(ap, bp); } i += kStride; ap += kStride; bp += kStride; } // word at a time end = cl & ~(sizeof(uint64_t) - 1); while (i < end) { uint64_t a; uint64_t b; memcpy(&a, ap, 8); memcpy(&b, bp, 8); const auto mismatched = a ^ b; if (mismatched) { return i + std::countr_zero(mismatched) / 8; } i += 8; ap += 8; bp += 8; } bytes: // byte at a time while (i < cl) { if (*ap != *bp) { break; } ++ap; ++bp; ++i; } return i; } int longestCommonPrefixPartialKey(const uint8_t *ap, const uint8_t *bp, int cl) { int i = 0; for (; i < cl; ++i) { if (*ap++ != *bp++) { break; } } return i; } // Performs a physical search for remaining struct SearchStepWise { Node *n; std::span remaining; SearchStepWise() {} SearchStepWise(Node *n, std::span remaining) : n(n), remaining(remaining) { assert(n->partialKeyLen == 0); } bool step() { if (remaining.size() == 0) { return true; } auto *child = getChild(n, remaining[0]); if (child == nullptr) { return true; } int cl = std::min(child->partialKeyLen, remaining.size() - 1); int i = longestCommonPrefixPartialKey(child->partialKey, remaining.data() + 1, cl); if (i != child->partialKeyLen) { return true; } n = child; remaining = remaining.subspan(1 + child->partialKeyLen, remaining.size() - (1 + child->partialKeyLen)); return false; } }; namespace { std::string getSearchPathPrintable(Node *n); } // Logically this is the same as performing firstGeq and then checking against // point or range version according to cmp, but this version short circuits as // soon as it can prove that there's no conflict. bool checkPointRead(Node *n, const std::span key, int64_t readVersion) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); #endif auto remaining = key; for (;;) { if (n->maxVersion <= readVersion) { return true; } if (remaining.size() == 0) { if (n->entryPresent) { return n->entry.pointVersion <= readVersion; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); goto downLeftSpine; } else { int c = getChildGeq(n, remaining[0]); if (c == remaining[0]) { n = getChildExists(n, c); remaining = remaining.subspan(1, remaining.size() - 1); } else { if (c >= 0) { n = getChildExists(n, c); goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefix(n->partialKey, remaining.data(), commonLen); if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { // n is the first physical node greater than remaining, and there's no // eq node goto downLeftSpine; } } } downLeftSpine: if (n == nullptr) { return true; } for (;;) { if (n->entryPresent) { return n->entry.rangeVersion <= readVersion; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); } } int64_t maxBetweenExclusive(Node *n, int begin, int end) { assert(-1 <= begin); assert(begin <= 256); assert(-1 <= end); assert(end <= 256); assert(begin < end); int64_t result = std::numeric_limits::lowest(); constexpr int kSparseThreshold = 32; { int c = getChildGeq(n, begin + 1); if (c >= 0 && c < end) { auto *child = getChildExists(n, c); if (child->entryPresent) { result = std::max(result, child->entry.rangeVersion); } } } switch (n->type) { case Type::Node4: [[fallthrough]]; case Type::Node16: { auto *self = static_cast(n); for (int i = 0; i < self->numChildren && self->index[i] < end; ++i) { if (begin < self->index[i]) { result = std::max(result, self->children[i]->maxVersion); } } break; } case Type::Node48: { auto *self = static_cast(n); if (self->numChildren < kSparseThreshold) { for (int i = self->bitSet.firstSetGeq(begin + 1); i < end && i >= 0; i = self->bitSet.firstSetGeq(i + 1)) { if (self->index[i] != -1) { result = std::max(result, self->children[self->index[i]]->maxVersion); } } } else { for (int i = begin + 1; i < end; ++i) { if (self->index[i] != -1) { result = std::max(result, self->children[self->index[i]]->maxVersion); } } } break; } case Type::Node256: { auto *self = static_cast(n); if (self->numChildren < kSparseThreshold) { for (int i = self->bitSet.firstSetGeq(begin + 1); i < end && i >= 0; i = self->bitSet.firstSetGeq(i + 1)) { result = std::max(result, self->children[i]->maxVersion); } } else { for (int i = begin + 1; i < end; ++i) { if (self->children[i] != nullptr) { result = std::max(result, self->children[i]->maxVersion); } } } break; } case Type::Invalid: __builtin_unreachable(); // GCOVR_EXCL_LINE } #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "At `%s', max version in (%02x, %02x) is %" PRId64 "\n", getSearchPathPrintable(n).c_str(), begin, end, result); #endif return result; } Vector getSearchPath(Arena &arena, Node *n) { assert(n != nullptr); auto result = vector(arena); for (;;) { for (int i = n->partialKeyLen - 1; i >= 0; --i) { result.push_back(n->partialKey[i]); } if (n->parent == nullptr) { break; } result.push_back(n->parentsIndex); n = n->parent; } std::reverse(result.begin(), result.end()); return result; } // Return true if the max version among all keys that start with key + [child], // where begin < child < end, is <= readVersion bool checkRangeStartsWith(Node *n, std::span key, int begin, int end, int64_t readVersion) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end); #endif auto remaining = key; for (;;) { if (n->maxVersion <= readVersion) { return true; } if (remaining.size() == 0) { return maxBetweenExclusive(n, begin, end) <= readVersion; } int c = getChildGeq(n, remaining[0]); if (c == remaining[0]) { n = getChildExists(n, c); remaining = remaining.subspan(1, remaining.size() - 1); } else { if (c >= 0) { n = getChildExists(n, c); goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), commonLen); if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { if (begin < n->partialKey[remaining.size()] && n->partialKey[remaining.size()] < end) { if (n->entryPresent && n->entry.rangeVersion > readVersion) { return false; } return n->maxVersion <= readVersion; } return true; } } } downLeftSpine: if (n == nullptr) { return true; } for (;;) { if (n->entryPresent) { return n->entry.rangeVersion <= readVersion; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); } } // Return true if the max version among all keys that start with key[:prefixLen] // that are >= key is <= readVersion struct CheckRangeLeftSide { CheckRangeLeftSide(Node *n, std::span key, int prefixLen, int64_t readVersion) : n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check range left side from %s for keys starting with %s\n", printable(key).c_str(), printable(key.subspan(0, prefixLen)).c_str()); #endif } Node *n; std::span remaining; int prefixLen; int64_t readVersion; int searchPathLen = 0; bool ok; enum Phase { Search, DownLeftSpine } phase = Search; bool step() { switch (phase) { case Search: { if (n->maxVersion <= readVersion) { ok = true; return true; } if (remaining.size() == 0) { assert(searchPathLen >= prefixLen); ok = n->maxVersion <= readVersion; return true; } if (searchPathLen >= prefixLen) { if (maxBetweenExclusive(n, remaining[0], 256) > readVersion) { ok = false; return true; } } int c = getChildGeq(n, remaining[0]); if (c == remaining[0]) { n = getChildExists(n, c); remaining = remaining.subspan(1, remaining.size() - 1); ++searchPathLen; } else { if (c >= 0) { if (searchPathLen < prefixLen) { n = getChildExists(n, c); return downLeftSpine(); } n = getChildExists(n, c); ok = n->maxVersion <= readVersion; return true; } else { n = nextSibling(n); return downLeftSpine(); } } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefix(n->partialKey, remaining.data(), commonLen); searchPathLen += i; if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; if (c > 0) { if (searchPathLen < prefixLen) { return downLeftSpine(); } if (n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } ok = n->maxVersion <= readVersion; return true; } else { n = nextSibling(n); return downLeftSpine(); } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { if (searchPathLen < prefixLen) { return downLeftSpine(); } if (n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } ok = n->maxVersion <= readVersion; return true; } } break; } case DownLeftSpine: if (n->entryPresent) { ok = n->entry.rangeVersion <= readVersion; return true; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); break; } return false; } bool downLeftSpine() { phase = DownLeftSpine; if (n == nullptr) { ok = true; return true; } return false; } }; // Return true if the max version among all keys that start with key[:prefixLen] // that are < key is <= readVersion struct CheckRangeRightSide { CheckRangeRightSide(Node *n, std::span key, int prefixLen, int64_t readVersion) : n(n), key(key), remaining(key), prefixLen(prefixLen), readVersion(readVersion) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check range right side to %s for keys starting with %s\n", printable(key).c_str(), printable(key.subspan(0, prefixLen)).c_str()); #endif } Node *n; std::span key; std::span remaining; int prefixLen; int64_t readVersion; int searchPathLen = 0; bool ok; enum Phase { Search, DownLeftSpine } phase = Search; bool step() { switch (phase) { case Search: { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf( stderr, "Search path: %s, searchPathLen: %d, prefixLen: %d, remaining: %s\n", getSearchPathPrintable(n).c_str(), searchPathLen, prefixLen, printable(remaining).c_str()); #endif assert(searchPathLen <= int(key.size())); if (remaining.size() == 0) { return downLeftSpine(); } if (searchPathLen >= prefixLen) { if (n->entryPresent && n->entry.pointVersion > readVersion) { ok = false; return true; } if (maxBetweenExclusive(n, -1, remaining[0]) > readVersion) { ok = false; return true; } } if (searchPathLen > prefixLen && n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } int c = getChildGeq(n, remaining[0]); if (c == remaining[0]) { n = getChildExists(n, c); remaining = remaining.subspan(1, remaining.size() - 1); ++searchPathLen; } else { if (c >= 0) { n = getChildExists(n, c); return downLeftSpine(); } else { return backtrack(); } } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), commonLen); searchPathLen += i; if (i < commonLen) { ++searchPathLen; auto c = n->partialKey[i] <=> remaining[i]; if (c > 0) { return downLeftSpine(); } else { if (searchPathLen > prefixLen && n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } return backtrack(); } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { return downLeftSpine(); } } } break; case DownLeftSpine: if (n->entryPresent) { ok = n->entry.rangeVersion <= readVersion; return true; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); break; } return false; } bool backtrack() { for (;;) { if (searchPathLen > prefixLen && n->maxVersion > readVersion) { ok = false; return true; } if (n->parent == nullptr) { ok = true; return true; } auto next = getChildGeq(n->parent, n->parentsIndex + 1); if (next < 0) { searchPathLen -= 1 + n->partialKeyLen; n = n->parent; } else { searchPathLen -= n->partialKeyLen; n = getChildExists(n->parent, next); searchPathLen += n->partialKeyLen; return downLeftSpine(); } } } bool downLeftSpine() { phase = DownLeftSpine; if (n == nullptr) { ok = true; return true; } return false; } }; bool checkRangeRead(Node *n, std::span begin, std::span end, int64_t readVersion) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && end.back() == 0) { return checkPointRead(n, begin, readVersion); } SearchStepWise search{n, begin.subspan(0, lcp)}; Arena arena; for (;;) { assert(getSearchPath(arena, search.n) <=> begin.subspan(0, lcp - search.remaining.size()) == 0); if (search.n->maxVersion <= readVersion) { return true; } if (search.step()) { break; } } assert(getSearchPath(arena, search.n) <=> begin.subspan(0, lcp - search.remaining.size()) == 0); const int consumed = lcp - search.remaining.size(); assert(consumed >= 0); begin = begin.subspan(consumed, int(begin.size()) - consumed); end = end.subspan(consumed, int(end.size()) - consumed); n = search.n; lcp -= consumed; if (lcp == int(begin.size())) { CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion}; while (!checkRangeRightSide.step()) ; return checkRangeRightSide.ok; } if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp], readVersion)) { return false; } CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion}; CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion}; for (;;) { bool leftDone = checkRangeLeftSide.step(); bool rightDone = checkRangeRightSide.step(); if (!leftDone && !rightDone) { continue; } if (leftDone && rightDone) { break; } else if (leftDone) { while (!checkRangeRightSide.step()) ; break; } else { assert(rightDone); while (!checkRangeLeftSide.step()) ; } break; } return checkRangeLeftSide.ok & checkRangeRightSide.ok; } // Returns a pointer to the newly inserted node. caller is reponsible for // setting 'entry' fields and `maxVersion` on the result, which may have // !entryPresent. The search path of the result's parent will have // `maxVersion` at least `writeVersion` as a postcondition. [[nodiscard]] Node *insert(Node **self_, std::span key, int64_t writeVersion, bool begin, NodeAllocators *allocators) { for (;;) { auto &self = *self_; // Handle an existing partial key int commonLen = std::min(self->partialKeyLen, key.size()); int partialKeyIndex = longestCommonPrefixPartialKey(self->partialKey, key.data(), commonLen); if (partialKeyIndex < self->partialKeyLen) { auto *old = self; self = allocators->node4.allocate(); self->maxVersion = old->maxVersion; self->partialKeyLen = partialKeyIndex; self->parent = old->parent; self->parentsIndex = old->parentsIndex; memcpy(self->partialKey, old->partialKey, partialKeyIndex); getOrCreateChild(self, old->partialKey[partialKeyIndex], allocators) = old; old->parent = self; old->parentsIndex = old->partialKey[partialKeyIndex]; memmove(old->partialKey, old->partialKey + partialKeyIndex + 1, old->partialKeyLen - (partialKeyIndex + 1)); old->partialKeyLen -= partialKeyIndex + 1; } key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex); // Consider adding a partial key if (self->numChildren == 0 && !self->entryPresent) { self->partialKeyLen = std::min(key.size(), self->kPartialKeyMaxLen); memcpy(self->partialKey, key.data(), self->partialKeyLen); key = key.subspan(self->partialKeyLen, key.size() - self->partialKeyLen); } if (begin) { self->maxVersion = std::max(self->maxVersion, writeVersion); } if (key.size() == 0) { return self; } if (!begin) { self->maxVersion = std::max(self->maxVersion, writeVersion); } auto &child = getOrCreateChild(self, key.front(), allocators); if (!child) { child = allocators->node4.allocate(); child->parent = self; child->parentsIndex = key.front(); child->maxVersion = begin ? writeVersion : std::numeric_limits::lowest(); } self_ = &child; key = key.subspan(1, key.size() - 1); } } void destroyTree(Node *root) { Arena arena; auto toFree = vector(arena); toFree.push_back(root); while (toFree.size() > 0) { auto *n = toFree.back(); toFree.pop_back(); // Add all children to toFree for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { auto *c = getChildExists(n, child); assert(c != nullptr); toFree.push_back(c); } free(n); } } void addPointWrite(Node *&root, int64_t oldestVersion, std::span key, int64_t writeVersion, NodeAllocators *allocators) { auto *n = insert(&root, key, writeVersion, true, allocators); if (!n->entryPresent) { auto *p = nextLogical(n); n->entryPresent = true; n->entry.pointVersion = writeVersion; n->maxVersion = writeVersion; n->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; } else { n->entry.pointVersion = std::max(n->entry.pointVersion, writeVersion); n->maxVersion = std::max(n->maxVersion, writeVersion); } } void addWriteRange(Node *&root, int64_t oldestVersion, std::span begin, std::span end, int64_t writeVersion, NodeAllocators *allocators) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && end.back() == 0) { return addPointWrite(root, oldestVersion, begin, writeVersion, allocators); } auto remaining = begin.subspan(0, lcp); auto *n = root; for (;;) { if (int(remaining.size()) <= n->partialKeyLen) { break; } int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), n->partialKeyLen); if (i != n->partialKeyLen) { break; } auto *child = getChild(n, remaining[n->partialKeyLen]); if (child == nullptr) { break; } n->maxVersion = std::max(n->maxVersion, writeVersion); remaining = remaining.subspan(n->partialKeyLen + 1, remaining.size() - (n->partialKeyLen + 1)); n = child; } Node **useAsRoot = n->parent == nullptr ? &root : &getChildExists(n->parent, n->parentsIndex); int consumed = lcp - remaining.size(); begin = begin.subspan(consumed, begin.size() - consumed); end = end.subspan(consumed, end.size() - consumed); auto *beginNode = insert(useAsRoot, begin, writeVersion, true, allocators); const bool insertedBegin = !std::exchange(beginNode->entryPresent, true); if (insertedBegin) { auto *p = nextLogical(beginNode); beginNode->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; beginNode->entry.pointVersion = writeVersion; beginNode->maxVersion = writeVersion; } beginNode->maxVersion = std::max(beginNode->maxVersion, writeVersion); beginNode->entry.pointVersion = std::max(beginNode->entry.pointVersion, writeVersion); auto *endNode = insert(useAsRoot, end, writeVersion, false, allocators); const bool insertedEnd = !std::exchange(endNode->entryPresent, true); if (insertedEnd) { auto *p = nextLogical(endNode); endNode->entry.pointVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; endNode->maxVersion = std::max(endNode->maxVersion, endNode->entry.pointVersion); } endNode->entry.rangeVersion = writeVersion; if (insertedEnd) { // beginNode may have been invalidated beginNode = insert(useAsRoot, begin, writeVersion, true, allocators); } for (beginNode = nextLogical(beginNode); beginNode != endNode;) { auto *old = beginNode; beginNode = nextLogical(beginNode); old->entryPresent = false; if (old->numChildren == 0 && old->parent != nullptr) { eraseChild(old->parent, old->parentsIndex, allocators); } } } struct FirstGeqStepwise { Node *n; std::span remaining; int cmp; enum Phase { Init, // Being in this phase implies that the key matches the search path exactly // up to this point Search, DownLeftSpine }; Phase phase; FirstGeqStepwise(Node *n, std::span remaining) : n(n), remaining(remaining), phase(Init) {} // Not being done implies that n is not the firstGeq bool step() { switch (phase) { case Search: if (remaining.size() == 0) { int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); return downLeftSpine(); } else { int c = getChildGeq(n, remaining[0]); if (c == remaining[0]) { n = getChildExists(n, c); remaining = remaining.subspan(1, remaining.size() - 1); } else { if (c >= 0) { n = getChildExists(n, c); return downLeftSpine(); } else { n = nextSibling(n); return downLeftSpine(); } } } if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(), commonLen); if (i < commonLen) { auto c = n->partialKey[i] <=> remaining[i]; if (c > 0) { return downLeftSpine(); } else { n = nextSibling(n); return downLeftSpine(); } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { // n is the first physical node greater than remaining, and there's no // eq node return downLeftSpine(); } } [[fallthrough]]; case Init: phase = Search; if (remaining.size() == 0 && n->entryPresent) { cmp = 0; return true; } return false; case DownLeftSpine: int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); if (n->entryPresent) { cmp = 1; return true; } return false; } __builtin_unreachable(); // GCOVR_EXCL_LINE } bool downLeftSpine() { phase = DownLeftSpine; if (n == nullptr || n->entryPresent) { cmp = 1; return true; } return step(); } }; Iterator firstGeq(Node *n, const std::span key) { FirstGeqStepwise stepwise{n, key}; while (!stepwise.step()) ; return {stepwise.n, stepwise.cmp}; } Iterator firstGeq(Node *n, std::string_view key) { return firstGeq( n, std::span((const uint8_t *)key.data(), key.size())); } struct __attribute__((visibility("hidden"))) ConflictSet::Impl { void check(const ReadRange *reads, Result *result, int count) const { for (int i = 0; i < count; ++i) { const auto &r = reads[i]; auto begin = std::span(r.begin.p, r.begin.len); auto end = std::span(r.end.p, r.end.len); result[i] = reads[i].readVersion < oldestVersion ? TooOld : (end.size() > 0 ? checkRangeRead(root, begin, end, reads[i].readVersion) : checkPointRead(root, begin, reads[i].readVersion)) ? Commit : Conflict; } } void addWrites(const WriteRange *writes, int count) { for (int i = 0; i < count; ++i) { const auto &w = writes[i]; auto begin = std::span(w.begin.p, w.begin.len); auto end = std::span(w.end.p, w.end.len); if (w.end.len > 0) { keyUpdates += 2; addWriteRange(root, oldestVersion, begin, end, w.writeVersion, &allocators); } else { keyUpdates += 1; addPointWrite(root, oldestVersion, begin, w.writeVersion, &allocators); } } } void setOldestVersion(int64_t oldestVersion) { this->oldestVersion = oldestVersion; Node *prev = firstGeq(root, removalKey).n; // There's no way to erase removalKey without introducing a key after it assert(prev != nullptr); while (keyUpdates-- > 0) { Node *n = nextLogical(prev); if (n == nullptr) { removalKey = {}; return; } if (std::max(prev->entry.pointVersion, prev->entry.rangeVersion) <= oldestVersion) { // Any transaction prev would have prevented from committing is // going to fail with TooOld anyway. // There's no way to insert a range such that range version of the right // node is greater than the point version of the left node assert(n->entry.rangeVersion <= oldestVersion); prev->entryPresent = false; if (prev->numChildren == 0 && prev->parent != nullptr) { eraseChild(prev->parent, prev->parentsIndex, &allocators); } } prev = n; } removalKeyArena = Arena(); removalKey = getSearchPath(removalKeyArena, prev); } explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) { // Insert "" root = allocators.node4.allocate(); root->maxVersion = oldestVersion; root->entry.pointVersion = oldestVersion; root->entry.rangeVersion = oldestVersion; root->entryPresent = true; } ~Impl() { destroyTree(root); } NodeAllocators allocators; Arena removalKeyArena; std::span removalKey; int64_t keyUpdates = 0; Node *root; int64_t oldestVersion; }; // ==================== END IMPLEMENTATION ==================== // GCOVR_EXCL_START void ConflictSet::check(const ReadRange *reads, Result *results, int count) const { return impl->check(reads, results, count); } void ConflictSet::addWrites(const WriteRange *writes, int count) { return impl->addWrites(writes, count); } void ConflictSet::setOldestVersion(int64_t oldestVersion) { return impl->setOldestVersion(oldestVersion); } ConflictSet::ConflictSet(int64_t oldestVersion) : impl(new (safe_malloc(sizeof(Impl))) Impl{oldestVersion}) {} ConflictSet::~ConflictSet() { if (impl) { impl->~Impl(); free(impl); } } ConflictSet::ConflictSet(ConflictSet &&other) noexcept : impl(std::exchange(other.impl, nullptr)) {} ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept { impl = std::exchange(other.impl, nullptr); return *this; } using ConflictSet_Result = ConflictSet::Result; using ConflictSet_Key = ConflictSet::Key; using ConflictSet_ReadRange = ConflictSet::ReadRange; using ConflictSet_WriteRange = ConflictSet::WriteRange; extern "C" { __attribute__((__visibility__("default"))) void ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads, ConflictSet_Result *results, int count) { ((ConflictSet::Impl *)cs)->check(reads, results, count); } __attribute__((__visibility__("default"))) void ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count) { ((ConflictSet::Impl *)cs)->addWrites(writes, count); } __attribute__((__visibility__("default"))) void ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) { ((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion); } __attribute__((__visibility__("default"))) void * ConflictSet_create(int64_t oldestVersion) { return new (safe_malloc(sizeof(ConflictSet::Impl))) ConflictSet::Impl{oldestVersion}; } __attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) { using Impl = ConflictSet::Impl; ((Impl *)cs)->~Impl(); free(cs); } } namespace { std::string getSearchPathPrintable(Node *n) { Arena arena; if (n == nullptr) { return ""; } auto result = vector(arena); for (;;) { for (int i = n->partialKeyLen - 1; i >= 0; --i) { result.push_back(n->partialKey[i]); } if (n->parent == nullptr) { break; } result.push_back(n->parentsIndex); n = n->parent; } std::reverse(result.begin(), result.end()); if (result.size() > 0) { return printable(std::string_view((const char *)&result[0], result.size())); // NOLINT } else { return std::string(); } } std::string getPartialKeyPrintable(Node *n) { Arena arena; if (n == nullptr) { return ""; } auto result = std::string((const char *)&n->parentsIndex, n->parent == nullptr ? 0 : 1) + std::string((const char *)n->partialKey, n->partialKeyLen); return printable(result); // NOLINT } std::string strinc(std::string_view str, bool &ok) { int index; for (index = str.size() - 1; index >= 0; index--) if ((uint8_t &)(str[index]) != 255) break; // Must not be called with a string that consists only of zero or more '\xff' // bytes. if (index < 0) { ok = false; return {}; } ok = true; auto r = std::string(str.substr(0, index + 1)); ((uint8_t &)r[r.size() - 1])++; return r; } std::string getSearchPath(Node *n) { assert(n != nullptr); Arena arena; auto result = getSearchPath(arena, n); return std::string((const char *)result.data(), result.size()); } [[maybe_unused]] void debugPrintDot(FILE *file, Node *node) { constexpr int kSeparation = 3; struct DebugDotPrinter { explicit DebugDotPrinter(FILE *file) : file(file) {} void print(Node *n, int y = 0) { assert(n != nullptr); if (n->entryPresent) { fprintf(file, " k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", (void *)n, n->maxVersion, n->entry.pointVersion, n->entry.rangeVersion, getPartialKeyPrintable(n).c_str(), x, y); } else { fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", (void *)n, n->maxVersion, getPartialKeyPrintable(n).c_str(), x, y); } x += kSeparation; for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { auto *c = getChildExists(n, child); fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c); print(c, y - kSeparation); } } int x = 0; FILE *file; }; fprintf(file, "digraph ConflictSet {\n"); fprintf(file, " node [shape = box];\n"); assert(node != nullptr); DebugDotPrinter printer{file}; printer.print(node); fprintf(file, "}\n"); } void checkParentPointers(Node *node, bool &success) { for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); if (child->parent != node) { fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n", getSearchPathPrintable(node).c_str(), i, (void *)child->parent, (void *)node); success = false; } checkParentPointers(child, success); } } [[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node, int64_t oldestVersion, bool &success) { int64_t expected = std::numeric_limits::lowest(); if (node->entryPresent) { expected = std::max(expected, node->entry.pointVersion); } for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); expected = std::max(expected, checkMaxVersion(root, child, oldestVersion, success)); if (child->entryPresent) { expected = std::max(expected, child->entry.rangeVersion); } } auto key = getSearchPath(root); bool ok; auto inc = strinc(key, ok); if (ok) { auto borrowed = firstGeq(root, inc); if (borrowed.n != nullptr) { expected = std::max(expected, borrowed.n->entry.rangeVersion); } } if (node->maxVersion > oldestVersion && node->maxVersion != expected) { fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n", getSearchPathPrintable(node).c_str(), node->maxVersion, expected); success = false; } return expected; } [[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) { int64_t total = node->entryPresent; for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); int64_t e = checkEntriesExist(child, success); total += e; if (e == 0) { Arena arena; fprintf(stderr, "%s has child %02x with no reachable entries\n", getSearchPathPrintable(node).c_str(), i); success = false; } } return total; } bool checkCorrectness(Node *node, int64_t oldestVersion) { bool success = true; checkParentPointers(node, success); checkMaxVersion(node, node, oldestVersion, success); checkEntriesExist(node, success); return success; } } // namespace namespace std { void __throw_length_error(const char *) { __builtin_unreachable(); } } // namespace std #ifdef ENABLE_MAIN void printTree() { int64_t writeVersion = 0; ConflictSet::Impl cs{writeVersion}; ReferenceImpl refImpl{writeVersion}; Arena arena; constexpr int kNumKeys = 5; auto *write = new (arena) ConflictSet::WriteRange[kNumKeys]; for (int i = 0; i < kNumKeys; ++i) { write[i].begin = toKey(arena, i); write[i].end.len = 0; write[i].writeVersion = ++writeVersion; } cs.addWrites(write, kNumKeys); for (int i = 0; i < kNumKeys; ++i) { write[i].writeVersion = ++writeVersion; } cs.addWrites(write, kNumKeys); debugPrintDot(stdout, cs.root); } int main(void) { printTree(); return 0; } #endif #ifdef ENABLE_FUZZ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { TestDriver driver{data, size}; static_assert(driver.kMaxKeyLen > Node::kPartialKeyMaxLen); for (;;) { bool done = driver.next(); if (!driver.ok) { debugPrintDot(stdout, driver.cs.root); fflush(stdout); abort(); } #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check correctness\n"); #endif bool success = checkCorrectness(driver.cs.root, driver.cs.oldestVersion); if (!success) { debugPrintDot(stdout, driver.cs.root); fflush(stdout); abort(); } if (done) { break; } } return 0; } #endif // GCOVR_EXCL_STOP