Files
conflict-set/ConflictSet.cpp

2321 lines
66 KiB
C++

/*
Copyright 2024 Andrew Noyes
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "ConflictSet.h"
#include "Internal.h"
#include <algorithm>
#include <bit>
#include <cassert>
#include <compare>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <inttypes.h>
#include <limits>
#include <span>
#include <string>
#include <string_view>
#include <utility>
#ifdef HAS_AVX
#include <immintrin.h>
#elif defined(HAS_ARM_NEON)
#include <arm_neon.h>
#endif
#include <memcheck.h>
// Use assert for checking potentially complex properties during tests.
// Use assume to hint simple properties to the optimizer.
// TODO use the c++23 version when that's available
#ifdef NDEBUG
#if __has_builtin(__builtin_assume)
#define assume(e) __builtin_assume(e)
#else
#define assume(e) __attribute__((assume(e)))
#endif
#else
#define assume assert
#endif
// ==================== BEGIN IMPLEMENTATION ====================
struct Entry {
int64_t pointVersion;
int64_t rangeVersion;
};
struct BitSet {
bool test(int i) const;
void set(int i);
void reset(int i);
int firstSetGeq(int i) const;
// Calls `f` with the index of each bit set in [begin, end)
template <class F> void forEachInRange(F f, int begin, int end) {
// See section 3.1 in https://arxiv.org/pdf/1709.07821.pdf for details about
// this approach
if ((begin >> 6) == (end >> 6)) {
uint64_t word = words[begin >> 6] & (uint64_t(-1) << (begin & 63)) &
~(uint64_t(-1) << (end & 63));
while (word) {
uint64_t temp = word & -word;
int index = (begin & ~63) + std::countr_zero(word);
f(index);
word ^= temp;
}
return;
}
// Check begin partial word
if (begin & 63) {
uint64_t word = words[begin >> 6] & (uint64_t(-1) << (begin & 63));
if (std::popcount(word) + (begin & 63) == 64) {
while (begin & 63) {
f(begin++);
}
} else {
while (word) {
uint64_t temp = word & -word;
int index = (begin & ~63) + std::countr_zero(word);
f(index);
word ^= temp;
}
begin &= ~63;
begin += 64;
}
}
// Check inner, full words
while (begin != (end & ~63)) {
uint64_t word = words[begin >> 6];
if (word == uint64_t(-1)) {
for (int i = 0; i < 64; ++i) {
f(begin + i);
}
} else {
while (word) {
uint64_t temp = word & -word;
int index = begin + std::countr_zero(word);
f(index);
word ^= temp;
}
}
begin += 64;
}
if (end & 63) {
// Check end partial word
uint64_t word = words[end >> 6] & ~(uint64_t(-1) << (end & 63));
if (std::popcount(word) == (end & 63)) {
while (begin < end) {
f(begin++);
}
} else {
while (word) {
uint64_t temp = word & -word;
int index = begin + std::countr_zero(word);
f(index);
word ^= temp;
}
}
}
}
private:
uint64_t words[4] = {};
};
bool BitSet::test(int i) const {
assert(0 <= i);
assert(i < 256);
return words[i >> 6] & (uint64_t(1) << (i & 63));
}
void BitSet::set(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] |= uint64_t(1) << (i & 63);
}
void BitSet::reset(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] &= ~(uint64_t(1) << (i & 63));
}
int BitSet::firstSetGeq(int i) const {
assume(0 <= i);
// i may be >= 256
uint64_t mask = uint64_t(-1) << (i & 63);
for (int j = i >> 6; j < 4; ++j) {
uint64_t masked = mask & words[j];
if (masked) {
return (j << 6) + std::countr_zero(masked);
}
mask = -1;
}
return -1;
}
enum class Type : int8_t {
Node0,
Node4,
Node16,
Node48,
Node256,
};
struct Node {
/* begin section that's copied to the next node */
Node *parent = nullptr;
Entry entry;
int32_t partialKeyLen = 0;
int16_t numChildren = 0;
bool entryPresent = false;
uint8_t parentsIndex = 0;
/* end section that's copied to the next node */
Type type;
// Leaving this uninitialized is intentional and necessary for correctness.
// Basically it needs to be preserved when going to the free list and back.
int32_t partialKeyCapacity;
uint8_t *partialKey();
};
constexpr int kNodeCopyBegin = offsetof(Node, parent);
constexpr int kNodeCopySize = offsetof(Node, type) - kNodeCopyBegin;
struct Child {
int64_t childMaxVersion;
Node *child;
};
struct Node0 : Node {
// Sorted
uint8_t index[16]; // 16 so that we can use the same simd index search
// implementation as Node16
Node0() { this->type = Type::Node0; }
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
};
struct Node4 : Node {
// Sorted
uint8_t index[16]; // 16 so that we can use the same simd index search
// implementation as Node16
Child children[4];
Node4() { this->type = Type::Node4; }
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
};
struct Node16 : Node {
// Sorted
uint8_t index[16];
Child children[16];
Node16() { this->type = Type::Node16; }
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
};
struct Node48 : Node {
BitSet bitSet;
Child children[48];
int8_t nextFree = 0;
int8_t index[256];
Node48() {
memset(index, -1, 256);
this->type = Type::Node48;
}
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
};
struct Node256 : Node {
BitSet bitSet;
Child children[256];
Node256() {
this->type = Type::Node256;
for (int i = 0; i < 256; ++i) {
children[i].child = nullptr;
}
}
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
};
// Bound memory usage following the analysis in the ART paper
constexpr int kBytesPerKey = 86;
constexpr int kMinChildrenNode4 = 2;
constexpr int kMinChildrenNode16 = 5;
constexpr int kMinChildrenNode48 = 17;
constexpr int kMinChildrenNode256 = 49;
static_assert(sizeof(Node256) < kMinChildrenNode256 * kBytesPerKey);
static_assert(sizeof(Node48) < kMinChildrenNode48 * kBytesPerKey);
static_assert(sizeof(Node16) < kMinChildrenNode16 * kBytesPerKey);
static_assert(sizeof(Node4) < kMinChildrenNode4 * kBytesPerKey);
// Bounds memory usage in free list, but does not account for memory for partial
// keys.
template <class T, int64_t kMemoryBound = (1 << 20),
int64_t kMaxIndividual = (1 << 10)>
struct BoundedFreeListAllocator {
static_assert(sizeof(T) >= sizeof(void *));
static_assert(std::derived_from<T, Node>);
T *allocate(int partialKeyCapacity) {
#if SHOW_MEMORY
++liveAllocations;
maxLiveAllocations = std::max(maxLiveAllocations, liveAllocations);
#endif
if (freeList != nullptr) {
T *n = (T *)freeList;
VALGRIND_MAKE_MEM_UNDEFINED(n, sizeof(T));
VALGRIND_MAKE_MEM_DEFINED(&n->partialKeyCapacity,
sizeof(n->partialKeyCapacity));
VALGRIND_MAKE_MEM_DEFINED(freeList, sizeof(freeList));
memcpy(&freeList, freeList, sizeof(freeList));
freeListBytes -= sizeof(T) + n->partialKeyCapacity;
if (n->partialKeyCapacity >= partialKeyCapacity) {
return new (n) T;
} else {
// The intent is to filter out too-small nodes in the freelist
free(n);
}
}
auto *result = new (safe_malloc(sizeof(T) + partialKeyCapacity)) T;
result->partialKeyCapacity = partialKeyCapacity;
return result;
}
void release(T *p) {
#if SHOW_MEMORY
--liveAllocations;
#endif
static_assert(std::is_trivially_destructible_v<T>);
if (sizeof(T) + p->partialKeyCapacity > kMaxIndividual ||
freeListBytes >= kMemoryBound) {
return free(p);
}
memcpy((void *)p, &freeList, sizeof(freeList));
freeList = p;
freeListBytes += sizeof(T) + p->partialKeyCapacity;
VALGRIND_MAKE_MEM_NOACCESS(freeList, sizeof(T));
}
~BoundedFreeListAllocator() {
for (void *iter = freeList; iter != nullptr;) {
VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(iter));
auto *tmp = iter;
memcpy(&iter, iter, sizeof(void *));
free(tmp);
}
}
#if SHOW_MEMORY
int64_t highWaterMarkBytes() const { return maxLiveAllocations * sizeof(T); }
#endif
private:
int64_t freeListBytes = 0;
void *freeList = nullptr;
#if SHOW_MEMORY
// TODO Track partial key bytes
int64_t maxLiveAllocations = 0;
int64_t liveAllocations = 0;
#endif
};
uint8_t *Node::partialKey() {
switch (type) {
case Type::Node0:
return ((Node0 *)this)->partialKey();
case Type::Node4:
return ((Node4 *)this)->partialKey();
case Type::Node16:
return ((Node16 *)this)->partialKey();
case Type::Node48:
return ((Node48 *)this)->partialKey();
case Type::Node256:
return ((Node256 *)this)->partialKey();
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
struct NodeAllocators {
BoundedFreeListAllocator<Node0> node0;
BoundedFreeListAllocator<Node4> node4;
BoundedFreeListAllocator<Node16> node16;
BoundedFreeListAllocator<Node48> node48;
BoundedFreeListAllocator<Node256> node256;
};
int getNodeIndex(Node16 *self, uint8_t index) {
#ifdef HAS_AVX
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
// key_vec is 16 repeated copies of the searched-for byte, one for every
// possible position in child_keys that needs to be searched.
__m128i key_vec = _mm_set1_epi8(index);
// Compare all child_keys to 'index' in parallel. Don't worry if some of the
// keys aren't valid, we'll mask the results to only consider the valid ones
// below.
__m128i indices;
memcpy(&indices, self->index, sizeof(self->index));
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
// Build a mask to select only the first node->num_children values from the
// comparison (because the other values are meaningless)
uint32_t mask = (1 << self->numChildren) - 1;
// Change the results of the comparison into a bitfield, masking off any
// invalid comparisons.
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
// No match if there are no '1's in the bitfield.
if (bitfield == 0)
return -1;
// Find the index of the first '1' in the bitfield by counting the leading
// zeros.
return std::countr_zero(bitfield);
#elif defined(HAS_ARM_NEON)
// Based on
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
uint8x16_t indices;
memcpy(&indices, self->index, sizeof(self->index));
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each match in valid range
uint64_t bitfield =
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
if (bitfield == 0)
return -1;
return std::countr_zero(bitfield) / 4;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
}
}
return -1;
#endif
}
// Precondition - an entry for index must exist in the node
Node *&getChildExists(Node *self, uint8_t index) {
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
return self16->children[getNodeIndex(self16, index)].child;
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
assert(self48->bitSet.test(index));
return self48->children[self48->index[index]].child;
} else {
auto *self256 = static_cast<Node256 *>(self);
assert(self256->bitSet.test(index));
return self256->children[index].child;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
// Precondition - an entry for index must exist in the node
int64_t getChildMaxVersion(Node *self, uint8_t index) {
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
return self16->children[getNodeIndex(self16, index)].childMaxVersion;
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
assert(self48->bitSet.test(index));
return self48->children[self48->index[index]].childMaxVersion;
} else {
auto *self256 = static_cast<Node256 *>(self);
assert(self256->bitSet.test(index));
return self256->children[index].childMaxVersion;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
// Precondition - an entry for index must exist in the node
int64_t &maxVersion(Node *n, ConflictSet::Impl *);
Node *getChild(Node *self, uint8_t index) {
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
int i = getNodeIndex(self16, index);
if (i >= 0) {
return self16->children[i].child;
}
return nullptr;
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
int secondIndex = self48->index[index];
if (secondIndex >= 0) {
return self48->children[secondIndex].child;
}
return nullptr;
} else {
auto *self256 = static_cast<Node256 *>(self);
return self256->children[index].child;
}
}
int getChildGeq(Node *self, int child) {
if (child > 255) {
return -1;
}
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(child);
__m128i indices;
memcpy(&indices, self16->index, sizeof(self16->index));
__m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices));
int mask = (1 << self16->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
int result = bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield)];
assert(result == [&]() -> int {
for (int i = 0; i < self16->numChildren; ++i) {
if (self16->index[i] >= child) {
return self16->index[i];
}
}
return -1;
}());
return result;
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self16->index, sizeof(self16->index));
// 0xff for each leq
auto results = vcleq_u8(vdupq_n_u8(child), indices);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
int simd =
bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield) / 4];
assert(simd == [&]() -> int {
for (int i = 0; i < self->numChildren; ++i) {
if (self16->index[i] >= child) {
return self16->index[i];
}
}
return -1;
}());
return simd;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (i > 0) {
assert(self16->index[i - 1] < self16->index[i]);
}
if (self16->index[i] >= child) {
return self16->index[i];
}
}
#endif
} else {
static_assert(offsetof(Node48, bitSet) == offsetof(Node256, bitSet));
auto *self48 = static_cast<Node48 *>(self);
return self48->bitSet.firstSetGeq(child);
}
return -1;
}
void setChildrenParents(Node16 *n) {
for (int i = 0; i < n->numChildren; ++i) {
n->children[i].child->parent = n;
}
}
void setChildrenParents(Node48 *n) {
n->bitSet.forEachInRange(
[&](int i) { n->children[n->index[i]].child->parent = n; }, 0, 256);
}
void setChildrenParents(Node256 *n) {
n->bitSet.forEachInRange([&](int i) { n->children[i].child->parent = n; }, 0,
256);
}
// Caller is responsible for assigning a non-null pointer to the returned
// reference if null
Node *&getOrCreateChild(Node *&self, uint8_t index,
NodeAllocators *allocators) {
// Fast path for if it exists already
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
int i = getNodeIndex(self16, index);
if (i >= 0) {
return self16->children[i].child;
}
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
int secondIndex = self48->index[index];
if (secondIndex >= 0) {
return self48->children[secondIndex].child;
}
} else {
auto *self256 = static_cast<Node256 *>(self);
if (auto &result = self256->children[index].child; result != nullptr) {
return result;
}
}
if (self->type == Type::Node0) {
auto *self0 = static_cast<Node0 *>(self);
auto *newSelf = allocators->node4.allocate(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize);
memcpy(newSelf->partialKey(), self0->partialKey(), self->partialKeyLen);
allocators->node0.release(self0);
self = newSelf;
goto insert16;
} else if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
if (self->numChildren == 4) {
auto *newSelf = allocators->node16.allocate(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize);
memcpy(newSelf->partialKey(), self4->partialKey(), self->partialKeyLen);
// TODO replace with memcpy?
for (int i = 0; i < 4; ++i) {
newSelf->index[i] = self4->index[i];
newSelf->children[i] = self4->children[i];
}
allocators->node4.release(self4);
setChildrenParents(newSelf);
self = newSelf;
}
goto insert16;
} else if (self->type == Type::Node16) {
if (self->numChildren == 16) {
auto *self16 = static_cast<Node16 *>(self);
auto *newSelf = allocators->node48.allocate(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize);
memcpy(newSelf->partialKey(), self16->partialKey(), self->partialKeyLen);
newSelf->nextFree = 16;
int i = 0;
for (auto x : self16->index) {
newSelf->bitSet.set(x);
newSelf->children[i] = self16->children[i];
newSelf->index[x] = i;
++i;
}
assert(i == 16);
allocators->node16.release(self16);
setChildrenParents(newSelf);
self = newSelf;
goto insert48;
}
insert16:
auto *self16 = static_cast<Node16 *>(self);
++self->numChildren;
int i = 0;
for (; i < int(self->numChildren) - 1; ++i) {
if (int(self16->index[i]) > int(index)) {
memmove(self16->index + i + 1, self16->index + i,
self->numChildren - (i + 1));
memmove(self16->children + i + 1, self16->children + i,
(self->numChildren - (i + 1)) * sizeof(Child));
break;
}
}
self16->index[i] = index;
auto &result = self16->children[i].child;
result = nullptr;
return result;
} else if (self->type == Type::Node48) {
if (self->numChildren == 48) {
auto *self48 = static_cast<Node48 *>(self);
auto *newSelf = allocators->node256.allocate(self->partialKeyLen);
memcpy((char *)newSelf + kNodeCopyBegin, (char *)self + kNodeCopyBegin,
kNodeCopySize);
memcpy(newSelf->partialKey(), self48->partialKey(), self->partialKeyLen);
newSelf->bitSet = self48->bitSet;
newSelf->bitSet.forEachInRange(
[&](int i) {
newSelf->children[i] = self48->children[self48->index[i]];
},
0, 256);
allocators->node48.release(self48);
setChildrenParents(newSelf);
self = newSelf;
goto insert256;
}
insert48:
auto *self48 = static_cast<Node48 *>(self);
self48->bitSet.set(index);
++self->numChildren;
assert(self48->nextFree < 48);
int nextFree = self48->nextFree++;
self48->index[index] = nextFree;
auto &result = self48->children[nextFree].child;
result = nullptr;
return result;
} else {
assert(self->type == Type::Node256);
insert256:
auto *self256 = static_cast<Node256 *>(self);
++self->numChildren;
self256->bitSet.set(index);
return self256->children[index].child;
}
}
// Precondition - an entry for index must exist in the node
void eraseChild(Node *self, uint8_t index, NodeAllocators *allocators) {
auto *child = getChildExists(self, index);
switch (child->type) {
case Type::Node0:
allocators->node0.release((Node0 *)child);
break;
case Type::Node4:
allocators->node4.release((Node4 *)child);
break;
case Type::Node16:
allocators->node16.release((Node16 *)child);
break;
case Type::Node48:
allocators->node48.release((Node48 *)child);
break;
case Type::Node256:
allocators->node256.release((Node256 *)child);
break;
}
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
int nodeIndex = getNodeIndex(self16, index);
memmove(self16->index + nodeIndex, self16->index + nodeIndex + 1,
sizeof(self16->index[0]) * (self->numChildren - (nodeIndex + 1)));
memmove(self16->children + nodeIndex, self16->children + nodeIndex + 1,
sizeof(self16->children[0]) * // NOLINT
(self->numChildren - (nodeIndex + 1)));
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
self48->bitSet.reset(index);
int8_t toRemoveChildrenIndex = std::exchange(self48->index[index], -1);
int8_t lastChildrenIndex = --self48->nextFree;
assert(toRemoveChildrenIndex >= 0);
assert(lastChildrenIndex >= 0);
if (toRemoveChildrenIndex != lastChildrenIndex) {
self48->children[toRemoveChildrenIndex] =
self48->children[lastChildrenIndex];
self48
->index[self48->children[toRemoveChildrenIndex].child->parentsIndex] =
toRemoveChildrenIndex;
}
} else {
auto *self256 = static_cast<Node256 *>(self);
self256->bitSet.reset(index);
self256->children[index].child = nullptr;
}
--self->numChildren;
if (self->numChildren == 0 && !self->entryPresent &&
self->parent != nullptr) {
eraseChild(self->parent, self->parentsIndex, allocators);
}
}
Node *nextPhysical(Node *node) {
int index = -1;
for (;;) {
auto nextChild = getChildGeq(node, index + 1);
if (nextChild >= 0) {
return getChildExists(node, nextChild);
}
index = node->parentsIndex;
node = node->parent;
if (node == nullptr) {
return nullptr;
}
}
}
Node *nextLogical(Node *node) {
for (node = nextPhysical(node); node != nullptr && !node->entryPresent;
node = nextPhysical(node))
;
return node;
}
struct Iterator {
Node *n;
int cmp;
};
Node *nextSibling(Node *node) {
for (;;) {
if (node->parent == nullptr) {
return nullptr;
}
auto next = getChildGeq(node->parent, node->parentsIndex + 1);
if (next < 0) {
node = node->parent;
} else {
return getChildExists(node->parent, next);
}
}
}
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
constexpr int kStride = 64;
#else
constexpr int kStride = 16;
#endif
constexpr int kUnrollFactor = 4;
bool compareStride(const uint8_t *ap, const uint8_t *bp) {
#if defined(HAS_ARM_NEON)
static_assert(kStride == 64);
uint8x16_t x[4];
for (int i = 0; i < 4; ++i) {
x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16));
}
auto results = vreinterpretq_u16_u8(
vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3])));
bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) ==
uint64_t(-1);
#elif defined(HAS_AVX)
static_assert(kStride == 64);
__m128i x[4];
for (int i = 0; i < 4; ++i) {
x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)),
_mm_loadu_si128((__m128i *)(bp + i * 16)));
}
auto eq =
_mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]),
_mm_and_si128(x[2], x[3]))) == 0xffff;
#else
// Hope it gets vectorized
auto eq = memcmp(ap, bp, kStride) == 0;
#endif
return eq;
}
// Precondition: ap[:kStride] != bp[:kStride]
int firstNeqStride(const uint8_t *ap, const uint8_t *bp) {
#if defined(HAS_AVX)
static_assert(kStride == 64);
uint64_t c[kStride / 16];
for (int i = 0; i < kStride; i += 16) {
const auto a = _mm_loadu_si128((__m128i *)(ap + i));
const auto b = _mm_loadu_si128((__m128i *)(bp + i));
const auto compared = _mm_cmpeq_epi8(a, b);
c[i / 16] = _mm_movemask_epi8(compared) & 0xffff;
}
return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48));
#elif defined(HAS_ARM_NEON)
static_assert(kStride == 64);
for (int i = 0; i < kStride; i += 16) {
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i)));
// 0xf for each mismatch
uint64_t bitfield =
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
if (bitfield) {
return i + (std::countr_zero(bitfield) >> 2);
}
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
#else
int i = 0;
for (; i < kStride - 1; ++i) {
if (*ap++ != *bp++) {
break;
}
}
return i;
#endif
}
int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
assume(cl >= 0);
int i = 0;
int end;
if (cl < 8) {
goto bytes;
}
// Optimistic early return
{
uint64_t a;
uint64_t b;
memcpy(&a, ap, 8);
memcpy(&b, bp, 8);
const auto mismatched = a ^ b;
if (mismatched) {
return std::countr_zero(mismatched) / 8;
}
}
// kStride * kUnrollCount at a time
end = cl & ~(kStride * kUnrollFactor - 1);
while (i < end) {
for (int j = 0; j < kUnrollFactor; ++j) {
if (!compareStride(ap, bp)) {
return i + firstNeqStride(ap, bp);
}
i += kStride;
ap += kStride;
bp += kStride;
}
}
// kStride at a time
end = cl & ~(kStride - 1);
while (i < end) {
if (!compareStride(ap, bp)) {
return i + firstNeqStride(ap, bp);
}
i += kStride;
ap += kStride;
bp += kStride;
}
// word at a time
end = cl & ~(sizeof(uint64_t) - 1);
while (i < end) {
uint64_t a;
uint64_t b;
memcpy(&a, ap, 8);
memcpy(&b, bp, 8);
const auto mismatched = a ^ b;
if (mismatched) {
return i + std::countr_zero(mismatched) / 8;
}
i += 8;
ap += 8;
bp += 8;
}
bytes:
// byte at a time
while (i < cl) {
if (*ap != *bp) {
break;
}
++ap;
++bp;
++i;
}
return i;
}
// Performs a physical search for remaining
struct SearchStepWise {
Node *n;
std::span<const uint8_t> remaining;
SearchStepWise() {}
SearchStepWise(Node *n, std::span<const uint8_t> remaining)
: n(n), remaining(remaining) {
assert(n->partialKeyLen == 0);
}
bool step() {
if (remaining.size() == 0) {
return true;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
return true;
}
int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1);
int i = longestCommonPrefix(child->partialKey(), remaining.data() + 1, cl);
if (i != child->partialKeyLen) {
return true;
}
n = child;
remaining =
remaining.subspan(1 + child->partialKeyLen,
remaining.size() - (1 + child->partialKeyLen));
return false;
}
};
namespace {
std::string getSearchPathPrintable(Node *n);
}
// Logically this is the same as performing firstGeq and then checking against
// point or range version according to cmp, but this version short circuits as
// soon as it can prove that there's no conflict.
bool checkPointRead(Node *n, const std::span<const uint8_t> key,
int64_t readVersion, ConflictSet::Impl *impl) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;;) {
if (maxVersion(n, impl) <= readVersion) {
return true;
}
if (remaining.size() == 0) {
if (n->entryPresent) {
return n->entry.pointVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
goto downLeftSpine;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
}
downLeftSpine:
if (n == nullptr) {
return true;
}
for (;;) {
if (n->entryPresent) {
return n->entry.rangeVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
}
}
// Return the max version among all keys starting with the search path of n +
// [child], where child in (begin, end)
int64_t maxBetweenExclusive(Node *n, int begin, int end) {
assume(-1 <= begin);
assume(begin <= 256);
assume(-1 <= end);
assume(end <= 256);
assume(begin < end);
int64_t result = std::numeric_limits<int64_t>::lowest();
{
int c = getChildGeq(n, begin + 1);
if (c >= 0 && c < end) {
auto *child = getChildExists(n, c);
if (child->entryPresent) {
result = std::max(result, child->entry.rangeVersion);
}
begin = c;
} else {
return result;
}
}
switch (n->type) {
case Type::Node0:
[[fallthrough]];
case Type::Node4:
[[fallthrough]];
case Type::Node16: {
auto *self = static_cast<Node16 *>(n);
for (int i = 0; i < self->numChildren && self->index[i] < end; ++i) {
if (begin <= self->index[i]) {
result = std::max(result, self->children[i].childMaxVersion);
}
}
break;
}
case Type::Node48: {
auto *self = static_cast<Node48 *>(n);
self->bitSet.forEachInRange(
[&](int i) {
result =
std::max(result, self->children[self->index[i]].childMaxVersion);
},
begin, end);
break;
}
case Type::Node256: {
auto *self = static_cast<Node256 *>(n);
self->bitSet.forEachInRange(
[&](int i) {
result = std::max(result, self->children[i].childMaxVersion);
},
begin, end);
break;
}
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "At `%s', max version in (%02x, %02x) is %" PRId64 "\n",
getSearchPathPrintable(n).c_str(), begin, end, result);
#endif
return result;
}
Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
assert(n != nullptr);
auto result = vector<uint8_t>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
return result;
}
// Return true if the max version among all keys that start with key + [child],
// where begin < child < end, is <= readVersion
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
int end, int64_t readVersion,
ConflictSet::Impl *impl) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
#endif
auto remaining = key;
for (;;) {
if (maxVersion(n, impl) <= readVersion) {
return true;
}
if (remaining.size() == 0) {
return maxBetweenExclusive(n, begin, end) <= readVersion;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
if (begin < n->partialKey()[remaining.size()] &&
n->partialKey()[remaining.size()] < end) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n, impl) <= readVersion;
}
return true;
}
}
}
downLeftSpine:
if (n == nullptr) {
return true;
}
for (;;) {
if (n->entryPresent) {
return n->entry.rangeVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
}
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are >= key is <= readVersion
struct CheckRangeLeftSide {
CheckRangeLeftSide(Node *n, std::span<const uint8_t> key, int prefixLen,
int64_t readVersion, ConflictSet::Impl *impl)
: n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion),
impl(impl) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range left side from %s for keys starting with %s\n",
printable(key).c_str(),
printable(key.subspan(0, prefixLen)).c_str());
#endif
}
Node *n;
std::span<const uint8_t> remaining;
int prefixLen;
int64_t readVersion;
ConflictSet::Impl *impl;
int searchPathLen = 0;
bool ok;
enum Phase { Search, DownLeftSpine } phase = Search;
bool step() {
switch (phase) {
case Search: {
if (maxVersion(n, impl) <= readVersion) {
ok = true;
return true;
}
if (remaining.size() == 0) {
assert(searchPathLen >= prefixLen);
ok = maxVersion(n, impl) <= readVersion;
return true;
}
if (searchPathLen >= prefixLen) {
if (maxBetweenExclusive(n, remaining[0], 256) > readVersion) {
ok = false;
return true;
}
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
if (searchPathLen < prefixLen) {
n = getChildExists(n, c);
return downLeftSpine();
}
n = getChildExists(n, c);
ok = maxVersion(n, impl) <= readVersion;
return true;
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i =
longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
if (searchPathLen < prefixLen) {
return downLeftSpine();
}
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
ok = maxVersion(n, impl) <= readVersion;
return true;
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining =
remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
if (searchPathLen < prefixLen) {
return downLeftSpine();
}
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
ok = maxVersion(n, impl) <= readVersion;
return true;
}
}
break;
}
case DownLeftSpine:
if (n->entryPresent) {
ok = n->entry.rangeVersion <= readVersion;
return true;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
break;
}
return false;
}
bool downLeftSpine() {
phase = DownLeftSpine;
if (n == nullptr) {
ok = true;
return true;
}
return false;
}
};
// Return true if the max version among all keys that start with key[:prefixLen]
// that are < key is <= readVersion
struct CheckRangeRightSide {
CheckRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen,
int64_t readVersion, ConflictSet::Impl *impl)
: n(n), key(key), remaining(key), prefixLen(prefixLen),
readVersion(readVersion), impl(impl) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range right side to %s for keys starting with %s\n",
printable(key).c_str(),
printable(key.subspan(0, prefixLen)).c_str());
#endif
}
Node *n;
std::span<const uint8_t> key;
std::span<const uint8_t> remaining;
int prefixLen;
int64_t readVersion;
ConflictSet::Impl *impl;
int searchPathLen = 0;
bool ok;
enum Phase { Search, DownLeftSpine } phase = Search;
bool step() {
switch (phase) {
case Search: {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(
stderr,
"Search path: %s, searchPathLen: %d, prefixLen: %d, remaining: %s\n",
getSearchPathPrintable(n).c_str(), searchPathLen, prefixLen,
printable(remaining).c_str());
#endif
assert(searchPathLen <= int(key.size()));
if (remaining.size() == 0) {
return downLeftSpine();
}
if (searchPathLen >= prefixLen) {
if (n->entryPresent && n->entry.pointVersion > readVersion) {
ok = false;
return true;
}
if (maxBetweenExclusive(n, -1, remaining[0]) > readVersion) {
ok = false;
return true;
}
}
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
return downLeftSpine();
} else {
return backtrack();
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i =
longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
++searchPathLen;
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
return downLeftSpine();
} else {
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
return backtrack();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining =
remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
return downLeftSpine();
}
}
} break;
case DownLeftSpine:
if (n->entryPresent) {
ok = n->entry.rangeVersion <= readVersion;
return true;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
break;
}
return false;
}
bool backtrack() {
for (;;) {
if (searchPathLen > prefixLen && maxVersion(n, impl) > readVersion) {
ok = false;
return true;
}
if (n->parent == nullptr) {
ok = true;
return true;
}
auto next = getChildGeq(n->parent, n->parentsIndex + 1);
if (next < 0) {
searchPathLen -= 1 + n->partialKeyLen;
n = n->parent;
} else {
searchPathLen -= n->partialKeyLen;
n = getChildExists(n->parent, next);
searchPathLen += n->partialKeyLen;
return downLeftSpine();
}
}
}
bool downLeftSpine() {
phase = DownLeftSpine;
if (n == nullptr) {
ok = true;
return true;
}
return false;
}
};
bool checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, int64_t readVersion,
ConflictSet::Impl *impl) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return checkPointRead(n, begin, readVersion, impl);
}
SearchStepWise search{n, begin.subspan(0, lcp)};
Arena arena;
for (;;) {
assert(getSearchPath(arena, search.n) <=>
begin.subspan(0, lcp - search.remaining.size()) ==
0);
if (maxVersion(search.n, impl) <= readVersion) {
return true;
}
if (search.step()) {
break;
}
}
assert(getSearchPath(arena, search.n) <=>
begin.subspan(0, lcp - search.remaining.size()) ==
0);
const int consumed = lcp - search.remaining.size();
assume(consumed >= 0);
begin = begin.subspan(consumed, int(begin.size()) - consumed);
end = end.subspan(consumed, int(end.size()) - consumed);
n = search.n;
lcp -= consumed;
if (lcp == int(begin.size())) {
CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion, impl};
while (!checkRangeRightSide.step())
;
return checkRangeRightSide.ok;
}
if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp],
readVersion, impl)) {
return false;
}
CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion, impl};
CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion, impl};
for (;;) {
bool leftDone = checkRangeLeftSide.step();
bool rightDone = checkRangeRightSide.step();
if (!leftDone && !rightDone) {
continue;
}
if (leftDone && rightDone) {
break;
} else if (leftDone) {
while (!checkRangeRightSide.step())
;
break;
} else {
assert(rightDone);
while (!checkRangeLeftSide.step())
;
}
break;
}
return checkRangeLeftSide.ok & checkRangeRightSide.ok;
}
// Returns a pointer to the newly inserted node. Caller must set
// `entryPresent`, `entry` fields and `maxVersion` on the result. The search
// path of the result's parent will have `maxVersion` at least `writeVersion` as
// a postcondition.
template <bool kBegin>
[[nodiscard]] Node *insert(Node **self, std::span<const uint8_t> key,
int64_t writeVersion, NodeAllocators *allocators,
ConflictSet::Impl *impl) {
for (;;) {
if ((*self)->partialKeyLen > 0) {
// Handle an existing partial key
int commonLen = std::min<int>((*self)->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix((*self)->partialKey(), key.data(), commonLen);
if (partialKeyIndex < (*self)->partialKeyLen) {
auto *old = *self;
int64_t oldMaxVersion = maxVersion(old, impl);
*self = allocators->node4.allocate(partialKeyIndex);
memcpy((char *)*self + kNodeCopyBegin, (char *)old + kNodeCopyBegin,
kNodeCopySize);
(*self)->partialKeyLen = partialKeyIndex;
(*self)->entryPresent = false;
(*self)->numChildren = 0;
memcpy((*self)->partialKey(), old->partialKey(),
(*self)->partialKeyLen);
getOrCreateChild(*self, old->partialKey()[partialKeyIndex],
allocators) = old;
old->parent = *self;
old->parentsIndex = old->partialKey()[partialKeyIndex];
maxVersion(old, impl) = oldMaxVersion;
memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1,
old->partialKeyLen - (partialKeyIndex + 1));
old->partialKeyLen -= partialKeyIndex + 1;
}
key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex);
} else {
// Consider adding a partial key
if ((*self)->numChildren == 0 && !(*self)->entryPresent) {
assert((*self)->partialKeyCapacity >= int(key.size()));
(*self)->partialKeyLen = key.size();
memcpy((*self)->partialKey(), key.data(), (*self)->partialKeyLen);
key = key.subspan((*self)->partialKeyLen,
key.size() - (*self)->partialKeyLen);
}
}
if constexpr (kBegin) {
auto &m = maxVersion(*self, impl);
assert(writeVersion >= m);
m = writeVersion;
}
if (key.size() == 0) {
return *self;
}
if constexpr (!kBegin) {
auto &m = maxVersion(*self, impl);
assert(writeVersion >= m);
m = writeVersion;
}
auto &child = getOrCreateChild(*self, key.front(), allocators);
if (!child) {
child = allocators->node0.allocate(key.size() - 1);
child->parent = *self;
child->parentsIndex = key.front();
maxVersion(child, impl) =
kBegin ? writeVersion : std::numeric_limits<int64_t>::lowest();
}
self = &child;
key = key.subspan(1, key.size() - 1);
}
}
void destroyTree(Node *root) {
Arena arena;
auto toFree = vector<Node *>(arena);
toFree.push_back(root);
while (toFree.size() > 0) {
auto *n = toFree.back();
toFree.pop_back();
// Add all children to toFree
for (int child = getChildGeq(n, 0); child >= 0;
child = getChildGeq(n, child + 1)) {
auto *c = getChildExists(n, child);
assert(c != nullptr);
toFree.push_back(c);
}
free(n);
}
}
void addPointWrite(Node *&root, int64_t oldestVersion,
std::span<const uint8_t> key, int64_t writeVersion,
NodeAllocators *allocators, ConflictSet::Impl *impl) {
auto *n = insert<true>(&root, key, writeVersion, allocators, impl);
if (!n->entryPresent) {
auto *p = nextLogical(n);
n->entryPresent = true;
n->entry.pointVersion = writeVersion;
maxVersion(n, impl) = writeVersion;
n->entry.rangeVersion =
p != nullptr ? p->entry.rangeVersion : oldestVersion;
} else {
assert(writeVersion >= n->entry.pointVersion);
n->entry.pointVersion = writeVersion;
}
}
void addWriteRange(Node *&root, int64_t oldestVersion,
std::span<const uint8_t> begin, std::span<const uint8_t> end,
int64_t writeVersion, NodeAllocators *allocators,
ConflictSet::Impl *impl) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return addPointWrite(root, oldestVersion, begin, writeVersion, allocators,
impl);
}
auto remaining = begin.subspan(0, lcp);
auto *n = root;
for (;;) {
if (int(remaining.size()) <= n->partialKeyLen) {
break;
}
int i = longestCommonPrefix(n->partialKey(), remaining.data(),
n->partialKeyLen);
if (i != n->partialKeyLen) {
break;
}
auto *child = getChild(n, remaining[n->partialKeyLen]);
if (child == nullptr) {
break;
}
auto &m = maxVersion(n, impl);
assert(writeVersion >= m);
m = writeVersion;
remaining = remaining.subspan(n->partialKeyLen + 1,
remaining.size() - (n->partialKeyLen + 1));
n = child;
}
Node **useAsRoot = n->parent == nullptr
? &root
: &getChildExists(n->parent, n->parentsIndex);
int consumed = lcp - remaining.size();
begin = begin.subspan(consumed, begin.size() - consumed);
end = end.subspan(consumed, end.size() - consumed);
auto *beginNode =
insert<true>(useAsRoot, begin, writeVersion, allocators, impl);
const bool insertedBegin = !beginNode->entryPresent;
beginNode->entryPresent = true;
if (insertedBegin) {
auto *p = nextLogical(beginNode);
beginNode->entry.rangeVersion =
p != nullptr ? p->entry.rangeVersion : oldestVersion;
beginNode->entry.pointVersion = writeVersion;
maxVersion(beginNode, impl) = writeVersion;
}
auto &m = maxVersion(beginNode, impl);
assert(writeVersion >= m);
m = writeVersion;
assert(writeVersion >= beginNode->entry.pointVersion);
beginNode->entry.pointVersion = writeVersion;
auto *endNode = insert<false>(useAsRoot, end, writeVersion, allocators, impl);
const bool insertedEnd = !endNode->entryPresent;
endNode->entryPresent = true;
if (insertedEnd) {
auto *p = nextLogical(endNode);
endNode->entry.pointVersion =
p != nullptr ? p->entry.rangeVersion : oldestVersion;
auto &m = maxVersion(endNode, impl);
m = std::max(m, endNode->entry.pointVersion);
}
endNode->entry.rangeVersion = writeVersion;
if (insertedEnd) {
// beginNode may have been invalidated
beginNode = insert<true>(useAsRoot, begin, writeVersion, allocators, impl);
assert(beginNode->entryPresent);
}
for (beginNode = nextLogical(beginNode); beginNode != endNode;) {
auto *old = beginNode;
beginNode = nextLogical(beginNode);
old->entryPresent = false;
if (old->numChildren == 0 && old->parent != nullptr) {
eraseChild(old->parent, old->parentsIndex, allocators);
}
}
}
Iterator firstGeq(Node *n, const std::span<const uint8_t> key) {
auto remaining = key;
for (;;) {
if (remaining.size() == 0) {
if (n->entryPresent) {
return {n, 0};
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
goto downLeftSpine;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
}
downLeftSpine:
if (n == nullptr) {
return {nullptr, 1};
}
for (;;) {
if (n->entryPresent) {
return {n, 1};
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
}
}
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
void check(const ReadRange *reads, Result *result, int count) {
for (int i = 0; i < count; ++i) {
const auto &r = reads[i];
auto begin = std::span<const uint8_t>(r.begin.p, r.begin.len);
auto end = std::span<const uint8_t>(r.end.p, r.end.len);
result[i] =
reads[i].readVersion < oldestVersion ? TooOld
: (end.size() > 0
? checkRangeRead(root, begin, end, reads[i].readVersion, this)
: checkPointRead(root, begin, reads[i].readVersion, this))
? Commit
: Conflict;
}
}
void addWrites(const WriteRange *writes, int count, int64_t writeVersion) {
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
auto begin = std::span<const uint8_t>(w.begin.p, w.begin.len);
auto end = std::span<const uint8_t>(w.end.p, w.end.len);
if (w.end.len > 0) {
keyUpdates += 3;
addWriteRange(root, oldestVersion, begin, end, writeVersion,
&allocators, this);
} else {
keyUpdates += 2;
addPointWrite(root, oldestVersion, begin, writeVersion, &allocators,
this);
}
}
}
void setOldestVersion(int64_t oldestVersion) {
if (oldestVersion <= this->oldestVersion) {
return;
}
this->oldestVersion = oldestVersion;
if (keyUpdates < 100) {
return;
}
Node *prev = firstGeq(root, removalKey).n;
// There's no way to erase removalKey without introducing a key after it
assert(prev != nullptr);
for (; keyUpdates > 0; --keyUpdates) {
Node *n = nextLogical(prev);
if (n == nullptr) {
removalKey = {};
return;
}
if (std::max(prev->entry.pointVersion, prev->entry.rangeVersion) <=
oldestVersion) {
// Any transaction prev would have prevented from committing is
// going to fail with TooOld anyway.
// There's no way to insert a range such that range version of the right
// node is greater than the point version of the left node
assert(n->entry.rangeVersion <= oldestVersion);
prev->entryPresent = false;
if (prev->numChildren == 0 && prev->parent != nullptr) {
eraseChild(prev->parent, prev->parentsIndex, &allocators);
}
}
prev = n;
}
removalKeyArena = Arena();
removalKey = getSearchPath(removalKeyArena, prev);
}
explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) {
// Insert ""
root = allocators.node0.allocate(0);
rootMaxVersion = oldestVersion;
root->entry.pointVersion = oldestVersion;
root->entry.rangeVersion = oldestVersion;
root->entryPresent = true;
}
~Impl() { destroyTree(root); }
NodeAllocators allocators;
Arena removalKeyArena;
std::span<const uint8_t> removalKey;
int64_t keyUpdates = 0;
Node *root;
int64_t rootMaxVersion;
int64_t oldestVersion;
};
// Precondition - an entry for index must exist in the node
int64_t &maxVersion(Node *n, ConflictSet::Impl *impl) {
int index = n->parentsIndex;
n = n->parent;
if (n == nullptr) {
return impl->rootMaxVersion;
}
if (n->type <= Type::Node16) {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndex(n16, index);
return n16->children[i].childMaxVersion;
} else if (n->type == Type::Node48) {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
return n48->children[n48->index[index]].childMaxVersion;
} else {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
return n256->children[index].childMaxVersion;
}
}
// ==================== END IMPLEMENTATION ====================
// GCOVR_EXCL_START
void ConflictSet::check(const ReadRange *reads, Result *results,
int count) const {
return impl->check(reads, results, count);
}
void ConflictSet::addWrites(const WriteRange *writes, int count,
int64_t writeVersion) {
return impl->addWrites(writes, count, writeVersion);
}
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
return impl->setOldestVersion(oldestVersion);
}
ConflictSet::ConflictSet(int64_t oldestVersion)
: impl(new (safe_malloc(sizeof(Impl))) Impl{oldestVersion}) {}
ConflictSet::~ConflictSet() {
if (impl) {
impl->~Impl();
free(impl);
}
}
#if SHOW_MEMORY
__attribute__((visibility("default"))) void showMemory(const ConflictSet &cs) {
ConflictSet::Impl *impl;
memcpy(&impl, &cs, sizeof(impl)); // NOLINT
fprintf(stderr, "Max Node0 memory usage: %" PRId64 "\n",
impl->allocators.node0.highWaterMarkBytes());
fprintf(stderr, "Max Node4 memory usage: %" PRId64 "\n",
impl->allocators.node4.highWaterMarkBytes());
fprintf(stderr, "Max Node16 memory usage: %" PRId64 "\n",
impl->allocators.node16.highWaterMarkBytes());
fprintf(stderr, "Max Node48 memory usage: %" PRId64 "\n",
impl->allocators.node48.highWaterMarkBytes());
fprintf(stderr, "Max Node256 memory usage: %" PRId64 "\n",
impl->allocators.node256.highWaterMarkBytes());
}
#endif
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
impl = std::exchange(other.impl, nullptr);
return *this;
}
using ConflictSet_Result = ConflictSet::Result;
using ConflictSet_Key = ConflictSet::Key;
using ConflictSet_ReadRange = ConflictSet::ReadRange;
using ConflictSet_WriteRange = ConflictSet::WriteRange;
extern "C" {
__attribute__((__visibility__("default"))) void
ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads,
ConflictSet_Result *results, int count) {
((ConflictSet::Impl *)cs)->check(reads, results, count);
}
__attribute__((__visibility__("default"))) void
ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count,
int64_t writeVersion) {
((ConflictSet::Impl *)cs)->addWrites(writes, count, writeVersion);
}
__attribute__((__visibility__("default"))) void
ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) {
((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion);
}
__attribute__((__visibility__("default"))) void *
ConflictSet_create(int64_t oldestVersion) {
return new (safe_malloc(sizeof(ConflictSet::Impl)))
ConflictSet::Impl{oldestVersion};
}
__attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) {
using Impl = ConflictSet::Impl;
((Impl *)cs)->~Impl();
free(cs);
}
}
namespace {
std::string getSearchPathPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = vector<char>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
if (result.size() > 0) {
return printable(std::string_view((const char *)&result[0],
result.size())); // NOLINT
} else {
return std::string();
}
}
std::string getPartialKeyPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = std::string((const char *)&n->parentsIndex,
n->parent == nullptr ? 0 : 1) +
std::string((const char *)n->partialKey(), n->partialKeyLen);
return printable(result); // NOLINT
}
std::string strinc(std::string_view str, bool &ok) {
int index;
for (index = str.size() - 1; index >= 0; index--)
if ((uint8_t &)(str[index]) != 255)
break;
// Must not be called with a string that consists only of zero or more '\xff'
// bytes.
if (index < 0) {
ok = false;
return {};
}
ok = true;
auto r = std::string(str.substr(0, index + 1));
((uint8_t &)r[r.size() - 1])++;
return r;
}
std::string getSearchPath(Node *n) {
assert(n != nullptr);
Arena arena;
auto result = getSearchPath(arena, n);
return std::string((const char *)result.data(), result.size());
}
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node,
ConflictSet::Impl *impl) {
constexpr int kSeparation = 3;
struct DebugDotPrinter {
explicit DebugDotPrinter(FILE *file, ConflictSet::Impl *impl)
: file(file), impl(impl) {}
void print(Node *n, int y = 0) {
assert(n != nullptr);
if (n->entryPresent) {
fprintf(file,
" k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64
"\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, maxVersion(n, impl), n->entry.pointVersion,
n->entry.rangeVersion, getPartialKeyPrintable(n).c_str(), x, y);
} else {
fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, maxVersion(n, impl),
getPartialKeyPrintable(n).c_str(), x, y);
}
x += kSeparation;
for (int child = getChildGeq(n, 0); child >= 0;
child = getChildGeq(n, child + 1)) {
auto *c = getChildExists(n, child);
fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c);
print(c, y - kSeparation);
}
}
int x = 0;
FILE *file;
ConflictSet::Impl *impl;
};
fprintf(file, "digraph ConflictSet {\n");
fprintf(file, " node [shape = box];\n");
assert(node != nullptr);
DebugDotPrinter printer{file, impl};
printer.print(node);
fprintf(file, "}\n");
}
void checkParentPointers(Node *node, bool &success) {
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChildExists(node, i);
if (child->parent != node) {
fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n",
getSearchPathPrintable(node).c_str(), i, (void *)child->parent,
(void *)node);
success = false;
}
checkParentPointers(child, success);
}
}
Iterator firstGeq(Node *n, std::string_view key) {
return firstGeq(
n, std::span<const uint8_t>((const uint8_t *)key.data(), key.size()));
}
[[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node,
int64_t oldestVersion, bool &success,
ConflictSet::Impl *impl) {
int64_t expected = std::numeric_limits<int64_t>::lowest();
if (node->entryPresent) {
expected = std::max(expected, node->entry.pointVersion);
}
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChildExists(node, i);
expected = std::max(
expected, checkMaxVersion(root, child, oldestVersion, success, impl));
if (child->entryPresent) {
expected = std::max(expected, child->entry.rangeVersion);
}
}
auto key = getSearchPath(root);
bool ok;
auto inc = strinc(key, ok);
if (ok) {
auto borrowed = firstGeq(root, inc);
if (borrowed.n != nullptr) {
expected = std::max(expected, borrowed.n->entry.rangeVersion);
}
}
if (node->parent != nullptr &&
getChildMaxVersion(node->parent, node->parentsIndex) !=
maxVersion(node, impl)) {
fprintf(stderr,
"%s has max version %" PRId64
" . But parent has child max version %" PRId64 "\n",
getSearchPathPrintable(node).c_str(), maxVersion(node, impl),
getChildMaxVersion(node->parent, node->parentsIndex));
success = false;
}
if (maxVersion(node, impl) > oldestVersion &&
maxVersion(node, impl) != expected) {
fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",
getSearchPathPrintable(node).c_str(), maxVersion(node, impl),
expected);
success = false;
}
return expected;
}
[[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) {
int64_t total = node->entryPresent;
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChildExists(node, i);
int64_t e = checkEntriesExist(child, success);
total += e;
if (e == 0) {
Arena arena;
fprintf(stderr, "%s has child %02x with no reachable entries\n",
getSearchPathPrintable(node).c_str(), i);
success = false;
}
}
return total;
}
bool checkCorrectness(Node *node, int64_t oldestVersion,
ConflictSet::Impl *impl) {
bool success = true;
checkParentPointers(node, success);
checkMaxVersion(node, node, oldestVersion, success, impl);
checkEntriesExist(node, success);
return success;
}
} // namespace
namespace std {
void __throw_length_error(const char *) { __builtin_unreachable(); }
} // namespace std
#ifdef ENABLE_MAIN
void printTree() {
int64_t writeVersion = 0;
ConflictSet::Impl cs{writeVersion};
ReferenceImpl refImpl{writeVersion};
Arena arena;
ConflictSet::WriteRange write;
write.begin = "and"_s;
write.end = "ant"_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "any"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "are"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "art"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
debugPrintDot(stdout, cs.root, &cs);
}
#define ANKERL_NANOBENCH_IMPLEMENT
#include "third_party/nanobench.h"
int main(void) {
printTree();
return 0;
ankerl::nanobench::Bench bench;
ConflictSet::Impl cs{0};
for (int j = 0; j < 256; ++j) {
getOrCreateChild(cs.root, j, &cs.allocators) =
cs.allocators.node0.allocate(0);
if (j % 10 == 0) {
bench.run("MaxExclusive " + std::to_string(j), [&]() {
bench.doNotOptimizeAway(maxBetweenExclusive(cs.root, 0, 256));
});
}
}
return 0;
}
#endif
#ifdef ENABLE_FUZZ
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
TestDriver<ConflictSet::Impl> driver{data, size};
for (;;) {
bool done = driver.next();
if (!driver.ok) {
debugPrintDot(stdout, driver.cs.root, &driver.cs);
fflush(stdout);
abort();
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check correctness\n");
#endif
bool success =
checkCorrectness(driver.cs.root, driver.cs.oldestVersion, &driver.cs);
if (!success) {
debugPrintDot(stdout, driver.cs.root, &driver.cs);
fflush(stdout);
abort();
}
if (done) {
break;
}
}
return 0;
}
#endif
// GCOVR_EXCL_STOP