2305 lines
66 KiB
C++
2305 lines
66 KiB
C++
/*
|
|
Copyright 2024 Andrew Noyes
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
#include "ConflictSet.h"
|
|
#include "Internal.h"
|
|
|
|
#include <algorithm>
|
|
#include <bit>
|
|
#include <cassert>
|
|
#include <compare>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <inttypes.h>
|
|
#include <limits>
|
|
#include <span>
|
|
#include <string>
|
|
#include <string_view>
|
|
#include <utility>
|
|
|
|
#ifdef HAS_AVX
|
|
#include <immintrin.h>
|
|
#elif defined(HAS_ARM_NEON)
|
|
#include <arm_neon.h>
|
|
#endif
|
|
|
|
#include <memcheck.h>
|
|
|
|
// ==================== BEGIN IMPLEMENTATION ====================
|
|
|
|
struct Entry {
|
|
int64_t pointVersion;
|
|
int64_t rangeVersion;
|
|
};
|
|
|
|
struct BitSet {
|
|
bool test(int i) const;
|
|
void set(int i);
|
|
void reset(int i);
|
|
int firstSetGeq(int i) const;
|
|
|
|
// Calls `f` with the index of each bit set in [begin, end)
|
|
template <class F> void forEachInRange(F f, int begin, int end) {
|
|
// See section 3.1 in https://arxiv.org/pdf/1709.07821.pdf for details about
|
|
// this approach
|
|
|
|
if ((begin >> 6) == (end >> 6)) {
|
|
uint64_t word = words[begin >> 6] & (uint64_t(-1) << (begin & 63)) &
|
|
~(uint64_t(-1) << (end & 63));
|
|
while (word) {
|
|
uint64_t temp = word & -word;
|
|
int index = (begin & ~63) + std::countr_zero(word);
|
|
f(index);
|
|
word ^= temp;
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Check begin partial word
|
|
if (begin & 63) {
|
|
uint64_t word = words[begin >> 6] & (uint64_t(-1) << (begin & 63));
|
|
if (std::popcount(word) + (begin & 63) == 64) {
|
|
while (begin & 63) {
|
|
f(begin++);
|
|
}
|
|
} else {
|
|
while (word) {
|
|
uint64_t temp = word & -word;
|
|
int index = (begin & ~63) + std::countr_zero(word);
|
|
f(index);
|
|
word ^= temp;
|
|
}
|
|
begin &= ~63;
|
|
begin += 64;
|
|
}
|
|
}
|
|
|
|
// Check inner, full words
|
|
while (begin != (end & ~63)) {
|
|
uint64_t word = words[begin >> 6];
|
|
if (word == uint64_t(-1)) {
|
|
for (int i = 0; i < 64; ++i) {
|
|
f(begin + i);
|
|
}
|
|
} else {
|
|
while (word) {
|
|
uint64_t temp = word & -word;
|
|
int index = begin + std::countr_zero(word);
|
|
f(index);
|
|
word ^= temp;
|
|
}
|
|
}
|
|
begin += 64;
|
|
}
|
|
|
|
if (end & 63) {
|
|
// Check end partial word
|
|
uint64_t word = words[end >> 6] & ~(uint64_t(-1) << (end & 63));
|
|
if (std::popcount(word) == (end & 63)) {
|
|
while (begin < end) {
|
|
f(begin++);
|
|
}
|
|
} else {
|
|
while (word) {
|
|
uint64_t temp = word & -word;
|
|
int index = begin + std::countr_zero(word);
|
|
f(index);
|
|
word ^= temp;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
uint64_t words[4] = {};
|
|
};
|
|
|
|
bool BitSet::test(int i) const {
|
|
assert(0 <= i);
|
|
assert(i < 256);
|
|
return words[i >> 6] & (uint64_t(1) << (i & 63));
|
|
}
|
|
|
|
void BitSet::set(int i) {
|
|
assert(0 <= i);
|
|
assert(i < 256);
|
|
words[i >> 6] |= uint64_t(1) << (i & 63);
|
|
}
|
|
|
|
void BitSet::reset(int i) {
|
|
assert(0 <= i);
|
|
assert(i < 256);
|
|
words[i >> 6] &= ~(uint64_t(1) << (i & 63));
|
|
}
|
|
|
|
int BitSet::firstSetGeq(int i) const {
|
|
assert(0 <= i);
|
|
// i may be >= 256
|
|
uint64_t mask = uint64_t(-1) << (i & 63);
|
|
for (int j = i >> 6; j < 4; ++j) {
|
|
uint64_t masked = mask & words[j];
|
|
if (masked) {
|
|
return (j << 6) + std::countr_zero(masked);
|
|
}
|
|
mask = -1;
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
enum class Type : int8_t {
|
|
Node0,
|
|
Node4,
|
|
Node16,
|
|
Node48,
|
|
Node256,
|
|
};
|
|
|
|
struct Node {
|
|
|
|
/* begin section that's copied to the next node */
|
|
Node *parent = nullptr;
|
|
Entry entry;
|
|
int32_t partialKeyLen = 0;
|
|
int16_t numChildren : 15 = 0;
|
|
bool entryPresent : 1 = false;
|
|
uint8_t parentsIndex = 0;
|
|
/* end section that's copied to the next node */
|
|
|
|
Type type;
|
|
#ifndef NDEBUG
|
|
int32_t partialKeyCapacity;
|
|
#endif
|
|
|
|
uint8_t *partialKey();
|
|
};
|
|
|
|
constexpr int kNodeCopyBegin = offsetof(Node, parent);
|
|
constexpr int kNodeCopySize = offsetof(Node, type) - kNodeCopyBegin;
|
|
|
|
struct Child {
|
|
int64_t childMaxVersion;
|
|
Node *child;
|
|
};
|
|
|
|
struct Node0 : Node {
|
|
// Sorted
|
|
uint8_t index[16]; // 16 so that we can use the same simd index search
|
|
// implementation as Node16
|
|
Node0() { this->type = Type::Node0; }
|
|
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
|
|
};
|
|
|
|
struct Node4 : Node {
|
|
// Sorted
|
|
uint8_t index[16]; // 16 so that we can use the same simd index search
|
|
// implementation as Node16
|
|
Child children[4];
|
|
Node4() { this->type = Type::Node4; }
|
|
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
|
|
};
|
|
|
|
struct Node16 : Node {
|
|
// Sorted
|
|
uint8_t index[16];
|
|
Child children[16];
|
|
Node16() { this->type = Type::Node16; }
|
|
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
|
|
};
|
|
|
|
struct Node48 : Node {
|
|
BitSet bitSet;
|
|
Child children[48];
|
|
int8_t nextFree = 0;
|
|
int8_t index[256];
|
|
Node48() {
|
|
memset(index, -1, 256);
|
|
this->type = Type::Node48;
|
|
}
|
|
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
|
|
};
|
|
|
|
struct Node256 : Node {
|
|
BitSet bitSet;
|
|
Child children[256];
|
|
Node256() {
|
|
this->type = Type::Node256;
|
|
for (int i = 0; i < 256; ++i) {
|
|
children[i].child = nullptr;
|
|
}
|
|
}
|
|
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
|
|
};
|
|
|
|
// Bounds memory usage in free list, but does not account for memory for partial
|
|
// keys.
|
|
template <class T, size_t kMemoryBound = (1 << 20)>
|
|
struct BoundedFreeListAllocator {
|
|
static_assert(sizeof(T) >= sizeof(void *));
|
|
static_assert(std::derived_from<T, Node>);
|
|
|
|
T *allocate(int partialKeyCapacity) {
|
|
#if SHOW_MEMORY
|
|
++liveAllocations;
|
|
maxLiveAllocations = std::max(maxLiveAllocations, liveAllocations);
|
|
#endif
|
|
if (freeList != nullptr) {
|
|
T *n = (T *)freeList;
|
|
VALGRIND_MAKE_MEM_DEFINED(n, sizeof(T));
|
|
if (n->partialKeyLen >= partialKeyCapacity) {
|
|
memcpy(&freeList, freeList, sizeof(freeList));
|
|
--freeListSize;
|
|
VALGRIND_MAKE_MEM_UNDEFINED(n, sizeof(T));
|
|
return new (n) T;
|
|
}
|
|
VALGRIND_MAKE_MEM_NOACCESS(n, sizeof(T));
|
|
}
|
|
|
|
auto *result = new (safe_malloc(sizeof(T) + partialKeyCapacity)) T;
|
|
#ifndef NDEBUG
|
|
result->partialKeyCapacity = partialKeyCapacity;
|
|
#endif
|
|
return result;
|
|
}
|
|
|
|
void release(T *p) {
|
|
#if SHOW_MEMORY
|
|
--liveAllocations;
|
|
#endif
|
|
p->~T();
|
|
if (freeListSize == kMaxFreeListSize) {
|
|
return free(p);
|
|
}
|
|
memcpy((void *)p, &freeList, sizeof(freeList));
|
|
freeList = p;
|
|
++freeListSize;
|
|
VALGRIND_MAKE_MEM_NOACCESS(freeList, sizeof(T));
|
|
}
|
|
|
|
~BoundedFreeListAllocator() {
|
|
for (void *iter = freeList; iter != nullptr;) {
|
|
VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(iter));
|
|
auto *tmp = iter;
|
|
memcpy(&iter, iter, sizeof(void *));
|
|
free(tmp);
|
|
}
|
|
}
|
|
|
|
#if SHOW_MEMORY
|
|
int64_t highWaterMarkBytes() const { return maxLiveAllocations * sizeof(T); }
|
|
#endif
|
|
|
|
private:
|
|
static constexpr int kMaxFreeListSize = kMemoryBound / sizeof(T);
|
|
int freeListSize = 0;
|
|
void *freeList = nullptr;
|
|
#if SHOW_MEMORY
|
|
int64_t maxLiveAllocations = 0;
|
|
int64_t liveAllocations = 0;
|
|
#endif
|
|
};
|
|
|
|
uint8_t *Node::partialKey() {
|
|
switch (type) {
|
|
case Type::Node0:
|
|
return ((Node0 *)this)->partialKey();
|
|
case Type::Node4:
|
|
return ((Node4 *)this)->partialKey();
|
|
case Type::Node16:
|
|
return ((Node16 *)this)->partialKey();
|
|
case Type::Node48:
|
|
return ((Node48 *)this)->partialKey();
|
|
case Type::Node256:
|
|
return ((Node256 *)this)->partialKey();
|
|
}
|
|
}
|
|
|
|
struct NodeAllocators {
|
|
BoundedFreeListAllocator<Node0> node0;
|
|
BoundedFreeListAllocator<Node4> node4;
|
|
BoundedFreeListAllocator<Node16> node16;
|
|
BoundedFreeListAllocator<Node48> node48;
|
|
BoundedFreeListAllocator<Node256> node256;
|
|
};
|
|
|
|
int getNodeIndex(Node16 *self, uint8_t index) {
|
|
#ifdef HAS_AVX
|
|
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
|
|
|
|
// key_vec is 16 repeated copies of the searched-for byte, one for every
|
|
// possible position in child_keys that needs to be searched.
|
|
__m128i key_vec = _mm_set1_epi8(index);
|
|
|
|
// Compare all child_keys to 'index' in parallel. Don't worry if some of the
|
|
// keys aren't valid, we'll mask the results to only consider the valid ones
|
|
// below.
|
|
__m128i indices;
|
|
memcpy(&indices, self->index, sizeof(self->index));
|
|
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
|
|
|
|
// Build a mask to select only the first node->num_children values from the
|
|
// comparison (because the other values are meaningless)
|
|
uint32_t mask = (1 << self->numChildren) - 1;
|
|
|
|
// Change the results of the comparison into a bitfield, masking off any
|
|
// invalid comparisons.
|
|
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
|
|
|
|
// No match if there are no '1's in the bitfield.
|
|
if (bitfield == 0)
|
|
return -1;
|
|
|
|
// Find the index of the first '1' in the bitfield by counting the leading
|
|
// zeros.
|
|
return std::countr_zero(bitfield);
|
|
#elif defined(HAS_ARM_NEON)
|
|
// Based on
|
|
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
|
|
|
|
uint8x16_t indices;
|
|
memcpy(&indices, self->index, sizeof(self->index));
|
|
// 0xff for each match
|
|
uint16x8_t results =
|
|
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
|
|
uint64_t mask = self->numChildren == 16
|
|
? uint64_t(-1)
|
|
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
|
// 0xf for each match in valid range
|
|
uint64_t bitfield =
|
|
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
|
|
if (bitfield == 0)
|
|
return -1;
|
|
return std::countr_zero(bitfield) / 4;
|
|
#else
|
|
for (int i = 0; i < self->numChildren; ++i) {
|
|
if (self->index[i] == index) {
|
|
return i;
|
|
}
|
|
}
|
|
return -1;
|
|
#endif
|
|
}
|
|
|
|
// Precondition - an entry for index must exist in the node
|
|
Node *&getChildExists(Node *self, uint8_t index) {
|
|
if (self->type <= Type::Node16) {
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
return self16->children[getNodeIndex(self16, index)].child;
|
|
} else if (self->type == Type::Node48) {
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
assert(self48->bitSet.test(index));
|
|
return self48->children[self48->index[index]].child;
|
|
} else {
|
|
auto *self256 = static_cast<Node256 *>(self);
|
|
assert(self256->bitSet.test(index));
|
|
return self256->children[index].child;
|
|
}
|
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
|
}
|
|
|
|
// Precondition - an entry for index must exist in the node
|
|
int64_t getChildMaxVersion(Node *self, uint8_t index) {
|
|
if (self->type <= Type::Node16) {
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
return self16->children[getNodeIndex(self16, index)].childMaxVersion;
|
|
} else if (self->type == Type::Node48) {
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
assert(self48->bitSet.test(index));
|
|
return self48->children[self48->index[index]].childMaxVersion;
|
|
} else {
|
|
auto *self256 = static_cast<Node256 *>(self);
|
|
assert(self256->bitSet.test(index));
|
|
return self256->children[index].childMaxVersion;
|
|
}
|
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
|
}
|
|
|
|
// Precondition - an entry for index must exist in the node
|
|
int64_t &maxVersion(Node *n, ConflictSet::Impl *);
|
|
|
|
Node *getChild(Node *self, uint8_t index) {
|
|
if (self->type <= Type::Node16) {
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
int i = getNodeIndex(self16, index);
|
|
if (i >= 0) {
|
|
return self16->children[i].child;
|
|
}
|
|
return nullptr;
|
|
} else if (self->type == Type::Node48) {
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
int secondIndex = self48->index[index];
|
|
if (secondIndex >= 0) {
|
|
return self48->children[secondIndex].child;
|
|
}
|
|
return nullptr;
|
|
} else {
|
|
auto *self256 = static_cast<Node256 *>(self);
|
|
return self256->children[index].child;
|
|
}
|
|
}
|
|
|
|
int getChildGeq(Node *self, int child) {
|
|
if (child > 255) {
|
|
return -1;
|
|
}
|
|
if (self->type <= Type::Node16) {
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
#ifdef HAS_AVX
|
|
__m128i key_vec = _mm_set1_epi8(child);
|
|
__m128i indices;
|
|
memcpy(&indices, self16->index, sizeof(self16->index));
|
|
__m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices));
|
|
int mask = (1 << self16->numChildren) - 1;
|
|
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
|
|
int result = bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield)];
|
|
assert(result == [&]() -> int {
|
|
for (int i = 0; i < self16->numChildren; ++i) {
|
|
if (self16->index[i] >= child) {
|
|
return self16->index[i];
|
|
}
|
|
}
|
|
return -1;
|
|
}());
|
|
return result;
|
|
#elif defined(HAS_ARM_NEON)
|
|
uint8x16_t indices;
|
|
memcpy(&indices, self16->index, sizeof(self16->index));
|
|
// 0xff for each leq
|
|
auto results = vcleq_u8(vdupq_n_u8(child), indices);
|
|
uint64_t mask = self->numChildren == 16
|
|
? uint64_t(-1)
|
|
: (uint64_t(1) << (self->numChildren * 4)) - 1;
|
|
// 0xf for each 0xff (within mask)
|
|
uint64_t bitfield =
|
|
vget_lane_u64(
|
|
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
|
|
0) &
|
|
mask;
|
|
int simd =
|
|
bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield) / 4];
|
|
assert(simd == [&]() -> int {
|
|
for (int i = 0; i < self->numChildren; ++i) {
|
|
if (self16->index[i] >= child) {
|
|
return self16->index[i];
|
|
}
|
|
}
|
|
return -1;
|
|
}());
|
|
return simd;
|
|
#else
|
|
for (int i = 0; i < self->numChildren; ++i) {
|
|
if (i > 0) {
|
|
assert(self16->index[i - 1] < self16->index[i]);
|
|
}
|
|
if (self16->index[i] >= child) {
|
|
return self16->index[i];
|
|
}
|
|
}
|
|
#endif
|
|
} else {
|
|
static_assert(offsetof(Node48, bitSet) == offsetof(Node256, bitSet));
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
return self48->bitSet.firstSetGeq(child);
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
void setChildrenParents(Node4 *n) {
|
|
for (int i = 0; i < n->numChildren; ++i) {
|
|
n->children[i].child->parent = n;
|
|
}
|
|
}
|
|
|
|
void setChildrenParents(Node16 *n) {
|
|
for (int i = 0; i < n->numChildren; ++i) {
|
|
n->children[i].child->parent = n;
|
|
}
|
|
}
|
|
|
|
void setChildrenParents(Node48 *n) {
|
|
n->bitSet.forEachInRange(
|
|
[&](int i) { n->children[n->index[i]].child->parent = n; }, 0, 256);
|
|
}
|
|
|
|
void setChildrenParents(Node256 *n) {
|
|
n->bitSet.forEachInRange([&](int i) { n->children[i].child->parent = n; }, 0,
|
|
256);
|
|
}
|
|
|
|
// Caller is responsible for assigning a non-null pointer to the returned
|
|
// reference if null
|
|
Node *&getOrCreateChild(Node *&self, uint8_t index,
|
|
NodeAllocators *allocators) {
|
|
|
|
// Fast path for if it exists already
|
|
if (self->type <= Type::Node16) {
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
int i = getNodeIndex(self16, index);
|
|
if (i >= 0) {
|
|
return self16->children[i].child;
|
|
}
|
|
} else if (self->type == Type::Node48) {
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
int secondIndex = self48->index[index];
|
|
if (secondIndex >= 0) {
|
|
return self48->children[secondIndex].child;
|
|
}
|
|
} else {
|
|
auto *self256 = static_cast<Node256 *>(self);
|
|
if (auto &result = self256->children[index].child; result != nullptr) {
|
|
return result;
|
|
}
|
|
}
|
|
|
|
if (self->type == Type::Node0) {
|
|
auto *self0 = static_cast<Node0 *>(self);
|
|
|
|
auto *newSelf = allocators->node4.allocate(self->partialKeyLen);
|
|
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
|
|
kNodeCopySize);
|
|
memcpy(newSelf->partialKey(), self0->partialKey(), self->partialKeyLen);
|
|
allocators->node0.release(self0);
|
|
self = newSelf;
|
|
|
|
goto insert16;
|
|
|
|
} else if (self->type == Type::Node4) {
|
|
auto *self4 = static_cast<Node4 *>(self);
|
|
|
|
if (self->numChildren == 4) {
|
|
auto *newSelf = allocators->node16.allocate(self->partialKeyLen);
|
|
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
|
|
kNodeCopySize);
|
|
memcpy(newSelf->partialKey(), self4->partialKey(), self->partialKeyLen);
|
|
// TODO replace with memcpy?
|
|
for (int i = 0; i < 4; ++i) {
|
|
newSelf->index[i] = self4->index[i];
|
|
newSelf->children[i] = self4->children[i];
|
|
}
|
|
allocators->node4.release(self4);
|
|
setChildrenParents(newSelf);
|
|
self = newSelf;
|
|
}
|
|
|
|
goto insert16;
|
|
|
|
} else if (self->type == Type::Node16) {
|
|
|
|
if (self->numChildren == 16) {
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
auto *newSelf = allocators->node48.allocate(self->partialKeyLen);
|
|
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
|
|
kNodeCopySize);
|
|
memcpy(newSelf->partialKey(), self16->partialKey(), self->partialKeyLen);
|
|
newSelf->nextFree = 16;
|
|
int i = 0;
|
|
for (auto x : self16->index) {
|
|
newSelf->bitSet.set(x);
|
|
newSelf->children[i] = self16->children[i];
|
|
newSelf->index[x] = i;
|
|
++i;
|
|
}
|
|
assert(i == 16);
|
|
allocators->node16.release(self16);
|
|
setChildrenParents(newSelf);
|
|
self = newSelf;
|
|
goto insert48;
|
|
}
|
|
|
|
insert16:
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
|
|
++self->numChildren;
|
|
int i = 0;
|
|
for (; i < int(self->numChildren) - 1; ++i) {
|
|
if (int(self16->index[i]) > int(index)) {
|
|
memmove(self16->index + i + 1, self16->index + i,
|
|
self->numChildren - (i + 1));
|
|
memmove(self16->children + i + 1, self16->children + i,
|
|
(self->numChildren - (i + 1)) * sizeof(Child));
|
|
break;
|
|
}
|
|
}
|
|
self16->index[i] = index;
|
|
auto &result = self16->children[i].child;
|
|
result = nullptr;
|
|
return result;
|
|
} else if (self->type == Type::Node48) {
|
|
|
|
if (self->numChildren == 48) {
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
auto *newSelf = allocators->node256.allocate(self->partialKeyLen);
|
|
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
|
|
kNodeCopySize);
|
|
memcpy(newSelf->partialKey(), self48->partialKey(), self->partialKeyLen);
|
|
newSelf->bitSet = self48->bitSet;
|
|
newSelf->bitSet.forEachInRange(
|
|
[&](int i) {
|
|
newSelf->children[i] = self48->children[self48->index[i]];
|
|
},
|
|
0, 256);
|
|
allocators->node48.release(self48);
|
|
setChildrenParents(newSelf);
|
|
self = newSelf;
|
|
goto insert256;
|
|
}
|
|
insert48:
|
|
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
self48->bitSet.set(index);
|
|
++self->numChildren;
|
|
assert(self48->nextFree < 48);
|
|
int nextFree = self48->nextFree++;
|
|
self48->index[index] = nextFree;
|
|
auto &result = self48->children[nextFree].child;
|
|
result = nullptr;
|
|
return result;
|
|
} else {
|
|
assert(self->type == Type::Node256);
|
|
insert256:
|
|
auto *self256 = static_cast<Node256 *>(self);
|
|
++self->numChildren;
|
|
self256->bitSet.set(index);
|
|
return self256->children[index].child;
|
|
}
|
|
}
|
|
|
|
// Precondition - an entry for index must exist in the node
|
|
void eraseChild(Node *self, uint8_t index, NodeAllocators *allocators) {
|
|
auto *child = getChildExists(self, index);
|
|
switch (child->type) {
|
|
case Type::Node0:
|
|
allocators->node0.release((Node0 *)child);
|
|
break;
|
|
case Type::Node4:
|
|
allocators->node4.release((Node4 *)child);
|
|
break;
|
|
case Type::Node16:
|
|
allocators->node16.release((Node16 *)child);
|
|
break;
|
|
case Type::Node48:
|
|
allocators->node48.release((Node48 *)child);
|
|
break;
|
|
case Type::Node256:
|
|
allocators->node256.release((Node256 *)child);
|
|
break;
|
|
}
|
|
|
|
if (self->type <= Type::Node16) {
|
|
auto *self16 = static_cast<Node16 *>(self);
|
|
int nodeIndex = getNodeIndex(self16, index);
|
|
memmove(self16->index + nodeIndex, self16->index + nodeIndex + 1,
|
|
sizeof(self16->index[0]) * (self->numChildren - (nodeIndex + 1)));
|
|
memmove(self16->children + nodeIndex, self16->children + nodeIndex + 1,
|
|
sizeof(self16->children[0]) * // NOLINT
|
|
(self->numChildren - (nodeIndex + 1)));
|
|
} else if (self->type == Type::Node48) {
|
|
auto *self48 = static_cast<Node48 *>(self);
|
|
self48->bitSet.reset(index);
|
|
int8_t toRemoveChildrenIndex = std::exchange(self48->index[index], -1);
|
|
int8_t lastChildrenIndex = --self48->nextFree;
|
|
assert(toRemoveChildrenIndex >= 0);
|
|
assert(lastChildrenIndex >= 0);
|
|
if (toRemoveChildrenIndex != lastChildrenIndex) {
|
|
self48->children[toRemoveChildrenIndex] =
|
|
self48->children[lastChildrenIndex];
|
|
self48
|
|
->index[self48->children[toRemoveChildrenIndex].child->parentsIndex] =
|
|
toRemoveChildrenIndex;
|
|
}
|
|
} else {
|
|
auto *self256 = static_cast<Node256 *>(self);
|
|
self256->bitSet.reset(index);
|
|
self256->children[index].child = nullptr;
|
|
}
|
|
--self->numChildren;
|
|
if (self->numChildren == 0 && !self->entryPresent &&
|
|
self->parent != nullptr) {
|
|
eraseChild(self->parent, self->parentsIndex, allocators);
|
|
}
|
|
}
|
|
|
|
Node *nextPhysical(Node *node) {
|
|
int index = -1;
|
|
for (;;) {
|
|
auto nextChild = getChildGeq(node, index + 1);
|
|
if (nextChild >= 0) {
|
|
return getChildExists(node, nextChild);
|
|
}
|
|
index = node->parentsIndex;
|
|
node = node->parent;
|
|
if (node == nullptr) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
Node *nextLogical(Node *node) {
|
|
for (node = nextPhysical(node); node != nullptr && !node->entryPresent;
|
|
node = nextPhysical(node))
|
|
;
|
|
return node;
|
|
}
|
|
|
|
struct Iterator {
|
|
Node *n;
|
|
int cmp;
|
|
};
|
|
|
|
Node *nextSibling(Node *node) {
|
|
for (;;) {
|
|
if (node->parent == nullptr) {
|
|
return nullptr;
|
|
}
|
|
auto next = getChildGeq(node->parent, node->parentsIndex + 1);
|
|
if (next < 0) {
|
|
node = node->parent;
|
|
} else {
|
|
return getChildExists(node->parent, next);
|
|
}
|
|
}
|
|
}
|
|
|
|
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
|
|
constexpr int kStride = 64;
|
|
#else
|
|
constexpr int kStride = 16;
|
|
#endif
|
|
|
|
constexpr int kUnrollFactor = 4;
|
|
|
|
bool compareStride(const uint8_t *ap, const uint8_t *bp) {
|
|
#if defined(HAS_ARM_NEON)
|
|
static_assert(kStride == 64);
|
|
uint8x16_t x[4];
|
|
for (int i = 0; i < 4; ++i) {
|
|
x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16));
|
|
}
|
|
auto results = vreinterpretq_u16_u8(
|
|
vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3])));
|
|
bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) ==
|
|
uint64_t(-1);
|
|
#elif defined(HAS_AVX)
|
|
static_assert(kStride == 64);
|
|
__m128i x[4];
|
|
for (int i = 0; i < 4; ++i) {
|
|
x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)),
|
|
_mm_loadu_si128((__m128i *)(bp + i * 16)));
|
|
}
|
|
auto eq =
|
|
_mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]),
|
|
_mm_and_si128(x[2], x[3]))) == 0xffff;
|
|
#else
|
|
// Hope it gets vectorized
|
|
auto eq = memcmp(ap, bp, kStride) == 0;
|
|
#endif
|
|
return eq;
|
|
}
|
|
|
|
// Precondition: ap[:kStride] != bp[:kStride]
|
|
int firstNeqStride(const uint8_t *ap, const uint8_t *bp) {
|
|
#if defined(HAS_AVX)
|
|
static_assert(kStride == 64);
|
|
uint64_t c[kStride / 16];
|
|
for (int i = 0; i < kStride; i += 16) {
|
|
const auto a = _mm_loadu_si128((__m128i *)(ap + i));
|
|
const auto b = _mm_loadu_si128((__m128i *)(bp + i));
|
|
const auto compared = _mm_cmpeq_epi8(a, b);
|
|
c[i / 16] = _mm_movemask_epi8(compared) & 0xffff;
|
|
}
|
|
return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48));
|
|
#elif defined(HAS_ARM_NEON)
|
|
static_assert(kStride == 64);
|
|
for (int i = 0; i < kStride; i += 16) {
|
|
// 0xff for each match
|
|
uint16x8_t results =
|
|
vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i)));
|
|
// 0xf for each mismatch
|
|
uint64_t bitfield =
|
|
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
|
|
if (bitfield) {
|
|
return i + (std::countr_zero(bitfield) >> 2);
|
|
}
|
|
}
|
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
|
#else
|
|
int i = 0;
|
|
for (; i < kStride - 1; ++i) {
|
|
if (*ap++ != *bp++) {
|
|
break;
|
|
}
|
|
}
|
|
return i;
|
|
#endif
|
|
}
|
|
|
|
int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
|
|
if (cl < 0) {
|
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
|
}
|
|
int i = 0;
|
|
int end;
|
|
|
|
if (cl < 8) {
|
|
goto bytes;
|
|
}
|
|
|
|
// Optimistic early return
|
|
{
|
|
uint64_t a;
|
|
uint64_t b;
|
|
memcpy(&a, ap, 8);
|
|
memcpy(&b, bp, 8);
|
|
const auto mismatched = a ^ b;
|
|
if (mismatched) {
|
|
return std::countr_zero(mismatched) / 8;
|
|
}
|
|
}
|
|
|
|
// kStride * kUnrollCount at a time
|
|
end = cl & ~(kStride * kUnrollFactor - 1);
|
|
while (i < end) {
|
|
for (int j = 0; j < kUnrollFactor; ++j) {
|
|
if (!compareStride(ap, bp)) {
|
|
return i + firstNeqStride(ap, bp);
|
|
}
|
|
i += kStride;
|
|
ap += kStride;
|
|
bp += kStride;
|
|
}
|
|
}
|
|
|
|
// kStride at a time
|
|
end = cl & ~(kStride - 1);
|
|
while (i < end) {
|
|
if (!compareStride(ap, bp)) {
|
|
return i + firstNeqStride(ap, bp);
|
|
}
|
|
i += kStride;
|
|
ap += kStride;
|
|
bp += kStride;
|
|
}
|
|
|
|
// word at a time
|
|
end = cl & ~(sizeof(uint64_t) - 1);
|
|
while (i < end) {
|
|
uint64_t a;
|
|
uint64_t b;
|
|
memcpy(&a, ap, 8);
|
|
memcpy(&b, bp, 8);
|
|
const auto mismatched = a ^ b;
|
|
if (mismatched) {
|
|
return i + std::countr_zero(mismatched) / 8;
|
|
}
|
|
i += 8;
|
|
ap += 8;
|
|
bp += 8;
|
|
}
|
|
|
|
bytes:
|
|
// byte at a time
|
|
while (i < cl) {
|
|
if (*ap != *bp) {
|
|
break;
|
|
}
|
|
++ap;
|
|
++bp;
|
|
++i;
|
|
}
|
|
|
|
return i;
|
|
}
|
|
|
|
int longestCommonPrefixPartialKey(const uint8_t *ap, const uint8_t *bp,
|
|
int cl) {
|
|
return longestCommonPrefix(ap, bp, cl);
|
|
}
|
|
|
|
// Performs a physical search for remaining
|
|
struct SearchStepWise {
|
|
Node *n;
|
|
std::span<const uint8_t> remaining;
|
|
|
|
SearchStepWise() {}
|
|
SearchStepWise(Node *n, std::span<const uint8_t> remaining)
|
|
: n(n), remaining(remaining) {
|
|
assert(n->partialKeyLen == 0);
|
|
}
|
|
|
|
bool step() {
|
|
if (remaining.size() == 0) {
|
|
return true;
|
|
}
|
|
auto *child = getChild(n, remaining[0]);
|
|
if (child == nullptr) {
|
|
return true;
|
|
}
|
|
int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1);
|
|
int i = longestCommonPrefixPartialKey(child->partialKey(),
|
|
remaining.data() + 1, cl);
|
|
if (i != child->partialKeyLen) {
|
|
return true;
|
|
}
|
|
n = child;
|
|
remaining =
|
|
remaining.subspan(1 + child->partialKeyLen,
|
|
remaining.size() - (1 + child->partialKeyLen));
|
|
return false;
|
|
}
|
|
};
|
|
|
|
namespace {
|
|
std::string getSearchPathPrintable(Node *n);
|
|
}
|
|
|
|
// Logically this is the same as performing firstGeq and then checking against
|
|
// point or range version according to cmp, but this version short circuits as
|
|
// soon as it can prove that there's no conflict.
|
|
bool checkPointRead(Node *n, const std::span<const uint8_t> key,
|
|
int64_t readVersion, ConflictSet::Impl *impl) {
|
|
#if DEBUG_VERBOSE && !defined(NDEBUG)
|
|
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
|
|
#endif
|
|
auto remaining = key;
|
|
for (;;) {
|
|
if (maxVersion(n, impl) <= readVersion) {
|
|
return true;
|
|
}
|
|
if (remaining.size() == 0) {
|
|
if (n->entryPresent) {
|
|
return n->entry.pointVersion <= readVersion;
|
|
}
|
|
int c = getChildGeq(n, 0);
|
|
assert(c >= 0);
|
|
n = getChildExists(n, c);
|
|
goto downLeftSpine;
|
|
}
|
|
|
|
auto *child = getChild(n, remaining[0]);
|
|
if (child == nullptr) {
|
|
int c = getChildGeq(n, remaining[0]);
|
|
if (c >= 0) {
|
|
n = getChildExists(n, c);
|
|
goto downLeftSpine;
|
|
} else {
|
|
n = nextSibling(n);
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
|
|
n = child;
|
|
remaining = remaining.subspan(1, remaining.size() - 1);
|
|
|
|
if (n->partialKeyLen > 0) {
|
|
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
|
|
int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
|
|
commonLen);
|
|
if (i < commonLen) {
|
|
auto c = n->partialKey()[i] <=> remaining[i];
|
|
if (c > 0) {
|
|
goto downLeftSpine;
|
|
} else {
|
|
n = nextSibling(n);
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
if (commonLen == n->partialKeyLen) {
|
|
// partial key matches
|
|
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
|
|
} else if (n->partialKeyLen > int(remaining.size())) {
|
|
// n is the first physical node greater than remaining, and there's no
|
|
// eq node
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
}
|
|
downLeftSpine:
|
|
if (n == nullptr) {
|
|
return true;
|
|
}
|
|
for (;;) {
|
|
if (n->entryPresent) {
|
|
return n->entry.rangeVersion <= readVersion;
|
|
}
|
|
int c = getChildGeq(n, 0);
|
|
assert(c >= 0);
|
|
n = getChildExists(n, c);
|
|
}
|
|
}
|
|
|
|
// Return the max version among all keys starting with the search path of n +
|
|
// [child], where child in (begin, end)
|
|
int64_t maxBetweenExclusive(Node *n, int begin, int end) {
|
|
assert(-1 <= begin);
|
|
assert(begin <= 256);
|
|
assert(-1 <= end);
|
|
assert(end <= 256);
|
|
assert(begin < end);
|
|
int64_t result = std::numeric_limits<int64_t>::lowest();
|
|
{
|
|
int c = getChildGeq(n, begin + 1);
|
|
if (c >= 0 && c < end) {
|
|
auto *child = getChildExists(n, c);
|
|
if (child->entryPresent) {
|
|
result = std::max(result, child->entry.rangeVersion);
|
|
}
|
|
begin = c;
|
|
} else {
|
|
return result;
|
|
}
|
|
}
|
|
switch (n->type) {
|
|
case Type::Node0:
|
|
[[fallthrough]];
|
|
case Type::Node4:
|
|
[[fallthrough]];
|
|
case Type::Node16: {
|
|
auto *self = static_cast<Node16 *>(n);
|
|
for (int i = 0; i < self->numChildren && self->index[i] < end; ++i) {
|
|
if (begin <= self->index[i]) {
|
|
result = std::max(result, self->children[i].childMaxVersion);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case Type::Node48: {
|
|
auto *self = static_cast<Node48 *>(n);
|
|
self->bitSet.forEachInRange(
|
|
[&](int i) {
|
|
result =
|
|
std::max(result, self->children[self->index[i]].childMaxVersion);
|
|
},
|
|
begin, end);
|
|
break;
|
|
}
|
|
case Type::Node256: {
|
|
auto *self = static_cast<Node256 *>(n);
|
|
self->bitSet.forEachInRange(
|
|
[&](int i) {
|
|
result = std::max(result, self->children[i].childMaxVersion);
|
|
},
|
|
begin, end);
|
|
break;
|
|
}
|
|
}
|
|
#if DEBUG_VERBOSE && !defined(NDEBUG)
|
|
fprintf(stderr, "At `%s', max version in (%02x, %02x) is %" PRId64 "\n",
|
|
getSearchPathPrintable(n).c_str(), begin, end, result);
|
|
#endif
|
|
return result;
|
|
}
|
|
|
|
Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
|
|
assert(n != nullptr);
|
|
auto result = vector<uint8_t>(arena);
|
|
for (;;) {
|
|
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
|
|
result.push_back(n->partialKey()[i]);
|
|
}
|
|
if (n->parent == nullptr) {
|
|
break;
|
|
}
|
|
result.push_back(n->parentsIndex);
|
|
n = n->parent;
|
|
}
|
|
std::reverse(result.begin(), result.end());
|
|
return result;
|
|
}
|
|
|
|
// Return true if the max version among all keys that start with key + [child],
|
|
// where begin < child < end, is <= readVersion
|
|
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
|
|
int end, int64_t readVersion,
|
|
ConflictSet::Impl *impl) {
|
|
#if DEBUG_VERBOSE && !defined(NDEBUG)
|
|
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
|
|
#endif
|
|
auto remaining = key;
|
|
for (;;) {
|
|
if (maxVersion(n, impl) <= readVersion) {
|
|
return true;
|
|
}
|
|
if (remaining.size() == 0) {
|
|
return maxBetweenExclusive(n, begin, end) <= readVersion;
|
|
}
|
|
|
|
auto *child = getChild(n, remaining[0]);
|
|
if (child == nullptr) {
|
|
int c = getChildGeq(n, remaining[0]);
|
|
if (c >= 0) {
|
|
n = getChildExists(n, c);
|
|
goto downLeftSpine;
|
|
} else {
|
|
n = nextSibling(n);
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
|
|
n = child;
|
|
remaining = remaining.subspan(1, remaining.size() - 1);
|
|
|
|
if (n->partialKeyLen > 0) {
|
|
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
|
|
int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
|
|
commonLen);
|
|
if (i < commonLen) {
|
|
auto c = n->partialKey()[i] <=> remaining[i];
|
|
if (c > 0) {
|
|
goto downLeftSpine;
|
|
} else {
|
|
n = nextSibling(n);
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
if (commonLen == n->partialKeyLen) {
|
|
// partial key matches
|
|
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
|
|
} else if (n->partialKeyLen > int(remaining.size())) {
|
|
if (begin < n->partialKey()[remaining.size()] &&
|
|
n->partialKey()[remaining.size()] < end) {
|
|
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
|
|
return false;
|
|
}
|
|
return maxVersion(n, impl) <= readVersion;
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
downLeftSpine:
|
|
if (n == nullptr) {
|
|
return true;
|
|
}
|
|
for (;;) {
|
|
if (n->entryPresent) {
|
|
return n->entry.rangeVersion <= readVersion;
|
|
}
|
|
int c = getChildGeq(n, 0);
|
|
assert(c >= 0);
|
|
n = getChildExists(n, c);
|
|
}
|
|
}
|
|
|
|
// Return true if the max version among all keys that start with key[:prefixLen]
|
|
// that are >= key is <= readVersion
|
|
struct CheckRangeLeftSide {
|
|
CheckRangeLeftSide(Node *n, std::span<const uint8_t> key, int prefixLen,
|
|
int64_t readVersion, ConflictSet::Impl *impl)
|
|
: n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion),
|
|
impl(impl) {
|
|
#if DEBUG_VERBOSE && !defined(NDEBUG)
|
|
fprintf(stderr, "Check range left side from %s for keys starting with %s\n",
|
|
printable(key).c_str(),
|
|
printable(key.subspan(0, prefixLen)).c_str());
|
|
#endif
|
|
}
|
|
|
|
Node *n;
|
|
std::span<const uint8_t> remaining;
|
|
int prefixLen;
|
|
int64_t readVersion;
|
|
ConflictSet::Impl *impl;
|
|
int searchPathLen = 0;
|
|
bool ok;
|
|
|
|
enum Phase { Search, DownLeftSpine } phase = Search;
|
|
|
|
bool step() {
|
|
switch (phase) {
|
|
case Search: {
|
|
if (maxVersion(n, impl) <= readVersion) {
|
|
ok = true;
|
|
return true;
|
|
}
|
|
if (remaining.size() == 0) {
|
|
assert(searchPathLen >= prefixLen);
|
|
ok = maxVersion(n, impl) <= readVersion;
|
|
return true;
|
|
}
|
|
|
|
if (searchPathLen >= prefixLen) {
|
|
if (maxBetweenExclusive(n, remaining[0], 256) > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
}
|
|
|
|
auto *child = getChild(n, remaining[0]);
|
|
if (child == nullptr) {
|
|
int c = getChildGeq(n, remaining[0]);
|
|
if (c >= 0) {
|
|
if (searchPathLen < prefixLen) {
|
|
n = getChildExists(n, c);
|
|
return downLeftSpine();
|
|
}
|
|
n = getChildExists(n, c);
|
|
ok = maxVersion(n, impl) <= readVersion;
|
|
return true;
|
|
} else {
|
|
n = nextSibling(n);
|
|
return downLeftSpine();
|
|
}
|
|
}
|
|
|
|
n = child;
|
|
remaining = remaining.subspan(1, remaining.size() - 1);
|
|
++searchPathLen;
|
|
|
|
if (n->partialKeyLen > 0) {
|
|
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
|
|
int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
|
|
commonLen);
|
|
searchPathLen += i;
|
|
if (i < commonLen) {
|
|
auto c = n->partialKey()[i] <=> remaining[i];
|
|
if (c > 0) {
|
|
if (searchPathLen < prefixLen) {
|
|
return downLeftSpine();
|
|
}
|
|
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
ok = maxVersion(n, impl) <= readVersion;
|
|
return true;
|
|
} else {
|
|
n = nextSibling(n);
|
|
return downLeftSpine();
|
|
}
|
|
}
|
|
if (commonLen == n->partialKeyLen) {
|
|
// partial key matches
|
|
remaining =
|
|
remaining.subspan(commonLen, remaining.size() - commonLen);
|
|
} else if (n->partialKeyLen > int(remaining.size())) {
|
|
if (searchPathLen < prefixLen) {
|
|
return downLeftSpine();
|
|
}
|
|
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
ok = maxVersion(n, impl) <= readVersion;
|
|
return true;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case DownLeftSpine:
|
|
if (n->entryPresent) {
|
|
ok = n->entry.rangeVersion <= readVersion;
|
|
return true;
|
|
}
|
|
int c = getChildGeq(n, 0);
|
|
assert(c >= 0);
|
|
n = getChildExists(n, c);
|
|
break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool downLeftSpine() {
|
|
phase = DownLeftSpine;
|
|
if (n == nullptr) {
|
|
ok = true;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
};
|
|
|
|
// Return true if the max version among all keys that start with key[:prefixLen]
|
|
// that are < key is <= readVersion
|
|
struct CheckRangeRightSide {
|
|
CheckRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen,
|
|
int64_t readVersion, ConflictSet::Impl *impl)
|
|
: n(n), key(key), remaining(key), prefixLen(prefixLen),
|
|
readVersion(readVersion), impl(impl) {
|
|
#if DEBUG_VERBOSE && !defined(NDEBUG)
|
|
fprintf(stderr, "Check range right side to %s for keys starting with %s\n",
|
|
printable(key).c_str(),
|
|
printable(key.subspan(0, prefixLen)).c_str());
|
|
#endif
|
|
}
|
|
|
|
Node *n;
|
|
std::span<const uint8_t> key;
|
|
std::span<const uint8_t> remaining;
|
|
int prefixLen;
|
|
int64_t readVersion;
|
|
ConflictSet::Impl *impl;
|
|
int searchPathLen = 0;
|
|
bool ok;
|
|
|
|
enum Phase { Search, DownLeftSpine } phase = Search;
|
|
|
|
bool step() {
|
|
switch (phase) {
|
|
case Search: {
|
|
#if DEBUG_VERBOSE && !defined(NDEBUG)
|
|
fprintf(
|
|
stderr,
|
|
"Search path: %s, searchPathLen: %d, prefixLen: %d, remaining: %s\n",
|
|
getSearchPathPrintable(n).c_str(), searchPathLen, prefixLen,
|
|
printable(remaining).c_str());
|
|
#endif
|
|
|
|
assert(searchPathLen <= int(key.size()));
|
|
|
|
if (remaining.size() == 0) {
|
|
return downLeftSpine();
|
|
}
|
|
|
|
if (searchPathLen >= prefixLen) {
|
|
if (n->entryPresent && n->entry.pointVersion > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
|
|
if (maxBetweenExclusive(n, -1, remaining[0]) > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
}
|
|
|
|
if (searchPathLen > prefixLen && n->entryPresent &&
|
|
n->entry.rangeVersion > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
|
|
auto *child = getChild(n, remaining[0]);
|
|
if (child == nullptr) {
|
|
int c = getChildGeq(n, remaining[0]);
|
|
if (c >= 0) {
|
|
n = getChildExists(n, c);
|
|
return downLeftSpine();
|
|
} else {
|
|
return backtrack();
|
|
}
|
|
}
|
|
|
|
n = child;
|
|
remaining = remaining.subspan(1, remaining.size() - 1);
|
|
++searchPathLen;
|
|
|
|
if (n->partialKeyLen > 0) {
|
|
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
|
|
int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
|
|
commonLen);
|
|
searchPathLen += i;
|
|
if (i < commonLen) {
|
|
++searchPathLen;
|
|
auto c = n->partialKey()[i] <=> remaining[i];
|
|
if (c > 0) {
|
|
return downLeftSpine();
|
|
} else {
|
|
if (searchPathLen > prefixLen && n->entryPresent &&
|
|
n->entry.rangeVersion > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
return backtrack();
|
|
}
|
|
}
|
|
if (commonLen == n->partialKeyLen) {
|
|
// partial key matches
|
|
remaining =
|
|
remaining.subspan(commonLen, remaining.size() - commonLen);
|
|
} else if (n->partialKeyLen > int(remaining.size())) {
|
|
return downLeftSpine();
|
|
}
|
|
}
|
|
} break;
|
|
case DownLeftSpine:
|
|
if (n->entryPresent) {
|
|
ok = n->entry.rangeVersion <= readVersion;
|
|
return true;
|
|
}
|
|
int c = getChildGeq(n, 0);
|
|
assert(c >= 0);
|
|
n = getChildExists(n, c);
|
|
break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool backtrack() {
|
|
for (;;) {
|
|
if (searchPathLen > prefixLen && maxVersion(n, impl) > readVersion) {
|
|
ok = false;
|
|
return true;
|
|
}
|
|
if (n->parent == nullptr) {
|
|
ok = true;
|
|
return true;
|
|
}
|
|
auto next = getChildGeq(n->parent, n->parentsIndex + 1);
|
|
if (next < 0) {
|
|
searchPathLen -= 1 + n->partialKeyLen;
|
|
n = n->parent;
|
|
} else {
|
|
searchPathLen -= n->partialKeyLen;
|
|
n = getChildExists(n->parent, next);
|
|
searchPathLen += n->partialKeyLen;
|
|
return downLeftSpine();
|
|
}
|
|
}
|
|
}
|
|
|
|
bool downLeftSpine() {
|
|
phase = DownLeftSpine;
|
|
if (n == nullptr) {
|
|
ok = true;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
};
|
|
|
|
bool checkRangeRead(Node *n, std::span<const uint8_t> begin,
|
|
std::span<const uint8_t> end, int64_t readVersion,
|
|
ConflictSet::Impl *impl) {
|
|
int lcp = longestCommonPrefix(begin.data(), end.data(),
|
|
std::min(begin.size(), end.size()));
|
|
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
|
|
end.back() == 0) {
|
|
return checkPointRead(n, begin, readVersion, impl);
|
|
}
|
|
|
|
SearchStepWise search{n, begin.subspan(0, lcp)};
|
|
Arena arena;
|
|
for (;;) {
|
|
assert(getSearchPath(arena, search.n) <=>
|
|
begin.subspan(0, lcp - search.remaining.size()) ==
|
|
0);
|
|
if (maxVersion(search.n, impl) <= readVersion) {
|
|
return true;
|
|
}
|
|
if (search.step()) {
|
|
break;
|
|
}
|
|
}
|
|
assert(getSearchPath(arena, search.n) <=>
|
|
begin.subspan(0, lcp - search.remaining.size()) ==
|
|
0);
|
|
|
|
const int consumed = lcp - search.remaining.size();
|
|
assert(consumed >= 0);
|
|
|
|
begin = begin.subspan(consumed, int(begin.size()) - consumed);
|
|
end = end.subspan(consumed, int(end.size()) - consumed);
|
|
n = search.n;
|
|
lcp -= consumed;
|
|
|
|
if (lcp == int(begin.size())) {
|
|
CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion, impl};
|
|
while (!checkRangeRightSide.step())
|
|
;
|
|
return checkRangeRightSide.ok;
|
|
}
|
|
|
|
if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp],
|
|
readVersion, impl)) {
|
|
return false;
|
|
}
|
|
|
|
CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion, impl};
|
|
CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion, impl};
|
|
|
|
for (;;) {
|
|
bool leftDone = checkRangeLeftSide.step();
|
|
bool rightDone = checkRangeRightSide.step();
|
|
if (!leftDone && !rightDone) {
|
|
continue;
|
|
}
|
|
if (leftDone && rightDone) {
|
|
break;
|
|
} else if (leftDone) {
|
|
while (!checkRangeRightSide.step())
|
|
;
|
|
break;
|
|
} else {
|
|
assert(rightDone);
|
|
while (!checkRangeLeftSide.step())
|
|
;
|
|
}
|
|
break;
|
|
}
|
|
|
|
return checkRangeLeftSide.ok & checkRangeRightSide.ok;
|
|
}
|
|
|
|
// Returns a pointer to the newly inserted node. Caller must set
|
|
// `entryPresent`, `entry` fields and `maxVersion` on the result. The search
|
|
// path of the result's parent will have `maxVersion` at least `writeVersion` as
|
|
// a postcondition.
|
|
template <bool kBegin>
|
|
[[nodiscard]] Node *insert(Node **self, std::span<const uint8_t> key,
|
|
int64_t writeVersion, NodeAllocators *allocators,
|
|
ConflictSet::Impl *impl) {
|
|
|
|
for (;;) {
|
|
|
|
if ((*self)->partialKeyLen > 0) {
|
|
// Handle an existing partial key
|
|
int commonLen = std::min<int>((*self)->partialKeyLen, key.size());
|
|
int partialKeyIndex = longestCommonPrefixPartialKey(
|
|
(*self)->partialKey(), key.data(), commonLen);
|
|
if (partialKeyIndex < (*self)->partialKeyLen) {
|
|
auto *old = *self;
|
|
int64_t oldMaxVersion = maxVersion(old, impl);
|
|
|
|
*self = allocators->node4.allocate(partialKeyIndex);
|
|
|
|
memcpy((char *)*self + kNodeCopyBegin, (char *)old + kNodeCopyBegin,
|
|
kNodeCopySize);
|
|
(*self)->partialKeyLen = partialKeyIndex;
|
|
(*self)->entryPresent = false;
|
|
(*self)->numChildren = 0;
|
|
memcpy((*self)->partialKey(), old->partialKey(),
|
|
(*self)->partialKeyLen);
|
|
|
|
getOrCreateChild(*self, old->partialKey()[partialKeyIndex],
|
|
allocators) = old;
|
|
old->parent = *self;
|
|
old->parentsIndex = old->partialKey()[partialKeyIndex];
|
|
maxVersion(old, impl) = oldMaxVersion;
|
|
|
|
memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1,
|
|
old->partialKeyLen - (partialKeyIndex + 1));
|
|
old->partialKeyLen -= partialKeyIndex + 1;
|
|
}
|
|
key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex);
|
|
|
|
} else {
|
|
// Consider adding a partial key
|
|
if ((*self)->numChildren == 0 && !(*self)->entryPresent) {
|
|
assert((*self)->partialKeyCapacity >= int(key.size()));
|
|
(*self)->partialKeyLen = key.size();
|
|
memcpy((*self)->partialKey(), key.data(), (*self)->partialKeyLen);
|
|
key = key.subspan((*self)->partialKeyLen,
|
|
key.size() - (*self)->partialKeyLen);
|
|
}
|
|
}
|
|
|
|
if constexpr (kBegin) {
|
|
auto &m = maxVersion(*self, impl);
|
|
assert(writeVersion >= m);
|
|
m = writeVersion;
|
|
}
|
|
|
|
if (key.size() == 0) {
|
|
return *self;
|
|
}
|
|
|
|
if constexpr (!kBegin) {
|
|
auto &m = maxVersion(*self, impl);
|
|
assert(writeVersion >= m);
|
|
m = writeVersion;
|
|
}
|
|
|
|
auto &child = getOrCreateChild(*self, key.front(), allocators);
|
|
if (!child) {
|
|
child = allocators->node0.allocate(key.size() - 1);
|
|
child->parent = *self;
|
|
child->parentsIndex = key.front();
|
|
maxVersion(child, impl) =
|
|
kBegin ? writeVersion : std::numeric_limits<int64_t>::lowest();
|
|
}
|
|
|
|
self = &child;
|
|
key = key.subspan(1, key.size() - 1);
|
|
}
|
|
}
|
|
|
|
void destroyTree(Node *root) {
|
|
Arena arena;
|
|
auto toFree = vector<Node *>(arena);
|
|
toFree.push_back(root);
|
|
while (toFree.size() > 0) {
|
|
auto *n = toFree.back();
|
|
toFree.pop_back();
|
|
// Add all children to toFree
|
|
for (int child = getChildGeq(n, 0); child >= 0;
|
|
child = getChildGeq(n, child + 1)) {
|
|
auto *c = getChildExists(n, child);
|
|
assert(c != nullptr);
|
|
toFree.push_back(c);
|
|
}
|
|
free(n);
|
|
}
|
|
}
|
|
|
|
void addPointWrite(Node *&root, int64_t oldestVersion,
|
|
std::span<const uint8_t> key, int64_t writeVersion,
|
|
NodeAllocators *allocators, ConflictSet::Impl *impl) {
|
|
auto *n = insert<true>(&root, key, writeVersion, allocators, impl);
|
|
if (!n->entryPresent) {
|
|
auto *p = nextLogical(n);
|
|
n->entryPresent = true;
|
|
n->entry.pointVersion = writeVersion;
|
|
maxVersion(n, impl) = writeVersion;
|
|
n->entry.rangeVersion =
|
|
p != nullptr ? p->entry.rangeVersion : oldestVersion;
|
|
} else {
|
|
assert(writeVersion >= n->entry.pointVersion);
|
|
n->entry.pointVersion = writeVersion;
|
|
}
|
|
}
|
|
|
|
void addWriteRange(Node *&root, int64_t oldestVersion,
|
|
std::span<const uint8_t> begin, std::span<const uint8_t> end,
|
|
int64_t writeVersion, NodeAllocators *allocators,
|
|
ConflictSet::Impl *impl) {
|
|
|
|
int lcp = longestCommonPrefix(begin.data(), end.data(),
|
|
std::min(begin.size(), end.size()));
|
|
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
|
|
end.back() == 0) {
|
|
return addPointWrite(root, oldestVersion, begin, writeVersion, allocators,
|
|
impl);
|
|
}
|
|
auto remaining = begin.subspan(0, lcp);
|
|
|
|
auto *n = root;
|
|
|
|
for (;;) {
|
|
if (int(remaining.size()) <= n->partialKeyLen) {
|
|
break;
|
|
}
|
|
int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
|
|
n->partialKeyLen);
|
|
if (i != n->partialKeyLen) {
|
|
break;
|
|
}
|
|
|
|
auto *child = getChild(n, remaining[n->partialKeyLen]);
|
|
if (child == nullptr) {
|
|
break;
|
|
}
|
|
|
|
auto &m = maxVersion(n, impl);
|
|
assert(writeVersion >= m);
|
|
m = writeVersion;
|
|
|
|
remaining = remaining.subspan(n->partialKeyLen + 1,
|
|
remaining.size() - (n->partialKeyLen + 1));
|
|
n = child;
|
|
}
|
|
|
|
Node **useAsRoot = n->parent == nullptr
|
|
? &root
|
|
: &getChildExists(n->parent, n->parentsIndex);
|
|
|
|
int consumed = lcp - remaining.size();
|
|
|
|
begin = begin.subspan(consumed, begin.size() - consumed);
|
|
end = end.subspan(consumed, end.size() - consumed);
|
|
|
|
auto *beginNode =
|
|
insert<true>(useAsRoot, begin, writeVersion, allocators, impl);
|
|
|
|
const bool insertedBegin = !beginNode->entryPresent;
|
|
beginNode->entryPresent = true;
|
|
|
|
if (insertedBegin) {
|
|
auto *p = nextLogical(beginNode);
|
|
beginNode->entry.rangeVersion =
|
|
p != nullptr ? p->entry.rangeVersion : oldestVersion;
|
|
beginNode->entry.pointVersion = writeVersion;
|
|
maxVersion(beginNode, impl) = writeVersion;
|
|
}
|
|
auto &m = maxVersion(beginNode, impl);
|
|
assert(writeVersion >= m);
|
|
m = writeVersion;
|
|
assert(writeVersion >= beginNode->entry.pointVersion);
|
|
beginNode->entry.pointVersion = writeVersion;
|
|
|
|
auto *endNode = insert<false>(useAsRoot, end, writeVersion, allocators, impl);
|
|
|
|
const bool insertedEnd = !endNode->entryPresent;
|
|
endNode->entryPresent = true;
|
|
|
|
if (insertedEnd) {
|
|
auto *p = nextLogical(endNode);
|
|
endNode->entry.pointVersion =
|
|
p != nullptr ? p->entry.rangeVersion : oldestVersion;
|
|
auto &m = maxVersion(endNode, impl);
|
|
m = std::max(m, endNode->entry.pointVersion);
|
|
}
|
|
endNode->entry.rangeVersion = writeVersion;
|
|
|
|
if (insertedEnd) {
|
|
// beginNode may have been invalidated
|
|
beginNode = insert<true>(useAsRoot, begin, writeVersion, allocators, impl);
|
|
assert(beginNode->entryPresent);
|
|
}
|
|
|
|
for (beginNode = nextLogical(beginNode); beginNode != endNode;) {
|
|
auto *old = beginNode;
|
|
beginNode = nextLogical(beginNode);
|
|
old->entryPresent = false;
|
|
if (old->numChildren == 0 && old->parent != nullptr) {
|
|
eraseChild(old->parent, old->parentsIndex, allocators);
|
|
}
|
|
}
|
|
}
|
|
|
|
Iterator firstGeq(Node *n, const std::span<const uint8_t> key) {
|
|
auto remaining = key;
|
|
for (;;) {
|
|
if (remaining.size() == 0) {
|
|
if (n->entryPresent) {
|
|
return {n, 0};
|
|
}
|
|
int c = getChildGeq(n, 0);
|
|
assert(c >= 0);
|
|
n = getChildExists(n, c);
|
|
goto downLeftSpine;
|
|
}
|
|
|
|
auto *child = getChild(n, remaining[0]);
|
|
if (child == nullptr) {
|
|
int c = getChildGeq(n, remaining[0]);
|
|
if (c >= 0) {
|
|
n = getChildExists(n, c);
|
|
goto downLeftSpine;
|
|
} else {
|
|
n = nextSibling(n);
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
|
|
n = child;
|
|
remaining = remaining.subspan(1, remaining.size() - 1);
|
|
|
|
if (n->partialKeyLen > 0) {
|
|
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
|
|
int i = longestCommonPrefixPartialKey(n->partialKey(), remaining.data(),
|
|
commonLen);
|
|
if (i < commonLen) {
|
|
auto c = n->partialKey()[i] <=> remaining[i];
|
|
if (c > 0) {
|
|
goto downLeftSpine;
|
|
} else {
|
|
n = nextSibling(n);
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
if (commonLen == n->partialKeyLen) {
|
|
// partial key matches
|
|
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
|
|
} else if (n->partialKeyLen > int(remaining.size())) {
|
|
// n is the first physical node greater than remaining, and there's no
|
|
// eq node
|
|
goto downLeftSpine;
|
|
}
|
|
}
|
|
}
|
|
downLeftSpine:
|
|
if (n == nullptr) {
|
|
return {nullptr, 1};
|
|
}
|
|
for (;;) {
|
|
if (n->entryPresent) {
|
|
return {n, 1};
|
|
}
|
|
int c = getChildGeq(n, 0);
|
|
assert(c >= 0);
|
|
n = getChildExists(n, c);
|
|
}
|
|
}
|
|
|
|
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
|
|
|
|
void check(const ReadRange *reads, Result *result, int count) {
|
|
for (int i = 0; i < count; ++i) {
|
|
const auto &r = reads[i];
|
|
auto begin = std::span<const uint8_t>(r.begin.p, r.begin.len);
|
|
auto end = std::span<const uint8_t>(r.end.p, r.end.len);
|
|
result[i] =
|
|
reads[i].readVersion < oldestVersion ? TooOld
|
|
: (end.size() > 0
|
|
? checkRangeRead(root, begin, end, reads[i].readVersion, this)
|
|
: checkPointRead(root, begin, reads[i].readVersion, this))
|
|
? Commit
|
|
: Conflict;
|
|
}
|
|
}
|
|
|
|
void addWrites(const WriteRange *writes, int count, int64_t writeVersion) {
|
|
for (int i = 0; i < count; ++i) {
|
|
const auto &w = writes[i];
|
|
auto begin = std::span<const uint8_t>(w.begin.p, w.begin.len);
|
|
auto end = std::span<const uint8_t>(w.end.p, w.end.len);
|
|
if (w.end.len > 0) {
|
|
keyUpdates += 3;
|
|
addWriteRange(root, oldestVersion, begin, end, writeVersion,
|
|
&allocators, this);
|
|
} else {
|
|
keyUpdates += 2;
|
|
addPointWrite(root, oldestVersion, begin, writeVersion, &allocators,
|
|
this);
|
|
}
|
|
}
|
|
}
|
|
|
|
void setOldestVersion(int64_t oldestVersion) {
|
|
if (oldestVersion <= this->oldestVersion) {
|
|
return;
|
|
}
|
|
this->oldestVersion = oldestVersion;
|
|
if (keyUpdates < 100) {
|
|
return;
|
|
}
|
|
Node *prev = firstGeq(root, removalKey).n;
|
|
// There's no way to erase removalKey without introducing a key after it
|
|
assert(prev != nullptr);
|
|
for (; keyUpdates > 0; --keyUpdates) {
|
|
Node *n = nextLogical(prev);
|
|
if (n == nullptr) {
|
|
removalKey = {};
|
|
return;
|
|
}
|
|
|
|
if (std::max(prev->entry.pointVersion, prev->entry.rangeVersion) <=
|
|
oldestVersion) {
|
|
// Any transaction prev would have prevented from committing is
|
|
// going to fail with TooOld anyway.
|
|
|
|
// There's no way to insert a range such that range version of the right
|
|
// node is greater than the point version of the left node
|
|
assert(n->entry.rangeVersion <= oldestVersion);
|
|
prev->entryPresent = false;
|
|
if (prev->numChildren == 0 && prev->parent != nullptr) {
|
|
eraseChild(prev->parent, prev->parentsIndex, &allocators);
|
|
}
|
|
}
|
|
|
|
prev = n;
|
|
}
|
|
removalKeyArena = Arena();
|
|
removalKey = getSearchPath(removalKeyArena, prev);
|
|
}
|
|
|
|
explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) {
|
|
// Insert ""
|
|
root = allocators.node0.allocate(0);
|
|
rootMaxVersion = oldestVersion;
|
|
root->entry.pointVersion = oldestVersion;
|
|
root->entry.rangeVersion = oldestVersion;
|
|
root->entryPresent = true;
|
|
}
|
|
~Impl() { destroyTree(root); }
|
|
|
|
NodeAllocators allocators;
|
|
|
|
Arena removalKeyArena;
|
|
std::span<const uint8_t> removalKey;
|
|
int64_t keyUpdates = 0;
|
|
|
|
Node *root;
|
|
int64_t rootMaxVersion;
|
|
int64_t oldestVersion;
|
|
};
|
|
|
|
// Precondition - an entry for index must exist in the node
|
|
int64_t &maxVersion(Node *n, ConflictSet::Impl *impl) {
|
|
int index = n->parentsIndex;
|
|
n = n->parent;
|
|
if (n == nullptr) {
|
|
return impl->rootMaxVersion;
|
|
}
|
|
if (n->type <= Type::Node16) {
|
|
auto *n16 = static_cast<Node16 *>(n);
|
|
int i = getNodeIndex(n16, index);
|
|
return n16->children[i].childMaxVersion;
|
|
} else if (n->type == Type::Node48) {
|
|
auto *n48 = static_cast<Node48 *>(n);
|
|
assert(n48->bitSet.test(index));
|
|
return n48->children[n48->index[index]].childMaxVersion;
|
|
} else {
|
|
auto *n256 = static_cast<Node256 *>(n);
|
|
assert(n256->bitSet.test(index));
|
|
return n256->children[index].childMaxVersion;
|
|
}
|
|
}
|
|
|
|
// ==================== END IMPLEMENTATION ====================
|
|
|
|
// GCOVR_EXCL_START
|
|
|
|
void ConflictSet::check(const ReadRange *reads, Result *results,
|
|
int count) const {
|
|
return impl->check(reads, results, count);
|
|
}
|
|
|
|
void ConflictSet::addWrites(const WriteRange *writes, int count,
|
|
int64_t writeVersion) {
|
|
return impl->addWrites(writes, count, writeVersion);
|
|
}
|
|
|
|
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
|
|
return impl->setOldestVersion(oldestVersion);
|
|
}
|
|
|
|
ConflictSet::ConflictSet(int64_t oldestVersion)
|
|
: impl(new (safe_malloc(sizeof(Impl))) Impl{oldestVersion}) {}
|
|
|
|
ConflictSet::~ConflictSet() {
|
|
if (impl) {
|
|
impl->~Impl();
|
|
free(impl);
|
|
}
|
|
}
|
|
|
|
#if SHOW_MEMORY
|
|
__attribute__((visibility("default"))) void showMemory(const ConflictSet &cs) {
|
|
ConflictSet::Impl *impl;
|
|
memcpy(&impl, &cs, sizeof(impl)); // NOLINT
|
|
fprintf(stderr, "Max Node0 memory usage: %" PRId64 "\n",
|
|
impl->allocators.node0.highWaterMarkBytes());
|
|
fprintf(stderr, "Max Node4 memory usage: %" PRId64 "\n",
|
|
impl->allocators.node4.highWaterMarkBytes());
|
|
fprintf(stderr, "Max Node16 memory usage: %" PRId64 "\n",
|
|
impl->allocators.node16.highWaterMarkBytes());
|
|
fprintf(stderr, "Max Node48 memory usage: %" PRId64 "\n",
|
|
impl->allocators.node48.highWaterMarkBytes());
|
|
fprintf(stderr, "Max Node256 memory usage: %" PRId64 "\n",
|
|
impl->allocators.node256.highWaterMarkBytes());
|
|
}
|
|
#endif
|
|
|
|
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
|
|
: impl(std::exchange(other.impl, nullptr)) {}
|
|
|
|
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
|
|
impl = std::exchange(other.impl, nullptr);
|
|
return *this;
|
|
}
|
|
|
|
using ConflictSet_Result = ConflictSet::Result;
|
|
using ConflictSet_Key = ConflictSet::Key;
|
|
using ConflictSet_ReadRange = ConflictSet::ReadRange;
|
|
using ConflictSet_WriteRange = ConflictSet::WriteRange;
|
|
|
|
extern "C" {
|
|
__attribute__((__visibility__("default"))) void
|
|
ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads,
|
|
ConflictSet_Result *results, int count) {
|
|
((ConflictSet::Impl *)cs)->check(reads, results, count);
|
|
}
|
|
__attribute__((__visibility__("default"))) void
|
|
ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count,
|
|
int64_t writeVersion) {
|
|
((ConflictSet::Impl *)cs)->addWrites(writes, count, writeVersion);
|
|
}
|
|
__attribute__((__visibility__("default"))) void
|
|
ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) {
|
|
((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion);
|
|
}
|
|
__attribute__((__visibility__("default"))) void *
|
|
ConflictSet_create(int64_t oldestVersion) {
|
|
return new (safe_malloc(sizeof(ConflictSet::Impl)))
|
|
ConflictSet::Impl{oldestVersion};
|
|
}
|
|
__attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) {
|
|
using Impl = ConflictSet::Impl;
|
|
((Impl *)cs)->~Impl();
|
|
free(cs);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
std::string getSearchPathPrintable(Node *n) {
|
|
Arena arena;
|
|
if (n == nullptr) {
|
|
return "<end>";
|
|
}
|
|
auto result = vector<char>(arena);
|
|
for (;;) {
|
|
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
|
|
result.push_back(n->partialKey()[i]);
|
|
}
|
|
if (n->parent == nullptr) {
|
|
break;
|
|
}
|
|
result.push_back(n->parentsIndex);
|
|
n = n->parent;
|
|
}
|
|
std::reverse(result.begin(), result.end());
|
|
if (result.size() > 0) {
|
|
return printable(std::string_view((const char *)&result[0],
|
|
result.size())); // NOLINT
|
|
} else {
|
|
return std::string();
|
|
}
|
|
}
|
|
|
|
std::string getPartialKeyPrintable(Node *n) {
|
|
Arena arena;
|
|
if (n == nullptr) {
|
|
return "<end>";
|
|
}
|
|
auto result = std::string((const char *)&n->parentsIndex,
|
|
n->parent == nullptr ? 0 : 1) +
|
|
std::string((const char *)n->partialKey(), n->partialKeyLen);
|
|
return printable(result); // NOLINT
|
|
}
|
|
|
|
std::string strinc(std::string_view str, bool &ok) {
|
|
int index;
|
|
for (index = str.size() - 1; index >= 0; index--)
|
|
if ((uint8_t &)(str[index]) != 255)
|
|
break;
|
|
|
|
// Must not be called with a string that consists only of zero or more '\xff'
|
|
// bytes.
|
|
if (index < 0) {
|
|
ok = false;
|
|
return {};
|
|
}
|
|
ok = true;
|
|
|
|
auto r = std::string(str.substr(0, index + 1));
|
|
((uint8_t &)r[r.size() - 1])++;
|
|
return r;
|
|
}
|
|
|
|
std::string getSearchPath(Node *n) {
|
|
assert(n != nullptr);
|
|
Arena arena;
|
|
auto result = getSearchPath(arena, n);
|
|
return std::string((const char *)result.data(), result.size());
|
|
}
|
|
|
|
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node,
|
|
ConflictSet::Impl *impl) {
|
|
|
|
constexpr int kSeparation = 3;
|
|
|
|
struct DebugDotPrinter {
|
|
|
|
explicit DebugDotPrinter(FILE *file, ConflictSet::Impl *impl)
|
|
: file(file), impl(impl) {}
|
|
|
|
void print(Node *n, int y = 0) {
|
|
assert(n != nullptr);
|
|
if (n->entryPresent) {
|
|
fprintf(file,
|
|
" k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64
|
|
"\n%s\", pos=\"%d,%d!\"];\n",
|
|
(void *)n, maxVersion(n, impl), n->entry.pointVersion,
|
|
n->entry.rangeVersion, getPartialKeyPrintable(n).c_str(), x, y);
|
|
} else {
|
|
fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n",
|
|
(void *)n, maxVersion(n, impl),
|
|
getPartialKeyPrintable(n).c_str(), x, y);
|
|
}
|
|
x += kSeparation;
|
|
for (int child = getChildGeq(n, 0); child >= 0;
|
|
child = getChildGeq(n, child + 1)) {
|
|
auto *c = getChildExists(n, child);
|
|
fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c);
|
|
print(c, y - kSeparation);
|
|
}
|
|
}
|
|
int x = 0;
|
|
FILE *file;
|
|
ConflictSet::Impl *impl;
|
|
};
|
|
|
|
fprintf(file, "digraph ConflictSet {\n");
|
|
fprintf(file, " node [shape = box];\n");
|
|
assert(node != nullptr);
|
|
DebugDotPrinter printer{file, impl};
|
|
printer.print(node);
|
|
fprintf(file, "}\n");
|
|
}
|
|
|
|
void checkParentPointers(Node *node, bool &success) {
|
|
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
|
|
auto *child = getChildExists(node, i);
|
|
if (child->parent != node) {
|
|
fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n",
|
|
getSearchPathPrintable(node).c_str(), i, (void *)child->parent,
|
|
(void *)node);
|
|
success = false;
|
|
}
|
|
checkParentPointers(child, success);
|
|
}
|
|
}
|
|
|
|
Iterator firstGeq(Node *n, std::string_view key) {
|
|
return firstGeq(
|
|
n, std::span<const uint8_t>((const uint8_t *)key.data(), key.size()));
|
|
}
|
|
|
|
[[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node,
|
|
int64_t oldestVersion, bool &success,
|
|
ConflictSet::Impl *impl) {
|
|
int64_t expected = std::numeric_limits<int64_t>::lowest();
|
|
if (node->entryPresent) {
|
|
expected = std::max(expected, node->entry.pointVersion);
|
|
}
|
|
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
|
|
auto *child = getChildExists(node, i);
|
|
expected = std::max(
|
|
expected, checkMaxVersion(root, child, oldestVersion, success, impl));
|
|
if (child->entryPresent) {
|
|
expected = std::max(expected, child->entry.rangeVersion);
|
|
}
|
|
}
|
|
auto key = getSearchPath(root);
|
|
bool ok;
|
|
auto inc = strinc(key, ok);
|
|
if (ok) {
|
|
auto borrowed = firstGeq(root, inc);
|
|
if (borrowed.n != nullptr) {
|
|
expected = std::max(expected, borrowed.n->entry.rangeVersion);
|
|
}
|
|
}
|
|
if (node->parent != nullptr &&
|
|
getChildMaxVersion(node->parent, node->parentsIndex) !=
|
|
maxVersion(node, impl)) {
|
|
fprintf(stderr,
|
|
"%s has max version %" PRId64
|
|
" . But parent has child max version %" PRId64 "\n",
|
|
getSearchPathPrintable(node).c_str(), maxVersion(node, impl),
|
|
getChildMaxVersion(node->parent, node->parentsIndex));
|
|
success = false;
|
|
}
|
|
if (maxVersion(node, impl) > oldestVersion &&
|
|
maxVersion(node, impl) != expected) {
|
|
fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",
|
|
getSearchPathPrintable(node).c_str(), maxVersion(node, impl),
|
|
expected);
|
|
success = false;
|
|
}
|
|
return expected;
|
|
}
|
|
|
|
[[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) {
|
|
int64_t total = node->entryPresent;
|
|
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
|
|
auto *child = getChildExists(node, i);
|
|
int64_t e = checkEntriesExist(child, success);
|
|
total += e;
|
|
if (e == 0) {
|
|
Arena arena;
|
|
fprintf(stderr, "%s has child %02x with no reachable entries\n",
|
|
getSearchPathPrintable(node).c_str(), i);
|
|
success = false;
|
|
}
|
|
}
|
|
return total;
|
|
}
|
|
|
|
bool checkCorrectness(Node *node, int64_t oldestVersion,
|
|
ConflictSet::Impl *impl) {
|
|
bool success = true;
|
|
|
|
checkParentPointers(node, success);
|
|
checkMaxVersion(node, node, oldestVersion, success, impl);
|
|
checkEntriesExist(node, success);
|
|
|
|
return success;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace std {
|
|
void __throw_length_error(const char *) { __builtin_unreachable(); }
|
|
} // namespace std
|
|
|
|
#ifdef ENABLE_MAIN
|
|
|
|
void printTree() {
|
|
int64_t writeVersion = 0;
|
|
ConflictSet::Impl cs{writeVersion};
|
|
ReferenceImpl refImpl{writeVersion};
|
|
Arena arena;
|
|
ConflictSet::WriteRange write;
|
|
write.begin = "and"_s;
|
|
write.end = "ant"_s;
|
|
cs.addWrites(&write, 1, ++writeVersion);
|
|
write.begin = "any"_s;
|
|
write.end = ""_s;
|
|
cs.addWrites(&write, 1, ++writeVersion);
|
|
write.begin = "are"_s;
|
|
write.end = ""_s;
|
|
cs.addWrites(&write, 1, ++writeVersion);
|
|
write.begin = "art"_s;
|
|
write.end = ""_s;
|
|
cs.addWrites(&write, 1, ++writeVersion);
|
|
debugPrintDot(stdout, cs.root, &cs);
|
|
}
|
|
|
|
#define ANKERL_NANOBENCH_IMPLEMENT
|
|
#include "third_party/nanobench.h"
|
|
|
|
int main(void) {
|
|
printTree();
|
|
return 0;
|
|
ankerl::nanobench::Bench bench;
|
|
ConflictSet::Impl cs{0};
|
|
for (int j = 0; j < 256; ++j) {
|
|
getOrCreateChild(cs.root, j, &cs.allocators) =
|
|
cs.allocators.node0.allocate(0);
|
|
if (j % 10 == 0) {
|
|
bench.run("MaxExclusive " + std::to_string(j), [&]() {
|
|
bench.doNotOptimizeAway(maxBetweenExclusive(cs.root, 0, 256));
|
|
});
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
#endif
|
|
|
|
#ifdef ENABLE_FUZZ
|
|
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
|
TestDriver<ConflictSet::Impl> driver{data, size};
|
|
|
|
for (;;) {
|
|
bool done = driver.next();
|
|
if (!driver.ok) {
|
|
debugPrintDot(stdout, driver.cs.root, &driver.cs);
|
|
fflush(stdout);
|
|
abort();
|
|
}
|
|
#if DEBUG_VERBOSE && !defined(NDEBUG)
|
|
fprintf(stderr, "Check correctness\n");
|
|
#endif
|
|
bool success =
|
|
checkCorrectness(driver.cs.root, driver.cs.oldestVersion, &driver.cs);
|
|
if (!success) {
|
|
debugPrintDot(stdout, driver.cs.root, &driver.cs);
|
|
fflush(stdout);
|
|
abort();
|
|
}
|
|
if (done) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
#endif
|
|
|
|
// GCOVR_EXCL_STOP
|