/* Copyright 2024 Andrew Noyes Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "ConflictSet.h" #include "Internal.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef HAS_AVX #include #elif defined(HAS_ARM_NEON) #include #endif #include // Use assert for checking potentially complex properties during tests. // Use assume to hint simple properties to the optimizer. // TODO use the c++23 version when that's available #ifdef NDEBUG #if __has_builtin(__builtin_assume) #define assume(e) __builtin_assume(e) #else #define assume(e) __attribute__((assume(e))) #endif #else #define assume assert #endif // ==================== BEGIN IMPLEMENTATION ==================== struct Entry { int64_t pointVersion; int64_t rangeVersion; }; struct BitSet { bool test(int i) const; void set(int i); void reset(int i); int firstSetGeq(int i) const; // Calls `f` with the index of each bit set in [begin, end) template void forEachInRange(F f, int begin, int end) { // See section 3.1 in https://arxiv.org/pdf/1709.07821.pdf for details about // this approach if ((begin >> 6) == (end >> 6)) { uint64_t word = words[begin >> 6] & (uint64_t(-1) << (begin & 63)) & ~(uint64_t(-1) << (end & 63)); while (word) { uint64_t temp = word & -word; int index = (begin & ~63) + std::countr_zero(word); f(index); word ^= temp; } return; } // Check begin partial word if (begin & 63) { uint64_t word = words[begin >> 6] & (uint64_t(-1) << (begin & 63)); if (std::popcount(word) + (begin & 63) == 64) { while (begin & 63) { f(begin++); } } else { while (word) { uint64_t temp = word & -word; int index = (begin & ~63) + std::countr_zero(word); f(index); word ^= temp; } begin &= ~63; begin += 64; } } // Check inner, full words while (begin != (end & ~63)) { uint64_t word = words[begin >> 6]; if (word == uint64_t(-1)) { for (int i = 0; i < 64; ++i) { f(begin + i); } } else { while (word) { uint64_t temp = word & -word; int index = begin + std::countr_zero(word); f(index); word ^= temp; } } begin += 64; } if (end & 63) { // Check end partial word uint64_t word = words[end >> 6] & ~(uint64_t(-1) << (end & 63)); if (std::popcount(word) == (end & 63)) { while (begin < end) { f(begin++); } } else { while (word) { uint64_t temp = word & -word; int index = begin + std::countr_zero(word); f(index); word ^= temp; } } } } private: uint64_t words[4] = {}; }; bool BitSet::test(int i) const { assert(0 <= i); assert(i < 256); return words[i >> 6] & (uint64_t(1) << (i & 63)); } void BitSet::set(int i) { assert(0 <= i); assert(i < 256); words[i >> 6] |= uint64_t(1) << (i & 63); } void BitSet::reset(int i) { assert(0 <= i); assert(i < 256); words[i >> 6] &= ~(uint64_t(1) << (i & 63)); } int BitSet::firstSetGeq(int i) const { assume(0 <= i); // i may be >= 256 uint64_t mask = uint64_t(-1) << (i & 63); for (int j = i >> 6; j < 4; ++j) { uint64_t masked = mask & words[j]; if (masked) { return (j << 6) + std::countr_zero(masked); } mask = -1; } return -1; } enum class Type : int8_t { Node0, Node4, Node16, Node48, Node256, }; struct Node { /* begin section that's copied to the next node */ Node *parent = nullptr; Entry entry; int32_t partialKeyLen = 0; int16_t numChildren = 0; bool entryPresent = false; uint8_t parentsIndex = 0; /* end section that's copied to the next node */ Type type; // Leaving this uninitialized is intentional and necessary for correctness. // Basically it needs to be preserved when going to the free list and back. int32_t partialKeyCapacity; uint8_t *partialKey(); }; constexpr int kNodeCopyBegin = offsetof(Node, parent); constexpr int kNodeCopySize = offsetof(Node, type) - kNodeCopyBegin; struct Child { int64_t childMaxVersion; Node *child; }; struct Node0 : Node { // Sorted uint8_t index[16]; // 16 so that we can use the same simd index search // implementation as Node16 Node0() { this->type = Type::Node0; } uint8_t *partialKey() { return (uint8_t *)(this + 1); } }; struct Node4 : Node { // Sorted uint8_t index[16]; // 16 so that we can use the same simd index search // implementation as Node16 Child children[4]; Node4() { this->type = Type::Node4; } uint8_t *partialKey() { return (uint8_t *)(this + 1); } }; struct Node16 : Node { // Sorted uint8_t index[16]; Child children[16]; Node16() { this->type = Type::Node16; } uint8_t *partialKey() { return (uint8_t *)(this + 1); } }; struct Node48 : Node { BitSet bitSet; Child children[48]; int8_t nextFree = 0; int8_t index[256]; Node48() { memset(index, -1, 256); this->type = Type::Node48; } uint8_t *partialKey() { return (uint8_t *)(this + 1); } }; struct Node256 : Node { BitSet bitSet; Child children[256]; Node256() { this->type = Type::Node256; for (int i = 0; i < 256; ++i) { children[i].child = nullptr; } } uint8_t *partialKey() { return (uint8_t *)(this + 1); } }; // Bound memory usage following the analysis in the ART paper constexpr int kBytesPerKey = 86; constexpr int kMinChildrenNode4 = 2; constexpr int kMinChildrenNode16 = 5; constexpr int kMinChildrenNode48 = 17; constexpr int kMinChildrenNode256 = 49; static_assert(sizeof(Node256) < kMinChildrenNode256 * kBytesPerKey); static_assert(sizeof(Node48) < kMinChildrenNode48 * kBytesPerKey); static_assert(sizeof(Node16) < kMinChildrenNode16 * kBytesPerKey); static_assert(sizeof(Node4) < kMinChildrenNode4 * kBytesPerKey); // Bounds memory usage in free list, but does not account for memory for partial // keys. template struct BoundedFreeListAllocator { static_assert(sizeof(T) >= sizeof(void *)); static_assert(std::derived_from); T *allocate(int partialKeyCapacity) { #if SHOW_MEMORY ++liveAllocations; maxLiveAllocations = std::max(maxLiveAllocations, liveAllocations); #endif if (freeList != nullptr) { T *n = (T *)freeList; VALGRIND_MAKE_MEM_UNDEFINED(n, sizeof(T)); VALGRIND_MAKE_MEM_DEFINED(&n->partialKeyCapacity, sizeof(n->partialKeyCapacity)); VALGRIND_MAKE_MEM_DEFINED(freeList, sizeof(freeList)); memcpy(&freeList, freeList, sizeof(freeList)); freeListBytes -= sizeof(T) + n->partialKeyCapacity; if (n->partialKeyCapacity >= partialKeyCapacity) { return new (n) T; } else { // The intent is to filter out too-small nodes in the freelist free(n); } } auto *result = new (safe_malloc(sizeof(T) + partialKeyCapacity)) T; result->partialKeyCapacity = partialKeyCapacity; return result; } void release(T *p) { #if SHOW_MEMORY --liveAllocations; #endif static_assert(std::is_trivially_destructible_v); if (sizeof(T) + p->partialKeyCapacity > kMaxIndividual || freeListBytes >= kMemoryBound) { return free(p); } memcpy((void *)p, &freeList, sizeof(freeList)); freeList = p; freeListBytes += sizeof(T) + p->partialKeyCapacity; VALGRIND_MAKE_MEM_NOACCESS(freeList, sizeof(T)); } ~BoundedFreeListAllocator() { for (void *iter = freeList; iter != nullptr;) { VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(iter)); auto *tmp = iter; memcpy(&iter, iter, sizeof(void *)); free(tmp); } } #if SHOW_MEMORY int64_t highWaterMarkBytes() const { return maxLiveAllocations * sizeof(T); } #endif private: int64_t freeListBytes = 0; void *freeList = nullptr; #if SHOW_MEMORY // TODO Track partial key bytes int64_t maxLiveAllocations = 0; int64_t liveAllocations = 0; #endif }; uint8_t *Node::partialKey() { switch (type) { case Type::Node0: return ((Node0 *)this)->partialKey(); case Type::Node4: return ((Node4 *)this)->partialKey(); case Type::Node16: return ((Node16 *)this)->partialKey(); case Type::Node48: return ((Node48 *)this)->partialKey(); case Type::Node256: return ((Node256 *)this)->partialKey(); } __builtin_unreachable(); // GCOVR_EXCL_LINE } struct NodeAllocators { BoundedFreeListAllocator node0; BoundedFreeListAllocator node4; BoundedFreeListAllocator node16; BoundedFreeListAllocator node48; BoundedFreeListAllocator node256; }; int getNodeIndex(Node16 *self, uint8_t index) { #ifdef HAS_AVX // Based on https://www.the-paper-trail.org/post/art-paper-notes/ // key_vec is 16 repeated copies of the searched-for byte, one for every // possible position in child_keys that needs to be searched. __m128i key_vec = _mm_set1_epi8(index); // Compare all child_keys to 'index' in parallel. Don't worry if some of the // keys aren't valid, we'll mask the results to only consider the valid ones // below. __m128i indices; memcpy(&indices, self->index, sizeof(self->index)); __m128i results = _mm_cmpeq_epi8(key_vec, indices); // Build a mask to select only the first node->num_children values from the // comparison (because the other values are meaningless) uint32_t mask = (1 << self->numChildren) - 1; // Change the results of the comparison into a bitfield, masking off any // invalid comparisons. uint32_t bitfield = _mm_movemask_epi8(results) & mask; // No match if there are no '1's in the bitfield. if (bitfield == 0) return -1; // Find the index of the first '1' in the bitfield by counting the leading // zeros. return std::countr_zero(bitfield); #elif defined(HAS_ARM_NEON) // Based on // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon uint8x16_t indices; memcpy(&indices, self->index, sizeof(self->index)); // 0xff for each match uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices)); uint64_t mask = self->numChildren == 16 ? uint64_t(-1) : (uint64_t(1) << (self->numChildren * 4)) - 1; // 0xf for each match in valid range uint64_t bitfield = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask; if (bitfield == 0) return -1; return std::countr_zero(bitfield) / 4; #else for (int i = 0; i < self->numChildren; ++i) { if (self->index[i] == index) { return i; } } return -1; #endif } // Precondition - an entry for index must exist in the node Node *&getChildExists(Node *self, uint8_t index) { if (self->type <= Type::Node16) { auto *self16 = static_cast(self); return self16->children[getNodeIndex(self16, index)].child; } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); assert(self48->bitSet.test(index)); return self48->children[self48->index[index]].child; } else { auto *self256 = static_cast(self); assert(self256->bitSet.test(index)); return self256->children[index].child; } __builtin_unreachable(); // GCOVR_EXCL_LINE } // Precondition - an entry for index must exist in the node int64_t getChildMaxVersion(Node *self, uint8_t index) { if (self->type <= Type::Node16) { auto *self16 = static_cast(self); return self16->children[getNodeIndex(self16, index)].childMaxVersion; } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); assert(self48->bitSet.test(index)); return self48->children[self48->index[index]].childMaxVersion; } else { auto *self256 = static_cast(self); assert(self256->bitSet.test(index)); return self256->children[index].childMaxVersion; } __builtin_unreachable(); // GCOVR_EXCL_LINE } // Precondition - an entry for index must exist in the node int64_t &maxVersion(Node *n, ConflictSet::Impl *); Node *getChild(Node *self, uint8_t index) { if (self->type <= Type::Node16) { auto *self16 = static_cast(self); int i = getNodeIndex(self16, index); if (i >= 0) { return self16->children[i].child; } return nullptr; } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); int secondIndex = self48->index[index]; if (secondIndex >= 0) { return self48->children[secondIndex].child; } return nullptr; } else { auto *self256 = static_cast(self); return self256->children[index].child; } } int getChildGeq(Node *self, int child) { if (child > 255) { return -1; } if (self->type <= Type::Node16) { auto *self16 = static_cast(self); #ifdef HAS_AVX __m128i key_vec = _mm_set1_epi8(child); __m128i indices; memcpy(&indices, self16->index, sizeof(self16->index)); __m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices)); int mask = (1 << self16->numChildren) - 1; uint32_t bitfield = _mm_movemask_epi8(results) & mask; int result = bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield)]; assert(result == [&]() -> int { for (int i = 0; i < self16->numChildren; ++i) { if (self16->index[i] >= child) { return self16->index[i]; } } return -1; }()); return result; #elif defined(HAS_ARM_NEON) uint8x16_t indices; memcpy(&indices, self16->index, sizeof(self16->index)); // 0xff for each leq auto results = vcleq_u8(vdupq_n_u8(child), indices); uint64_t mask = self->numChildren == 16 ? uint64_t(-1) : (uint64_t(1) << (self->numChildren * 4)) - 1; // 0xf for each 0xff (within mask) uint64_t bitfield = vget_lane_u64( vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0) & mask; int simd = bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield) / 4]; assert(simd == [&]() -> int { for (int i = 0; i < self->numChildren; ++i) { if (self16->index[i] >= child) { return self16->index[i]; } } return -1; }()); return simd; #else for (int i = 0; i < self->numChildren; ++i) { if (i > 0) { assert(self16->index[i - 1] < self16->index[i]); } if (self16->index[i] >= child) { return self16->index[i]; } } #endif } else { static_assert(offsetof(Node48, bitSet) == offsetof(Node256, bitSet)); auto *self48 = static_cast(self); return self48->bitSet.firstSetGeq(child); } return -1; } void setChildrenParents(Node16 *n) { for (int i = 0; i < n->numChildren; ++i) { n->children[i].child->parent = n; } } void setChildrenParents(Node48 *n) { n->bitSet.forEachInRange( [&](int i) { n->children[n->index[i]].child->parent = n; }, 0, 256); } void setChildrenParents(Node256 *n) { n->bitSet.forEachInRange([&](int i) { n->children[i].child->parent = n; }, 0, 256); } // Caller is responsible for assigning a non-null pointer to the returned // reference if null Node *&getOrCreateChild(Node *&self, uint8_t index, NodeAllocators *allocators) { // Fast path for if it exists already if (self->type <= Type::Node16) { auto *self16 = static_cast(self); int i = getNodeIndex(self16, index); if (i >= 0) { return self16->children[i].child; } } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); int secondIndex = self48->index[index]; if (secondIndex >= 0) { return self48->children[secondIndex].child; } } else { auto *self256 = static_cast(self); if (auto &result = self256->children[index].child; result != nullptr) { return result; } } if (self->type == Type::Node0) { auto *self0 = static_cast(self); auto *newSelf = allocators->node4.allocate(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); memcpy(newSelf->partialKey(), self0->partialKey(), self->partialKeyLen); allocators->node0.release(self0); self = newSelf; goto insert16; } else if (self->type == Type::Node4) { auto *self4 = static_cast(self); if (self->numChildren == 4) { auto *newSelf = allocators->node16.allocate(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); memcpy(newSelf->partialKey(), self4->partialKey(), self->partialKeyLen); // TODO replace with memcpy? for (int i = 0; i < 4; ++i) { newSelf->index[i] = self4->index[i]; newSelf->children[i] = self4->children[i]; } allocators->node4.release(self4); setChildrenParents(newSelf); self = newSelf; } goto insert16; } else if (self->type == Type::Node16) { if (self->numChildren == 16) { auto *self16 = static_cast(self); auto *newSelf = allocators->node48.allocate(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); memcpy(newSelf->partialKey(), self16->partialKey(), self->partialKeyLen); newSelf->nextFree = 16; int i = 0; for (auto x : self16->index) { newSelf->bitSet.set(x); newSelf->children[i] = self16->children[i]; newSelf->index[x] = i; ++i; } assert(i == 16); allocators->node16.release(self16); setChildrenParents(newSelf); self = newSelf; goto insert48; } insert16: auto *self16 = static_cast(self); ++self->numChildren; int i = 0; for (; i < int(self->numChildren) - 1; ++i) { if (int(self16->index[i]) > int(index)) { memmove(self16->index + i + 1, self16->index + i, self->numChildren - (i + 1)); memmove(self16->children + i + 1, self16->children + i, (self->numChildren - (i + 1)) * sizeof(Child)); break; } } self16->index[i] = index; auto &result = self16->children[i].child; result = nullptr; return result; } else if (self->type == Type::Node48) { if (self->numChildren == 48) { auto *self48 = static_cast(self); auto *newSelf = allocators->node256.allocate(self->partialKeyLen); memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin, kNodeCopySize); memcpy(newSelf->partialKey(), self48->partialKey(), self->partialKeyLen); newSelf->bitSet = self48->bitSet; newSelf->bitSet.forEachInRange( [&](int i) { newSelf->children[i] = self48->children[self48->index[i]]; }, 0, 256); allocators->node48.release(self48); setChildrenParents(newSelf); self = newSelf; goto insert256; } insert48: auto *self48 = static_cast(self); self48->bitSet.set(index); ++self->numChildren; assert(self48->nextFree < 48); int nextFree = self48->nextFree++; self48->index[index] = nextFree; auto &result = self48->children[nextFree].child; result = nullptr; return result; } else { assert(self->type == Type::Node256); insert256: auto *self256 = static_cast(self); ++self->numChildren; self256->bitSet.set(index); return self256->children[index].child; } } // Precondition - an entry for index must exist in the node void eraseChild(Node *self, uint8_t index, NodeAllocators *allocators) { auto *child = getChildExists(self, index); switch (child->type) { case Type::Node0: allocators->node0.release((Node0 *)child); break; case Type::Node4: allocators->node4.release((Node4 *)child); break; case Type::Node16: allocators->node16.release((Node16 *)child); break; case Type::Node48: allocators->node48.release((Node48 *)child); break; case Type::Node256: allocators->node256.release((Node256 *)child); break; } if (self->type <= Type::Node16) { auto *self16 = static_cast(self); int nodeIndex = getNodeIndex(self16, index); memmove(self16->index + nodeIndex, self16->index + nodeIndex + 1, sizeof(self16->index[0]) * (self->numChildren - (nodeIndex + 1))); memmove(self16->children + nodeIndex, self16->children + nodeIndex + 1, sizeof(self16->children[0]) * // NOLINT (self->numChildren - (nodeIndex + 1))); } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); self48->bitSet.reset(index); int8_t toRemoveChildrenIndex = std::exchange(self48->index[index], -1); int8_t lastChildrenIndex = --self48->nextFree; assert(toRemoveChildrenIndex >= 0); assert(lastChildrenIndex >= 0); if (toRemoveChildrenIndex != lastChildrenIndex) { self48->children[toRemoveChildrenIndex] = self48->children[lastChildrenIndex]; self48 ->index[self48->children[toRemoveChildrenIndex].child->parentsIndex] = toRemoveChildrenIndex; } } else { auto *self256 = static_cast(self); self256->bitSet.reset(index); self256->children[index].child = nullptr; } --self->numChildren; if (self->numChildren == 0 && !self->entryPresent && self->parent != nullptr) { eraseChild(self->parent, self->parentsIndex, allocators); } } Node *nextPhysical(Node *node) { int index = -1; for (;;) { auto nextChild = getChildGeq(node, index + 1); if (nextChild >= 0) { return getChildExists(node, nextChild); } index = node->parentsIndex; node = node->parent; if (node == nullptr) { return nullptr; } } } Node *nextLogical(Node *node) { for (node = nextPhysical(node); node != nullptr && !node->entryPresent; node = nextPhysical(node)) ; return node; } struct Iterator { Node *n; int cmp; }; Node *nextSibling(Node *node) { for (;;) { if (node->parent == nullptr) { return nullptr; } auto next = getChildGeq(node->parent, node->parentsIndex + 1); if (next < 0) { node = node->parent; } else { return getChildExists(node->parent, next); } } } #if defined(HAS_AVX) || defined(HAS_ARM_NEON) constexpr int kStride = 64; #else constexpr int kStride = 16; #endif constexpr int kUnrollFactor = 4; bool compareStride(const uint8_t *ap, const uint8_t *bp) { #if defined(HAS_ARM_NEON) static_assert(kStride == 64); uint8x16_t x[4]; for (int i = 0; i < 4; ++i) { x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); } auto results = vreinterpretq_u16_u8( vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == uint64_t(-1); #elif defined(HAS_AVX) static_assert(kStride == 64); __m128i x[4]; for (int i = 0; i < 4; ++i) { x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), _mm_loadu_si128((__m128i *)(bp + i * 16))); } auto eq = _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), _mm_and_si128(x[2], x[3]))) == 0xffff; #else // Hope it gets vectorized auto eq = memcmp(ap, bp, kStride) == 0; #endif return eq; } // Precondition: ap[:kStride] != bp[:kStride] int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { #if defined(HAS_AVX) static_assert(kStride == 64); uint64_t c[kStride / 16]; for (int i = 0; i < kStride; i += 16) { const auto a = _mm_loadu_si128((__m128i *)(ap + i)); const auto b = _mm_loadu_si128((__m128i *)(bp + i)); const auto compared = _mm_cmpeq_epi8(a, b); c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; } return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); #elif defined(HAS_ARM_NEON) static_assert(kStride == 64); for (int i = 0; i < kStride; i += 16) { // 0xff for each match uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); // 0xf for each mismatch uint64_t bitfield = ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); if (bitfield) { return i + (std::countr_zero(bitfield) >> 2); } } __builtin_unreachable(); // GCOVR_EXCL_LINE #else int i = 0; for (; i < kStride - 1; ++i) { if (*ap++ != *bp++) { break; } } return i; #endif } int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { assume(cl >= 0); int i = 0; int end; if (cl < 8) { goto bytes; } // Optimistic early return { uint64_t a; uint64_t b; memcpy(&a, ap, 8); memcpy(&b, bp, 8); const auto mismatched = a ^ b; if (mismatched) { return std::countr_zero(mismatched) / 8; } } // kStride * kUnrollCount at a time end = cl & ~(kStride * kUnrollFactor - 1); while (i < end) { for (int j = 0; j < kUnrollFactor; ++j) { if (!compareStride(ap, bp)) { return i + firstNeqStride(ap, bp); } i += kStride; ap += kStride; bp += kStride; } } // kStride at a time end = cl & ~(kStride - 1); while (i < end) { if (!compareStride(ap, bp)) { return i + firstNeqStride(ap, bp); } i += kStride; ap += kStride; bp += kStride; } // word at a time end = cl & ~(sizeof(uint64_t) - 1); while (i < end) { uint64_t a; uint64_t b; memcpy(&a, ap, 8); memcpy(&b, bp, 8); const auto mismatched = a ^ b; if (mismatched) { return i + std::countr_zero(mismatched) / 8; } i += 8; ap += 8; bp += 8; } bytes: // byte at a time while (i < cl) { if (*ap != *bp) { break; } ++ap; ++bp; ++i; } return i; } // Performs a physical search for remaining struct SearchStepWise { Node *n; std::span remaining; SearchStepWise() {} SearchStepWise(Node *n, std::span remaining) : n(n), remaining(remaining) { assert(n->partialKeyLen == 0); } bool step() { if (remaining.size() == 0) { return true; } auto *child = getChild(n, remaining[0]); if (child == nullptr) { return true; } int cl = std::min(child->partialKeyLen, remaining.size() - 1); int i = longestCommonPrefix(child->partialKey(), remaining.data() + 1, cl); if (i != child->partialKeyLen) { return true; } n = child; remaining = remaining.subspan(1 + child->partialKeyLen, remaining.size() - (1 + child->partialKeyLen)); return false; } }; namespace { std::string getSearchPathPrintable(Node *n); } // Logically this is the same as performing firstGeq and then checking against // point or range version according to cmp, but this version short circuits as // soon as it can prove that there's no conflict. bool checkPointRead(Node *n, const std::span key, int64_t readVersion, ConflictSet::Impl *impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check point read: %s\n", printable(key).c_str()); #endif auto remaining = key; for (;;) { if (maxVersion(n, impl) <= readVersion) { return true; } if (remaining.size() == 0) { if (n->entryPresent) { return n->entry.pointVersion <= readVersion; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); goto downLeftSpine; } auto *child = getChild(n, remaining[0]); if (child == nullptr) { int c = getChildGeq(n, remaining[0]); if (c >= 0) { n = getChildExists(n, c); goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } n = child; remaining = remaining.subspan(1, remaining.size() - 1); if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); if (i < commonLen) { auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { // n is the first physical node greater than remaining, and there's no // eq node goto downLeftSpine; } } } downLeftSpine: if (n == nullptr) { return true; } for (;;) { if (n->entryPresent) { return n->entry.rangeVersion <= readVersion; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); } } // Return the max version among all keys starting with the search path of n + // [child], where child in (begin, end) int64_t maxBetweenExclusive(Node *n, int begin, int end) { assume(-1 <= begin); assume(begin <= 256); assume(-1 <= end); assume(end <= 256); assume(begin < end); int64_t result = std::numeric_limits::lowest(); { int c = getChildGeq(n, begin + 1); if (c >= 0 && c < end) { auto *child = getChildExists(n, c); if (child->entryPresent) { result = std::max(result, child->entry.rangeVersion); } begin = c; } else { return result; } } switch (n->type) { case Type::Node0: [[fallthrough]]; case Type::Node4: [[fallthrough]]; case Type::Node16: { auto *self = static_cast(n); for (int i = 0; i < self->numChildren && self->index[i] < end; ++i) { if (begin <= self->index[i]) { result = std::max(result, self->children[i].childMaxVersion); } } break; } case Type::Node48: { auto *self = static_cast(n); self->bitSet.forEachInRange( [&](int i) { result = std::max(result, self->children[self->index[i]].childMaxVersion); }, begin, end); break; } case Type::Node256: { auto *self = static_cast(n); self->bitSet.forEachInRange( [&](int i) { result = std::max(result, self->children[i].childMaxVersion); }, begin, end); break; } } #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "At `%s', max version in (%02x, %02x) is %" PRId64 "\n", getSearchPathPrintable(n).c_str(), begin, end, result); #endif return result; } Vector getSearchPath(Arena &arena, Node *n) { assert(n != nullptr); auto result = vector(arena); for (;;) { for (int i = n->partialKeyLen - 1; i >= 0; --i) { result.push_back(n->partialKey()[i]); } if (n->parent == nullptr) { break; } result.push_back(n->parentsIndex); n = n->parent; } std::reverse(result.begin(), result.end()); return result; } // Return true if the max version among all keys that start with key + [child], // where begin < child < end, is <= readVersion bool checkRangeStartsWith(Node *n, std::span key, int begin, int end, int64_t readVersion, ConflictSet::Impl *impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end); #endif auto remaining = key; for (;;) { if (maxVersion(n, impl) <= readVersion) { return true; } if (remaining.size() == 0) { return maxBetweenExclusive(n, begin, end) <= readVersion; } auto *child = getChild(n, remaining[0]); if (child == nullptr) { int c = getChildGeq(n, remaining[0]); if (c >= 0) { n = getChildExists(n, c); goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } n = child; remaining = remaining.subspan(1, remaining.size() - 1); if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); if (i < commonLen) { auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { if (begin < n->partialKey()[remaining.size()] && n->partialKey()[remaining.size()] < end) { if (n->entryPresent && n->entry.rangeVersion > readVersion) { return false; } return maxVersion(n, impl) <= readVersion; } return true; } } } downLeftSpine: if (n == nullptr) { return true; } for (;;) { if (n->entryPresent) { return n->entry.rangeVersion <= readVersion; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); } } // Return true if the max version among all keys that start with key[:prefixLen] // that are >= key is <= readVersion struct CheckRangeLeftSide { CheckRangeLeftSide(Node *n, std::span key, int prefixLen, int64_t readVersion, ConflictSet::Impl *impl) : n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion), impl(impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check range left side from %s for keys starting with %s\n", printable(key).c_str(), printable(key.subspan(0, prefixLen)).c_str()); #endif } Node *n; std::span remaining; int prefixLen; int64_t readVersion; ConflictSet::Impl *impl; int searchPathLen = 0; bool ok; enum Phase { Search, DownLeftSpine } phase = Search; bool step() { switch (phase) { case Search: { if (maxVersion(n, impl) <= readVersion) { ok = true; return true; } if (remaining.size() == 0) { assert(searchPathLen >= prefixLen); ok = maxVersion(n, impl) <= readVersion; return true; } if (searchPathLen >= prefixLen) { if (maxBetweenExclusive(n, remaining[0], 256) > readVersion) { ok = false; return true; } } auto *child = getChild(n, remaining[0]); if (child == nullptr) { int c = getChildGeq(n, remaining[0]); if (c >= 0) { if (searchPathLen < prefixLen) { n = getChildExists(n, c); return downLeftSpine(); } n = getChildExists(n, c); ok = maxVersion(n, impl) <= readVersion; return true; } else { n = nextSibling(n); return downLeftSpine(); } } n = child; remaining = remaining.subspan(1, remaining.size() - 1); ++searchPathLen; if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); searchPathLen += i; if (i < commonLen) { auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { if (searchPathLen < prefixLen) { return downLeftSpine(); } if (n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } ok = maxVersion(n, impl) <= readVersion; return true; } else { n = nextSibling(n); return downLeftSpine(); } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { if (searchPathLen < prefixLen) { return downLeftSpine(); } if (n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } ok = maxVersion(n, impl) <= readVersion; return true; } } break; } case DownLeftSpine: if (n->entryPresent) { ok = n->entry.rangeVersion <= readVersion; return true; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); break; } return false; } bool downLeftSpine() { phase = DownLeftSpine; if (n == nullptr) { ok = true; return true; } return false; } }; // Return true if the max version among all keys that start with key[:prefixLen] // that are < key is <= readVersion struct CheckRangeRightSide { CheckRangeRightSide(Node *n, std::span key, int prefixLen, int64_t readVersion, ConflictSet::Impl *impl) : n(n), key(key), remaining(key), prefixLen(prefixLen), readVersion(readVersion), impl(impl) { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check range right side to %s for keys starting with %s\n", printable(key).c_str(), printable(key.subspan(0, prefixLen)).c_str()); #endif } Node *n; std::span key; std::span remaining; int prefixLen; int64_t readVersion; ConflictSet::Impl *impl; int searchPathLen = 0; bool ok; enum Phase { Search, DownLeftSpine } phase = Search; bool step() { switch (phase) { case Search: { #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf( stderr, "Search path: %s, searchPathLen: %d, prefixLen: %d, remaining: %s\n", getSearchPathPrintable(n).c_str(), searchPathLen, prefixLen, printable(remaining).c_str()); #endif assert(searchPathLen <= int(key.size())); if (remaining.size() == 0) { return downLeftSpine(); } if (searchPathLen >= prefixLen) { if (n->entryPresent && n->entry.pointVersion > readVersion) { ok = false; return true; } if (maxBetweenExclusive(n, -1, remaining[0]) > readVersion) { ok = false; return true; } } if (searchPathLen > prefixLen && n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } auto *child = getChild(n, remaining[0]); if (child == nullptr) { int c = getChildGeq(n, remaining[0]); if (c >= 0) { n = getChildExists(n, c); return downLeftSpine(); } else { return backtrack(); } } n = child; remaining = remaining.subspan(1, remaining.size() - 1); ++searchPathLen; if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); searchPathLen += i; if (i < commonLen) { ++searchPathLen; auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { return downLeftSpine(); } else { if (searchPathLen > prefixLen && n->entryPresent && n->entry.rangeVersion > readVersion) { ok = false; return true; } return backtrack(); } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { return downLeftSpine(); } } } break; case DownLeftSpine: if (n->entryPresent) { ok = n->entry.rangeVersion <= readVersion; return true; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); break; } return false; } bool backtrack() { for (;;) { if (searchPathLen > prefixLen && maxVersion(n, impl) > readVersion) { ok = false; return true; } if (n->parent == nullptr) { ok = true; return true; } auto next = getChildGeq(n->parent, n->parentsIndex + 1); if (next < 0) { searchPathLen -= 1 + n->partialKeyLen; n = n->parent; } else { searchPathLen -= n->partialKeyLen; n = getChildExists(n->parent, next); searchPathLen += n->partialKeyLen; return downLeftSpine(); } } } bool downLeftSpine() { phase = DownLeftSpine; if (n == nullptr) { ok = true; return true; } return false; } }; bool checkRangeRead(Node *n, std::span begin, std::span end, int64_t readVersion, ConflictSet::Impl *impl) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && end.back() == 0) { return checkPointRead(n, begin, readVersion, impl); } SearchStepWise search{n, begin.subspan(0, lcp)}; Arena arena; for (;;) { assert(getSearchPath(arena, search.n) <=> begin.subspan(0, lcp - search.remaining.size()) == 0); if (maxVersion(search.n, impl) <= readVersion) { return true; } if (search.step()) { break; } } assert(getSearchPath(arena, search.n) <=> begin.subspan(0, lcp - search.remaining.size()) == 0); const int consumed = lcp - search.remaining.size(); assume(consumed >= 0); begin = begin.subspan(consumed, int(begin.size()) - consumed); end = end.subspan(consumed, int(end.size()) - consumed); n = search.n; lcp -= consumed; if (lcp == int(begin.size())) { CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion, impl}; while (!checkRangeRightSide.step()) ; return checkRangeRightSide.ok; } if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp], readVersion, impl)) { return false; } CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion, impl}; CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion, impl}; for (;;) { bool leftDone = checkRangeLeftSide.step(); bool rightDone = checkRangeRightSide.step(); if (!leftDone && !rightDone) { continue; } if (leftDone && rightDone) { break; } else if (leftDone) { while (!checkRangeRightSide.step()) ; break; } else { assert(rightDone); while (!checkRangeLeftSide.step()) ; } break; } return checkRangeLeftSide.ok & checkRangeRightSide.ok; } // Returns a pointer to the newly inserted node. Caller must set // `entryPresent`, `entry` fields and `maxVersion` on the result. The search // path of the result's parent will have `maxVersion` at least `writeVersion` as // a postcondition. template [[nodiscard]] Node *insert(Node **self, std::span key, int64_t writeVersion, NodeAllocators *allocators, ConflictSet::Impl *impl) { for (;;) { if ((*self)->partialKeyLen > 0) { // Handle an existing partial key int commonLen = std::min((*self)->partialKeyLen, key.size()); int partialKeyIndex = longestCommonPrefix((*self)->partialKey(), key.data(), commonLen); if (partialKeyIndex < (*self)->partialKeyLen) { auto *old = *self; int64_t oldMaxVersion = maxVersion(old, impl); *self = allocators->node0.allocate(partialKeyIndex); memcpy((char *)*self + kNodeCopyBegin, (char *)old + kNodeCopyBegin, kNodeCopySize); (*self)->partialKeyLen = partialKeyIndex; (*self)->entryPresent = false; (*self)->numChildren = 0; memcpy((*self)->partialKey(), old->partialKey(), (*self)->partialKeyLen); getOrCreateChild(*self, old->partialKey()[partialKeyIndex], allocators) = old; old->parent = *self; old->parentsIndex = old->partialKey()[partialKeyIndex]; maxVersion(old, impl) = oldMaxVersion; memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1, old->partialKeyLen - (partialKeyIndex + 1)); old->partialKeyLen -= partialKeyIndex + 1; } key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex); } else { // Consider adding a partial key if ((*self)->numChildren == 0 && !(*self)->entryPresent) { assert((*self)->partialKeyCapacity >= int(key.size())); (*self)->partialKeyLen = key.size(); memcpy((*self)->partialKey(), key.data(), (*self)->partialKeyLen); key = key.subspan((*self)->partialKeyLen, key.size() - (*self)->partialKeyLen); } } if constexpr (kBegin) { auto &m = maxVersion(*self, impl); assert(writeVersion >= m); m = writeVersion; } if (key.size() == 0) { return *self; } if constexpr (!kBegin) { auto &m = maxVersion(*self, impl); assert(writeVersion >= m); m = writeVersion; } auto &child = getOrCreateChild(*self, key.front(), allocators); if (!child) { child = allocators->node0.allocate(key.size() - 1); child->parent = *self; child->parentsIndex = key.front(); maxVersion(child, impl) = kBegin ? writeVersion : std::numeric_limits::lowest(); } self = &child; key = key.subspan(1, key.size() - 1); } } void destroyTree(Node *root) { Arena arena; auto toFree = vector(arena); toFree.push_back(root); while (toFree.size() > 0) { auto *n = toFree.back(); toFree.pop_back(); // Add all children to toFree for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { auto *c = getChildExists(n, child); assert(c != nullptr); toFree.push_back(c); } free(n); } } void addPointWrite(Node *&root, int64_t oldestVersion, std::span key, int64_t writeVersion, NodeAllocators *allocators, ConflictSet::Impl *impl) { auto *n = insert(&root, key, writeVersion, allocators, impl); if (!n->entryPresent) { auto *p = nextLogical(n); n->entryPresent = true; n->entry.pointVersion = writeVersion; maxVersion(n, impl) = writeVersion; n->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; } else { assert(writeVersion >= n->entry.pointVersion); n->entry.pointVersion = writeVersion; } } void addWriteRange(Node *&root, int64_t oldestVersion, std::span begin, std::span end, int64_t writeVersion, NodeAllocators *allocators, ConflictSet::Impl *impl) { int lcp = longestCommonPrefix(begin.data(), end.data(), std::min(begin.size(), end.size())); if (lcp == int(begin.size()) && end.size() == begin.size() + 1 && end.back() == 0) { return addPointWrite(root, oldestVersion, begin, writeVersion, allocators, impl); } auto remaining = begin.subspan(0, lcp); auto *n = root; for (;;) { if (int(remaining.size()) <= n->partialKeyLen) { break; } int i = longestCommonPrefix(n->partialKey(), remaining.data(), n->partialKeyLen); if (i != n->partialKeyLen) { break; } auto *child = getChild(n, remaining[n->partialKeyLen]); if (child == nullptr) { break; } auto &m = maxVersion(n, impl); assert(writeVersion >= m); m = writeVersion; remaining = remaining.subspan(n->partialKeyLen + 1, remaining.size() - (n->partialKeyLen + 1)); n = child; } Node **useAsRoot = n->parent == nullptr ? &root : &getChildExists(n->parent, n->parentsIndex); int consumed = lcp - remaining.size(); begin = begin.subspan(consumed, begin.size() - consumed); end = end.subspan(consumed, end.size() - consumed); auto *beginNode = insert(useAsRoot, begin, writeVersion, allocators, impl); const bool insertedBegin = !beginNode->entryPresent; beginNode->entryPresent = true; if (insertedBegin) { auto *p = nextLogical(beginNode); beginNode->entry.rangeVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; beginNode->entry.pointVersion = writeVersion; maxVersion(beginNode, impl) = writeVersion; } auto &m = maxVersion(beginNode, impl); assert(writeVersion >= m); m = writeVersion; assert(writeVersion >= beginNode->entry.pointVersion); beginNode->entry.pointVersion = writeVersion; auto *endNode = insert(useAsRoot, end, writeVersion, allocators, impl); const bool insertedEnd = !endNode->entryPresent; endNode->entryPresent = true; if (insertedEnd) { auto *p = nextLogical(endNode); endNode->entry.pointVersion = p != nullptr ? p->entry.rangeVersion : oldestVersion; auto &m = maxVersion(endNode, impl); m = std::max(m, endNode->entry.pointVersion); } endNode->entry.rangeVersion = writeVersion; if (insertedEnd) { // beginNode may have been invalidated beginNode = insert(useAsRoot, begin, writeVersion, allocators, impl); assert(beginNode->entryPresent); } for (beginNode = nextLogical(beginNode); beginNode != endNode;) { auto *old = beginNode; beginNode = nextLogical(beginNode); old->entryPresent = false; if (old->numChildren == 0 && old->parent != nullptr) { eraseChild(old->parent, old->parentsIndex, allocators); } } } Iterator firstGeq(Node *n, const std::span key) { auto remaining = key; for (;;) { if (remaining.size() == 0) { if (n->entryPresent) { return {n, 0}; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); goto downLeftSpine; } auto *child = getChild(n, remaining[0]); if (child == nullptr) { int c = getChildGeq(n, remaining[0]); if (c >= 0) { n = getChildExists(n, c); goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } n = child; remaining = remaining.subspan(1, remaining.size() - 1); if (n->partialKeyLen > 0) { int commonLen = std::min(n->partialKeyLen, remaining.size()); int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen); if (i < commonLen) { auto c = n->partialKey()[i] <=> remaining[i]; if (c > 0) { goto downLeftSpine; } else { n = nextSibling(n); goto downLeftSpine; } } if (commonLen == n->partialKeyLen) { // partial key matches remaining = remaining.subspan(commonLen, remaining.size() - commonLen); } else if (n->partialKeyLen > int(remaining.size())) { // n is the first physical node greater than remaining, and there's no // eq node goto downLeftSpine; } } } downLeftSpine: if (n == nullptr) { return {nullptr, 1}; } for (;;) { if (n->entryPresent) { return {n, 1}; } int c = getChildGeq(n, 0); assert(c >= 0); n = getChildExists(n, c); } } struct __attribute__((visibility("hidden"))) ConflictSet::Impl { void check(const ReadRange *reads, Result *result, int count) { for (int i = 0; i < count; ++i) { const auto &r = reads[i]; auto begin = std::span(r.begin.p, r.begin.len); auto end = std::span(r.end.p, r.end.len); result[i] = reads[i].readVersion < oldestVersion ? TooOld : (end.size() > 0 ? checkRangeRead(root, begin, end, reads[i].readVersion, this) : checkPointRead(root, begin, reads[i].readVersion, this)) ? Commit : Conflict; } } void addWrites(const WriteRange *writes, int count, int64_t writeVersion) { for (int i = 0; i < count; ++i) { const auto &w = writes[i]; auto begin = std::span(w.begin.p, w.begin.len); auto end = std::span(w.end.p, w.end.len); if (w.end.len > 0) { keyUpdates += 3; addWriteRange(root, oldestVersion, begin, end, writeVersion, &allocators, this); } else { keyUpdates += 2; addPointWrite(root, oldestVersion, begin, writeVersion, &allocators, this); } } } void setOldestVersion(int64_t oldestVersion) { if (oldestVersion <= this->oldestVersion) { return; } this->oldestVersion = oldestVersion; if (keyUpdates < 100) { return; } Node *prev = firstGeq(root, removalKey).n; // There's no way to erase removalKey without introducing a key after it assert(prev != nullptr); for (; keyUpdates > 0; --keyUpdates) { Node *n = nextLogical(prev); if (n == nullptr) { removalKey = {}; return; } if (std::max(prev->entry.pointVersion, prev->entry.rangeVersion) <= oldestVersion) { // Any transaction prev would have prevented from committing is // going to fail with TooOld anyway. // There's no way to insert a range such that range version of the right // node is greater than the point version of the left node assert(n->entry.rangeVersion <= oldestVersion); prev->entryPresent = false; if (prev->numChildren == 0 && prev->parent != nullptr) { eraseChild(prev->parent, prev->parentsIndex, &allocators); } } prev = n; } removalKeyArena = Arena(); removalKey = getSearchPath(removalKeyArena, prev); } explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) { // Insert "" root = allocators.node0.allocate(0); rootMaxVersion = oldestVersion; root->entry.pointVersion = oldestVersion; root->entry.rangeVersion = oldestVersion; root->entryPresent = true; } ~Impl() { destroyTree(root); } NodeAllocators allocators; Arena removalKeyArena; std::span removalKey; int64_t keyUpdates = 0; Node *root; int64_t rootMaxVersion; int64_t oldestVersion; }; // Precondition - an entry for index must exist in the node int64_t &maxVersion(Node *n, ConflictSet::Impl *impl) { int index = n->parentsIndex; n = n->parent; if (n == nullptr) { return impl->rootMaxVersion; } if (n->type <= Type::Node16) { auto *n16 = static_cast(n); int i = getNodeIndex(n16, index); return n16->children[i].childMaxVersion; } else if (n->type == Type::Node48) { auto *n48 = static_cast(n); assert(n48->bitSet.test(index)); return n48->children[n48->index[index]].childMaxVersion; } else { auto *n256 = static_cast(n); assert(n256->bitSet.test(index)); return n256->children[index].childMaxVersion; } } // ==================== END IMPLEMENTATION ==================== // GCOVR_EXCL_START void ConflictSet::check(const ReadRange *reads, Result *results, int count) const { return impl->check(reads, results, count); } void ConflictSet::addWrites(const WriteRange *writes, int count, int64_t writeVersion) { return impl->addWrites(writes, count, writeVersion); } void ConflictSet::setOldestVersion(int64_t oldestVersion) { return impl->setOldestVersion(oldestVersion); } ConflictSet::ConflictSet(int64_t oldestVersion) : impl(new (safe_malloc(sizeof(Impl))) Impl{oldestVersion}) {} ConflictSet::~ConflictSet() { if (impl) { impl->~Impl(); free(impl); } } #if SHOW_MEMORY __attribute__((visibility("default"))) void showMemory(const ConflictSet &cs) { ConflictSet::Impl *impl; memcpy(&impl, &cs, sizeof(impl)); // NOLINT fprintf(stderr, "Max Node0 memory usage: %" PRId64 "\n", impl->allocators.node0.highWaterMarkBytes()); fprintf(stderr, "Max Node4 memory usage: %" PRId64 "\n", impl->allocators.node4.highWaterMarkBytes()); fprintf(stderr, "Max Node16 memory usage: %" PRId64 "\n", impl->allocators.node16.highWaterMarkBytes()); fprintf(stderr, "Max Node48 memory usage: %" PRId64 "\n", impl->allocators.node48.highWaterMarkBytes()); fprintf(stderr, "Max Node256 memory usage: %" PRId64 "\n", impl->allocators.node256.highWaterMarkBytes()); } #endif ConflictSet::ConflictSet(ConflictSet &&other) noexcept : impl(std::exchange(other.impl, nullptr)) {} ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept { impl = std::exchange(other.impl, nullptr); return *this; } using ConflictSet_Result = ConflictSet::Result; using ConflictSet_Key = ConflictSet::Key; using ConflictSet_ReadRange = ConflictSet::ReadRange; using ConflictSet_WriteRange = ConflictSet::WriteRange; extern "C" { __attribute__((__visibility__("default"))) void ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads, ConflictSet_Result *results, int count) { ((ConflictSet::Impl *)cs)->check(reads, results, count); } __attribute__((__visibility__("default"))) void ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count, int64_t writeVersion) { ((ConflictSet::Impl *)cs)->addWrites(writes, count, writeVersion); } __attribute__((__visibility__("default"))) void ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) { ((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion); } __attribute__((__visibility__("default"))) void * ConflictSet_create(int64_t oldestVersion) { return new (safe_malloc(sizeof(ConflictSet::Impl))) ConflictSet::Impl{oldestVersion}; } __attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) { using Impl = ConflictSet::Impl; ((Impl *)cs)->~Impl(); free(cs); } } namespace { std::string getSearchPathPrintable(Node *n) { Arena arena; if (n == nullptr) { return ""; } auto result = vector(arena); for (;;) { for (int i = n->partialKeyLen - 1; i >= 0; --i) { result.push_back(n->partialKey()[i]); } if (n->parent == nullptr) { break; } result.push_back(n->parentsIndex); n = n->parent; } std::reverse(result.begin(), result.end()); if (result.size() > 0) { return printable(std::string_view((const char *)&result[0], result.size())); // NOLINT } else { return std::string(); } } std::string getPartialKeyPrintable(Node *n) { Arena arena; if (n == nullptr) { return ""; } auto result = std::string((const char *)&n->parentsIndex, n->parent == nullptr ? 0 : 1) + std::string((const char *)n->partialKey(), n->partialKeyLen); return printable(result); // NOLINT } std::string strinc(std::string_view str, bool &ok) { int index; for (index = str.size() - 1; index >= 0; index--) if ((uint8_t &)(str[index]) != 255) break; // Must not be called with a string that consists only of zero or more '\xff' // bytes. if (index < 0) { ok = false; return {}; } ok = true; auto r = std::string(str.substr(0, index + 1)); ((uint8_t &)r[r.size() - 1])++; return r; } std::string getSearchPath(Node *n) { assert(n != nullptr); Arena arena; auto result = getSearchPath(arena, n); return std::string((const char *)result.data(), result.size()); } [[maybe_unused]] void debugPrintDot(FILE *file, Node *node, ConflictSet::Impl *impl) { constexpr int kSeparation = 3; struct DebugDotPrinter { explicit DebugDotPrinter(FILE *file, ConflictSet::Impl *impl) : file(file), impl(impl) {} void print(Node *n, int y = 0) { assert(n != nullptr); if (n->entryPresent) { fprintf(file, " k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", (void *)n, maxVersion(n, impl), n->entry.pointVersion, n->entry.rangeVersion, getPartialKeyPrintable(n).c_str(), x, y); } else { fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n", (void *)n, maxVersion(n, impl), getPartialKeyPrintable(n).c_str(), x, y); } x += kSeparation; for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { auto *c = getChildExists(n, child); fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c); print(c, y - kSeparation); } } int x = 0; FILE *file; ConflictSet::Impl *impl; }; fprintf(file, "digraph ConflictSet {\n"); fprintf(file, " node [shape = box];\n"); assert(node != nullptr); DebugDotPrinter printer{file, impl}; printer.print(node); fprintf(file, "}\n"); } void checkParentPointers(Node *node, bool &success) { for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); if (child->parent != node) { fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n", getSearchPathPrintable(node).c_str(), i, (void *)child->parent, (void *)node); success = false; } checkParentPointers(child, success); } } Iterator firstGeq(Node *n, std::string_view key) { return firstGeq( n, std::span((const uint8_t *)key.data(), key.size())); } [[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node, int64_t oldestVersion, bool &success, ConflictSet::Impl *impl) { int64_t expected = std::numeric_limits::lowest(); if (node->entryPresent) { expected = std::max(expected, node->entry.pointVersion); } for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); expected = std::max( expected, checkMaxVersion(root, child, oldestVersion, success, impl)); if (child->entryPresent) { expected = std::max(expected, child->entry.rangeVersion); } } auto key = getSearchPath(root); bool ok; auto inc = strinc(key, ok); if (ok) { auto borrowed = firstGeq(root, inc); if (borrowed.n != nullptr) { expected = std::max(expected, borrowed.n->entry.rangeVersion); } } if (node->parent != nullptr && getChildMaxVersion(node->parent, node->parentsIndex) != maxVersion(node, impl)) { fprintf(stderr, "%s has max version %" PRId64 " . But parent has child max version %" PRId64 "\n", getSearchPathPrintable(node).c_str(), maxVersion(node, impl), getChildMaxVersion(node->parent, node->parentsIndex)); success = false; } if (maxVersion(node, impl) > oldestVersion && maxVersion(node, impl) != expected) { fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n", getSearchPathPrintable(node).c_str(), maxVersion(node, impl), expected); success = false; } return expected; } [[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) { int64_t total = node->entryPresent; for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { auto *child = getChildExists(node, i); int64_t e = checkEntriesExist(child, success); total += e; if (e == 0) { Arena arena; fprintf(stderr, "%s has child %02x with no reachable entries\n", getSearchPathPrintable(node).c_str(), i); success = false; } } return total; } bool checkCorrectness(Node *node, int64_t oldestVersion, ConflictSet::Impl *impl) { bool success = true; checkParentPointers(node, success); checkMaxVersion(node, node, oldestVersion, success, impl); checkEntriesExist(node, success); return success; } } // namespace namespace std { void __throw_length_error(const char *) { __builtin_unreachable(); } } // namespace std #ifdef ENABLE_MAIN void printTree() { int64_t writeVersion = 0; ConflictSet::Impl cs{writeVersion}; ReferenceImpl refImpl{writeVersion}; Arena arena; ConflictSet::WriteRange write; write.begin = "and"_s; write.end = "ant"_s; cs.addWrites(&write, 1, ++writeVersion); write.begin = "any"_s; write.end = ""_s; cs.addWrites(&write, 1, ++writeVersion); write.begin = "are"_s; write.end = ""_s; cs.addWrites(&write, 1, ++writeVersion); write.begin = "art"_s; write.end = ""_s; cs.addWrites(&write, 1, ++writeVersion); debugPrintDot(stdout, cs.root, &cs); } #define ANKERL_NANOBENCH_IMPLEMENT #include "third_party/nanobench.h" int main(void) { printTree(); return 0; ankerl::nanobench::Bench bench; ConflictSet::Impl cs{0}; for (int j = 0; j < 256; ++j) { getOrCreateChild(cs.root, j, &cs.allocators) = cs.allocators.node0.allocate(0); if (j % 10 == 0) { bench.run("MaxExclusive " + std::to_string(j), [&]() { bench.doNotOptimizeAway(maxBetweenExclusive(cs.root, 0, 256)); }); } } return 0; } #endif #ifdef ENABLE_FUZZ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { TestDriver driver{data, size}; for (;;) { bool done = driver.next(); if (!driver.ok) { debugPrintDot(stdout, driver.cs.root, &driver.cs); fflush(stdout); abort(); } #if DEBUG_VERBOSE && !defined(NDEBUG) fprintf(stderr, "Check correctness\n"); #endif bool success = checkCorrectness(driver.cs.root, driver.cs.oldestVersion, &driver.cs); if (!success) { debugPrintDot(stdout, driver.cs.root, &driver.cs); fflush(stdout); abort(); } if (done) { break; } } return 0; } #endif // GCOVR_EXCL_STOP