Files
conflict-set/ConflictSet.cpp
Andrew Noyes 4f32ecc26e Make "begin" a template parameter to insert
cachegrind says this saves instructions
2024-02-23 14:00:55 -08:00

2098 lines
58 KiB
C++

/*
Copyright 2024 Andrew Noyes
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "ConflictSet.h"
#include "Internal.h"
#include <algorithm>
#include <bit>
#include <cassert>
#include <compare>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <inttypes.h>
#include <limits>
#include <span>
#include <string>
#include <string_view>
#include <utility>
#ifdef HAS_AVX
#include <immintrin.h>
#elif defined(HAS_ARM_NEON)
#include <arm_neon.h>
#endif
#include <memcheck.h>
// ==================== BEGIN IMPLEMENTATION ====================
constexpr int kSparseScanThreshold = 32;
struct Entry {
int64_t pointVersion;
int64_t rangeVersion;
};
template <class T, size_t kMemoryBound = (1 << 20)>
struct BoundedFreeListAllocator {
static_assert(sizeof(T) >= sizeof(void *));
T *allocate() {
if (freeListSize == 0) {
assert(freeList == nullptr);
return new (safe_malloc(sizeof(T))) T;
}
assert(freeList != nullptr);
void *buffer = freeList;
VALGRIND_MAKE_MEM_DEFINED(freeList, sizeof(freeList));
memcpy(&freeList, freeList, sizeof(freeList));
--freeListSize;
VALGRIND_MAKE_MEM_UNDEFINED(buffer, sizeof(T));
return new (buffer) T;
}
void release(T *p) {
p->~T();
if (freeListSize == kMaxFreeListSize) {
return free(p);
}
memcpy((void *)p, &freeList, sizeof(freeList));
freeList = p;
++freeListSize;
VALGRIND_MAKE_MEM_NOACCESS(p, sizeof(T));
}
~BoundedFreeListAllocator() {
for (void *iter = freeList; iter != nullptr;) {
VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(iter));
auto *tmp = iter;
memcpy(&iter, iter, sizeof(void *));
free(tmp);
}
}
private:
static constexpr int kMaxFreeListSize = kMemoryBound / sizeof(T);
int freeListSize = 0;
void *freeList = nullptr;
};
struct BitSet {
bool test(int i) const {
assert(0 <= i);
assert(i < 256);
if (i < 128) {
return (lo >> i) & 1;
} else {
return (hi >> (i - 128)) & 1;
}
}
void set(int i) {
assert(0 <= i);
assert(i < 256);
if (i < 128) {
lo |= __uint128_t(1) << i;
} else {
hi |= __uint128_t(1) << (i - 128);
}
}
void reset(int i) {
assert(0 <= i);
assert(i < 256);
if (i < 128) {
lo &= ~(__uint128_t(1) << i);
} else {
hi &= ~(__uint128_t(1) << (i - 128));
}
}
int firstSetGeq(int i) const {
assert(0 <= i);
if (i >= 256) {
return -1;
}
if (i < 128) {
int a = std::countr_zero(lo >> i);
if (a < 128) {
assert(i + a < 128);
return i + a;
}
i = 128;
}
int b = std::countr_zero(hi >> (i - 128));
if (b < 128) {
assert(i + b < 256);
return i + b;
}
return -1;
}
private:
__uint128_t lo = 0;
__uint128_t hi = 0;
};
enum class Type : int8_t {
Node4,
Node16,
Node48,
Node256,
Invalid,
};
struct Node {
/* begin section that's copied to the next node */
Node *parent = nullptr;
// The max write version over all keys that start with the search path up to
// this point
int64_t maxVersion;
Entry entry;
int16_t numChildren = 0;
bool entryPresent = false;
uint8_t parentsIndex = 0;
constexpr static auto kPartialKeyMaxLen = 26;
uint8_t partialKey[kPartialKeyMaxLen];
int8_t partialKeyLen = 0;
/* end section that's copied to the next node */
Type type = Type::Invalid;
};
struct Node4 : Node {
// Sorted
uint8_t index[16]; // 16 so that we can use the same simd index search
// implementation for Node4 as Node16
Node *children[4];
Node4() { this->type = Type::Node4; }
};
struct Node16 : Node {
// Sorted
uint8_t index[16];
Node *children[16];
Node16() { this->type = Type::Node16; }
};
struct Node48 : Node {
BitSet bitSet;
Node *children[48];
int8_t nextFree = 0;
int8_t index[256];
Node48() {
memset(index, -1, 256);
this->type = Type::Node48;
}
};
struct Node256 : Node {
BitSet bitSet;
Node *children[256] = {};
Node256() { this->type = Type::Node256; }
};
struct NodeAllocators {
BoundedFreeListAllocator<Node4> node4;
BoundedFreeListAllocator<Node16> node16;
BoundedFreeListAllocator<Node48> node48;
BoundedFreeListAllocator<Node256> node256;
};
int getNodeIndex(Node16 *self, uint8_t index) {
#ifdef HAS_AVX
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
// key_vec is 16 repeated copies of the searched-for byte, one for every
// possible position in child_keys that needs to be searched.
__m128i key_vec = _mm_set1_epi8(index);
// Compare all child_keys to 'index' in parallel. Don't worry if some of the
// keys aren't valid, we'll mask the results to only consider the valid ones
// below.
__m128i indices;
memcpy(&indices, self->index, sizeof(self->index));
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
// Build a mask to select only the first node->num_children values from the
// comparison (because the other values are meaningless)
uint32_t mask = (1 << self->numChildren) - 1;
// Change the results of the comparison into a bitfield, masking off any
// invalid comparisons.
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
// No match if there are no '1's in the bitfield.
if (bitfield == 0)
return -1;
// Find the index of the first '1' in the bitfield by counting the leading
// zeros.
return std::countr_zero(bitfield);
#elif defined(HAS_ARM_NEON)
// Based on
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
uint8x16_t indices;
memcpy(&indices, self->index, sizeof(self->index));
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each match in valid range
uint64_t bitfield =
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
if (bitfield == 0)
return -1;
return std::countr_zero(bitfield) / 4;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
}
}
return -1;
#endif
}
// Precondition - an entry for index must exist in the node
Node *&getChildExists(Node *self, uint8_t index) {
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
return self16->children[getNodeIndex(self16, index)];
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
assert(self48->bitSet.test(index));
return self48->children[self48->index[index]];
} else {
auto *self256 = static_cast<Node256 *>(self);
return self256->children[index];
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
Node *getChild(Node *self, uint8_t index) {
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
int i = getNodeIndex(self16, index);
if (i >= 0) {
return self16->children[i];
}
return nullptr;
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
int secondIndex = self48->index[index];
if (secondIndex >= 0) {
return self48->children[secondIndex];
}
return nullptr;
} else {
auto *self256 = static_cast<Node256 *>(self);
return self256->children[index];
}
}
int getChildGeq(Node *self, int child) {
if (child > 255) {
return -1;
}
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(child);
__m128i indices;
memcpy(&indices, self16->index, sizeof(self16->index));
__m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices));
int mask = (1 << self16->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
int result = bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield)];
assert(result == [&]() -> int {
for (int i = 0; i < self16->numChildren; ++i) {
if (self16->index[i] >= child) {
return self16->index[i];
}
}
return -1;
}());
return result;
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self16->index, sizeof(self16->index));
// 0xff for each leq
auto results = vcleq_u8(vdupq_n_u8(child), indices);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
int simd =
bitfield == 0 ? -1 : self16->index[std::countr_zero(bitfield) / 4];
assert(simd == [&]() -> int {
for (int i = 0; i < self->numChildren; ++i) {
if (self16->index[i] >= child) {
return self16->index[i];
}
}
return -1;
}());
return simd;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (i > 0) {
assert(self16->index[i - 1] < self16->index[i]);
}
if (self16->index[i] >= child) {
return self16->index[i];
}
}
#endif
} else {
static_assert(offsetof(Node48, bitSet) == offsetof(Node256, bitSet));
auto *self48 = static_cast<Node48 *>(self);
return self48->bitSet.firstSetGeq(child);
}
return -1;
}
void setChildrenParents(Node16 *n) {
for (int i = 0; i < n->numChildren; ++i) {
n->children[i]->parent = n;
}
}
void setChildrenParents(Node48 *n) {
if (n->numChildren < kSparseScanThreshold) {
for (int i = n->bitSet.firstSetGeq(0); i >= 0;
i = n->bitSet.firstSetGeq(i + 1)) {
n->children[n->index[i]]->parent = n;
}
} else {
for (int i = 0; i < 256; ++i) {
int c = n->index[i];
if (c != -1) {
n->children[c]->parent = n;
}
}
}
}
void setChildrenParents(Node256 *n) {
if (n->numChildren < kSparseScanThreshold) {
for (int i = n->bitSet.firstSetGeq(0); i >= 0;
i = n->bitSet.firstSetGeq(i + 1)) {
n->children[i]->parent = n;
}
} else {
for (int i = 0; i < 256; ++i) {
auto *child = n->children[i];
if (child != nullptr) {
child->parent = n;
}
}
}
}
// Caller is responsible for assigning a non-null pointer to the returned
// reference if null
Node *&getOrCreateChild(Node *&self, uint8_t index,
NodeAllocators *allocators) {
// Fast path for if it exists already
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
int i = getNodeIndex(self16, index);
if (i >= 0) {
return self16->children[i];
}
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
int secondIndex = self48->index[index];
if (secondIndex >= 0) {
return self48->children[secondIndex];
}
} else {
auto *self256 = static_cast<Node256 *>(self);
if (auto &result = self256->children[index]; result != nullptr) {
return result;
}
}
if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
if (self->numChildren == 4) {
auto *newSelf = allocators->node16.allocate();
memcpy((void *)newSelf, self, sizeof(Node4));
newSelf->type = Type::Node16;
allocators->node4.release(self4);
setChildrenParents(newSelf);
self = newSelf;
}
goto insert16;
} else if (self->type == Type::Node16) {
if (self->numChildren == 16) {
auto *self16 = static_cast<Node16 *>(self);
auto *newSelf = allocators->node48.allocate();
memcpy((void *)newSelf, self, offsetof(Node, type));
newSelf->nextFree = 16;
int i = 0;
for (auto x : self16->index) {
newSelf->bitSet.set(x);
newSelf->children[i] = self16->children[i];
newSelf->index[x] = i;
++i;
}
assert(i == 16);
allocators->node16.release(self16);
setChildrenParents(newSelf);
self = newSelf;
goto insert48;
}
insert16:
auto *self16 = static_cast<Node16 *>(self);
++self->numChildren;
int i = 0;
for (; i < int(self->numChildren) - 1; ++i) {
if (int(self16->index[i]) > int(index)) {
memmove(self16->index + i + 1, self16->index + i,
self->numChildren - (i + 1));
memmove(self16->children + i + 1, self16->children + i,
(self->numChildren - (i + 1)) * sizeof(void *));
break;
}
}
self16->index[i] = index;
auto &result = self16->children[i];
result = nullptr;
return result;
} else if (self->type == Type::Node48) {
if (self->numChildren == 48) {
auto *self48 = static_cast<Node48 *>(self);
auto *newSelf = allocators->node256.allocate();
memcpy((void *)newSelf, self, offsetof(Node, type));
newSelf->bitSet = self48->bitSet;
for (int i = 0; i < 256; ++i) {
int c = self48->index[i];
if (c >= 0) {
newSelf->children[i] = self48->children[c];
}
}
allocators->node48.release(self48);
setChildrenParents(newSelf);
self = newSelf;
goto insert256;
}
insert48:
auto *self48 = static_cast<Node48 *>(self);
self48->bitSet.set(index);
++self->numChildren;
assert(self48->nextFree < 48);
int nextFree = self48->nextFree++;
self48->index[index] = nextFree;
auto &result = self48->children[nextFree];
result = nullptr;
return result;
} else {
insert256:
auto *self256 = static_cast<Node256 *>(self);
++self->numChildren;
self256->bitSet.set(index);
return self256->children[index];
}
}
// Precondition - an entry for index must exist in the node
void eraseChild(Node *self, uint8_t index, NodeAllocators *allocators) {
auto *child = getChildExists(self, index);
switch (child->type) {
case Type::Node4:
allocators->node4.release((Node4 *)child);
break;
case Type::Node16:
allocators->node16.release((Node16 *)child);
break;
case Type::Node48:
allocators->node48.release((Node48 *)child);
break;
case Type::Node256:
allocators->node256.release((Node256 *)child);
break;
case Type::Invalid:
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
if (self->type <= Type::Node16) {
auto *self16 = static_cast<Node16 *>(self);
int nodeIndex = getNodeIndex(self16, index);
memmove(self16->index + nodeIndex, self16->index + nodeIndex + 1,
sizeof(self16->index[0]) * (self->numChildren - (nodeIndex + 1)));
memmove(self16->children + nodeIndex, self16->children + nodeIndex + 1,
sizeof(self16->children[0]) * // NOLINT
(self->numChildren - (nodeIndex + 1)));
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
self48->bitSet.reset(index);
int8_t toRemoveChildrenIndex = std::exchange(self48->index[index], -1);
int8_t lastChildrenIndex = --self48->nextFree;
assert(toRemoveChildrenIndex >= 0);
assert(lastChildrenIndex >= 0);
if (toRemoveChildrenIndex != lastChildrenIndex) {
self48->children[toRemoveChildrenIndex] =
std::exchange(self48->children[lastChildrenIndex], nullptr);
self48->index[self48->children[toRemoveChildrenIndex]->parentsIndex] =
toRemoveChildrenIndex;
}
} else {
auto *self256 = static_cast<Node256 *>(self);
self256->bitSet.reset(index);
self256->children[index] = nullptr;
}
--self->numChildren;
if (self->numChildren == 0 && !self->entryPresent &&
self->parent != nullptr) {
eraseChild(self->parent, self->parentsIndex, allocators);
}
}
Node *nextPhysical(Node *node) {
int index = -1;
for (;;) {
auto nextChild = getChildGeq(node, index + 1);
if (nextChild >= 0) {
return getChildExists(node, nextChild);
}
index = node->parentsIndex;
node = node->parent;
if (node == nullptr) {
return nullptr;
}
}
}
Node *nextLogical(Node *node) {
for (node = nextPhysical(node); node != nullptr && !node->entryPresent;
node = nextPhysical(node))
;
return node;
}
struct Iterator {
Node *n;
int cmp;
};
Node *nextSibling(Node *node) {
for (;;) {
if (node->parent == nullptr) {
return nullptr;
}
auto next = getChildGeq(node->parent, node->parentsIndex + 1);
if (next < 0) {
node = node->parent;
} else {
return getChildExists(node->parent, next);
}
}
}
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
constexpr int kStride = 64;
#else
constexpr int kStride = 16;
#endif
constexpr int kUnrollFactor = 4;
bool compareStride(const uint8_t *ap, const uint8_t *bp) {
#if defined(HAS_ARM_NEON)
static_assert(kStride == 64);
uint8x16_t x[4];
for (int i = 0; i < 4; ++i) {
x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16));
}
auto results = vreinterpretq_u16_u8(
vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3])));
bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) ==
uint64_t(-1);
#elif defined(HAS_AVX)
static_assert(kStride == 64);
__m128i x[4];
for (int i = 0; i < 4; ++i) {
x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)),
_mm_loadu_si128((__m128i *)(bp + i * 16)));
}
auto eq =
_mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]),
_mm_and_si128(x[2], x[3]))) == 0xffff;
#else
// Hope it gets vectorized
auto eq = memcmp(ap, bp, kStride) == 0;
#endif
return eq;
}
// Precondition: ap[:kStride] != bp[:kStride]
int firstNeqStride(const uint8_t *ap, const uint8_t *bp) {
#if defined(HAS_AVX)
static_assert(kStride == 64);
uint64_t c[kStride / 16];
for (int i = 0; i < kStride; i += 16) {
const auto a = _mm_loadu_si128((__m128i *)(ap + i));
const auto b = _mm_loadu_si128((__m128i *)(bp + i));
const auto compared = _mm_cmpeq_epi8(a, b);
c[i / 16] = _mm_movemask_epi8(compared) & 0xffff;
}
return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48));
#elif defined(HAS_ARM_NEON)
static_assert(kStride == 64);
for (int i = 0; i < kStride; i += 16) {
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i)));
// 0xf for each mismatch
uint64_t bitfield =
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
if (bitfield) {
return i + (std::countr_zero(bitfield) >> 2);
}
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
#else
int i = 0;
for (; i < kStride - 1; ++i) {
if (*ap++ != *bp++) {
break;
}
}
return i;
#endif
}
int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
if (cl < 0) {
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
int i = 0;
int end;
if (cl < 8) {
goto bytes;
}
// Optimistic early return
{
uint64_t a;
uint64_t b;
memcpy(&a, ap, 8);
memcpy(&b, bp, 8);
const auto mismatched = a ^ b;
if (mismatched) {
return std::countr_zero(mismatched) / 8;
}
}
// kStride * kUnrollCount at a time
end = cl & ~(kStride * kUnrollFactor - 1);
while (i < end) {
for (int j = 0; j < kUnrollFactor; ++j) {
if (!compareStride(ap, bp)) {
return i + firstNeqStride(ap, bp);
}
i += kStride;
ap += kStride;
bp += kStride;
}
}
// kStride at a time
end = cl & ~(kStride - 1);
while (i < end) {
if (!compareStride(ap, bp)) {
return i + firstNeqStride(ap, bp);
}
i += kStride;
ap += kStride;
bp += kStride;
}
// word at a time
end = cl & ~(sizeof(uint64_t) - 1);
while (i < end) {
uint64_t a;
uint64_t b;
memcpy(&a, ap, 8);
memcpy(&b, bp, 8);
const auto mismatched = a ^ b;
if (mismatched) {
return i + std::countr_zero(mismatched) / 8;
}
i += 8;
ap += 8;
bp += 8;
}
bytes:
// byte at a time
while (i < cl) {
if (*ap != *bp) {
break;
}
++ap;
++bp;
++i;
}
return i;
}
int longestCommonPrefixPartialKey(const uint8_t *ap, const uint8_t *bp,
int cl) {
assert(cl <= Node::kPartialKeyMaxLen);
int i = 0;
for (; i < cl; ++i) {
if (*ap++ != *bp++) {
break;
}
}
return i;
}
// Performs a physical search for remaining
struct SearchStepWise {
Node *n;
std::span<const uint8_t> remaining;
SearchStepWise() {}
SearchStepWise(Node *n, std::span<const uint8_t> remaining)
: n(n), remaining(remaining) {
assert(n->partialKeyLen == 0);
}
bool step() {
if (remaining.size() == 0) {
return true;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
return true;
}
int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1);
int i = longestCommonPrefixPartialKey(child->partialKey,
remaining.data() + 1, cl);
if (i != child->partialKeyLen) {
return true;
}
n = child;
remaining =
remaining.subspan(1 + child->partialKeyLen,
remaining.size() - (1 + child->partialKeyLen));
return false;
}
};
namespace {
std::string getSearchPathPrintable(Node *n);
}
// Logically this is the same as performing firstGeq and then checking against
// point or range version according to cmp, but this version short circuits as
// soon as it can prove that there's no conflict.
bool checkPointRead(Node *n, const std::span<const uint8_t> key,
int64_t readVersion) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;;) {
if (n->maxVersion <= readVersion) {
return true;
}
if (remaining.size() == 0) {
if (n->entryPresent) {
return n->entry.pointVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
goto downLeftSpine;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(),
commonLen);
if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
}
downLeftSpine:
if (n == nullptr) {
return true;
}
for (;;) {
if (n->entryPresent) {
return n->entry.rangeVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
}
}
int64_t maxBetweenExclusive(Node *n, int begin, int end) {
assert(-1 <= begin);
assert(begin <= 256);
assert(-1 <= end);
assert(end <= 256);
assert(begin < end);
int64_t result = std::numeric_limits<int64_t>::lowest();
{
int c = getChildGeq(n, begin + 1);
if (c >= 0 && c < end) {
auto *child = getChildExists(n, c);
if (child->entryPresent) {
result = std::max(result, child->entry.rangeVersion);
}
}
}
switch (n->type) {
case Type::Node4:
[[fallthrough]];
case Type::Node16: {
auto *self = static_cast<Node16 *>(n);
for (int i = 0; i < self->numChildren && self->index[i] < end; ++i) {
if (begin < self->index[i]) {
result = std::max(result, self->children[i]->maxVersion);
}
}
break;
}
case Type::Node48: {
auto *self = static_cast<Node48 *>(n);
if (self->numChildren < kSparseScanThreshold) {
for (int i = self->bitSet.firstSetGeq(begin + 1); i < end && i >= 0;
i = self->bitSet.firstSetGeq(i + 1)) {
if (self->index[i] != -1) {
result = std::max(result, self->children[self->index[i]]->maxVersion);
}
}
} else {
for (int i = begin + 1; i < end; ++i) {
if (self->index[i] != -1) {
result = std::max(result, self->children[self->index[i]]->maxVersion);
}
}
}
break;
}
case Type::Node256: {
auto *self = static_cast<Node256 *>(n);
if (self->numChildren < kSparseScanThreshold) {
for (int i = self->bitSet.firstSetGeq(begin + 1); i < end && i >= 0;
i = self->bitSet.firstSetGeq(i + 1)) {
result = std::max(result, self->children[i]->maxVersion);
}
} else {
for (int i = begin + 1; i < end; ++i) {
if (self->children[i] != nullptr) {
result = std::max(result, self->children[i]->maxVersion);
}
}
}
break;
}
case Type::Invalid:
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "At `%s', max version in (%02x, %02x) is %" PRId64 "\n",
getSearchPathPrintable(n).c_str(), begin, end, result);
#endif
return result;
}
Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
assert(n != nullptr);
auto result = vector<uint8_t>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
return result;
}
// Return true if the max version among all keys that start with key + [child],
// where begin < child < end, is <= readVersion
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
int end, int64_t readVersion) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
#endif
auto remaining = key;
for (;;) {
if (n->maxVersion <= readVersion) {
return true;
}
if (remaining.size() == 0) {
return maxBetweenExclusive(n, begin, end) <= readVersion;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(),
commonLen);
if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
if (begin < n->partialKey[remaining.size()] &&
n->partialKey[remaining.size()] < end) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return n->maxVersion <= readVersion;
}
return true;
}
}
}
downLeftSpine:
if (n == nullptr) {
return true;
}
for (;;) {
if (n->entryPresent) {
return n->entry.rangeVersion <= readVersion;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
}
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are >= key is <= readVersion
struct CheckRangeLeftSide {
CheckRangeLeftSide(Node *n, std::span<const uint8_t> key, int prefixLen,
int64_t readVersion)
: n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range left side from %s for keys starting with %s\n",
printable(key).c_str(),
printable(key.subspan(0, prefixLen)).c_str());
#endif
}
Node *n;
std::span<const uint8_t> remaining;
int prefixLen;
int64_t readVersion;
int searchPathLen = 0;
bool ok;
enum Phase { Search, DownLeftSpine } phase = Search;
bool step() {
switch (phase) {
case Search: {
if (n->maxVersion <= readVersion) {
ok = true;
return true;
}
if (remaining.size() == 0) {
assert(searchPathLen >= prefixLen);
ok = n->maxVersion <= readVersion;
return true;
}
if (searchPathLen >= prefixLen) {
if (maxBetweenExclusive(n, remaining[0], 256) > readVersion) {
ok = false;
return true;
}
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
if (searchPathLen < prefixLen) {
n = getChildExists(n, c);
return downLeftSpine();
}
n = getChildExists(n, c);
ok = n->maxVersion <= readVersion;
return true;
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(),
commonLen);
searchPathLen += i;
if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i];
if (c > 0) {
if (searchPathLen < prefixLen) {
return downLeftSpine();
}
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
ok = n->maxVersion <= readVersion;
return true;
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining =
remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
if (searchPathLen < prefixLen) {
return downLeftSpine();
}
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
ok = n->maxVersion <= readVersion;
return true;
}
}
break;
}
case DownLeftSpine:
if (n->entryPresent) {
ok = n->entry.rangeVersion <= readVersion;
return true;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
break;
}
return false;
}
bool downLeftSpine() {
phase = DownLeftSpine;
if (n == nullptr) {
ok = true;
return true;
}
return false;
}
};
// Return true if the max version among all keys that start with key[:prefixLen]
// that are < key is <= readVersion
struct CheckRangeRightSide {
CheckRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen,
int64_t readVersion)
: n(n), key(key), remaining(key), prefixLen(prefixLen),
readVersion(readVersion) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range right side to %s for keys starting with %s\n",
printable(key).c_str(),
printable(key.subspan(0, prefixLen)).c_str());
#endif
}
Node *n;
std::span<const uint8_t> key;
std::span<const uint8_t> remaining;
int prefixLen;
int64_t readVersion;
int searchPathLen = 0;
bool ok;
enum Phase { Search, DownLeftSpine } phase = Search;
bool step() {
switch (phase) {
case Search: {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(
stderr,
"Search path: %s, searchPathLen: %d, prefixLen: %d, remaining: %s\n",
getSearchPathPrintable(n).c_str(), searchPathLen, prefixLen,
printable(remaining).c_str());
#endif
assert(searchPathLen <= int(key.size()));
if (remaining.size() == 0) {
return downLeftSpine();
}
if (searchPathLen >= prefixLen) {
if (n->entryPresent && n->entry.pointVersion > readVersion) {
ok = false;
return true;
}
if (maxBetweenExclusive(n, -1, remaining[0]) > readVersion) {
ok = false;
return true;
}
}
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
return downLeftSpine();
} else {
return backtrack();
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(),
commonLen);
searchPathLen += i;
if (i < commonLen) {
++searchPathLen;
auto c = n->partialKey[i] <=> remaining[i];
if (c > 0) {
return downLeftSpine();
} else {
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
return backtrack();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining =
remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
return downLeftSpine();
}
}
} break;
case DownLeftSpine:
if (n->entryPresent) {
ok = n->entry.rangeVersion <= readVersion;
return true;
}
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
break;
}
return false;
}
bool backtrack() {
for (;;) {
if (searchPathLen > prefixLen && n->maxVersion > readVersion) {
ok = false;
return true;
}
if (n->parent == nullptr) {
ok = true;
return true;
}
auto next = getChildGeq(n->parent, n->parentsIndex + 1);
if (next < 0) {
searchPathLen -= 1 + n->partialKeyLen;
n = n->parent;
} else {
searchPathLen -= n->partialKeyLen;
n = getChildExists(n->parent, next);
searchPathLen += n->partialKeyLen;
return downLeftSpine();
}
}
}
bool downLeftSpine() {
phase = DownLeftSpine;
if (n == nullptr) {
ok = true;
return true;
}
return false;
}
};
bool checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, int64_t readVersion) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return checkPointRead(n, begin, readVersion);
}
SearchStepWise search{n, begin.subspan(0, lcp)};
Arena arena;
for (;;) {
assert(getSearchPath(arena, search.n) <=>
begin.subspan(0, lcp - search.remaining.size()) ==
0);
if (search.n->maxVersion <= readVersion) {
return true;
}
if (search.step()) {
break;
}
}
assert(getSearchPath(arena, search.n) <=>
begin.subspan(0, lcp - search.remaining.size()) ==
0);
const int consumed = lcp - search.remaining.size();
assert(consumed >= 0);
begin = begin.subspan(consumed, int(begin.size()) - consumed);
end = end.subspan(consumed, int(end.size()) - consumed);
n = search.n;
lcp -= consumed;
if (lcp == int(begin.size())) {
CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion};
while (!checkRangeRightSide.step())
;
return checkRangeRightSide.ok;
}
if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp],
readVersion)) {
return false;
}
CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion};
CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion};
for (;;) {
bool leftDone = checkRangeLeftSide.step();
bool rightDone = checkRangeRightSide.step();
if (!leftDone && !rightDone) {
continue;
}
if (leftDone && rightDone) {
break;
} else if (leftDone) {
while (!checkRangeRightSide.step())
;
break;
} else {
assert(rightDone);
while (!checkRangeLeftSide.step())
;
}
break;
}
return checkRangeLeftSide.ok & checkRangeRightSide.ok;
}
// Returns a pointer to the newly inserted node. caller is reponsible for
// setting 'entry' fields and `maxVersion` on the result, which may have
// !entryPresent. The search path of the result's parent will have
// `maxVersion` at least `writeVersion` as a postcondition.
template <bool kBegin>
[[nodiscard]] Node *insert(Node **self_, std::span<const uint8_t> key,
int64_t writeVersion, NodeAllocators *allocators) {
for (;;) {
auto &self = *self_;
// Handle an existing partial key
int commonLen = std::min<int>(self->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefixPartialKey(self->partialKey, key.data(), commonLen);
if (partialKeyIndex < self->partialKeyLen) {
auto *old = self;
self = allocators->node4.allocate();
memcpy((void *)self, old, offsetof(Node, type));
self->partialKeyLen = partialKeyIndex;
self->entryPresent = false;
self->numChildren = 0;
getOrCreateChild(self, old->partialKey[partialKeyIndex], allocators) =
old;
old->parent = self;
old->parentsIndex = old->partialKey[partialKeyIndex];
memmove(old->partialKey, old->partialKey + partialKeyIndex + 1,
old->partialKeyLen - (partialKeyIndex + 1));
old->partialKeyLen -= partialKeyIndex + 1;
}
key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex);
// Consider adding a partial key
if (self->numChildren == 0 && !self->entryPresent) {
self->partialKeyLen = std::min<int>(key.size(), self->kPartialKeyMaxLen);
memcpy(self->partialKey, key.data(), self->partialKeyLen);
key = key.subspan(self->partialKeyLen, key.size() - self->partialKeyLen);
}
if constexpr (kBegin) {
self->maxVersion = std::max(self->maxVersion, writeVersion);
}
if (key.size() == 0) {
return self;
}
if constexpr (!kBegin) {
self->maxVersion = std::max(self->maxVersion, writeVersion);
}
auto &child = getOrCreateChild(self, key.front(), allocators);
if (!child) {
child = allocators->node4.allocate();
child->parent = self;
child->parentsIndex = key.front();
child->maxVersion =
kBegin ? writeVersion : std::numeric_limits<int64_t>::lowest();
}
self_ = &child;
key = key.subspan(1, key.size() - 1);
}
}
void destroyTree(Node *root) {
Arena arena;
auto toFree = vector<Node *>(arena);
toFree.push_back(root);
while (toFree.size() > 0) {
auto *n = toFree.back();
toFree.pop_back();
// Add all children to toFree
for (int child = getChildGeq(n, 0); child >= 0;
child = getChildGeq(n, child + 1)) {
auto *c = getChildExists(n, child);
assert(c != nullptr);
toFree.push_back(c);
}
free(n);
}
}
void addPointWrite(Node *&root, int64_t oldestVersion,
std::span<const uint8_t> key, int64_t writeVersion,
NodeAllocators *allocators) {
auto *n = insert<true>(&root, key, writeVersion, allocators);
if (!n->entryPresent) {
auto *p = nextLogical(n);
n->entryPresent = true;
n->entry.pointVersion = writeVersion;
n->maxVersion = writeVersion;
n->entry.rangeVersion =
p != nullptr ? p->entry.rangeVersion : oldestVersion;
} else {
n->entry.pointVersion = std::max(n->entry.pointVersion, writeVersion);
}
}
void addWriteRange(Node *&root, int64_t oldestVersion,
std::span<const uint8_t> begin, std::span<const uint8_t> end,
int64_t writeVersion, NodeAllocators *allocators) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return addPointWrite(root, oldestVersion, begin, writeVersion, allocators);
}
auto remaining = begin.subspan(0, lcp);
auto *n = root;
for (;;) {
if (int(remaining.size()) <= n->partialKeyLen) {
break;
}
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(),
n->partialKeyLen);
if (i != n->partialKeyLen) {
break;
}
auto *child = getChild(n, remaining[n->partialKeyLen]);
if (child == nullptr) {
break;
}
n->maxVersion = std::max(n->maxVersion, writeVersion);
remaining = remaining.subspan(n->partialKeyLen + 1,
remaining.size() - (n->partialKeyLen + 1));
n = child;
}
Node **useAsRoot = n->parent == nullptr
? &root
: &getChildExists(n->parent, n->parentsIndex);
int consumed = lcp - remaining.size();
begin = begin.subspan(consumed, begin.size() - consumed);
end = end.subspan(consumed, end.size() - consumed);
auto *beginNode = insert<true>(useAsRoot, begin, writeVersion, allocators);
const bool insertedBegin = !std::exchange(beginNode->entryPresent, true);
if (insertedBegin) {
auto *p = nextLogical(beginNode);
beginNode->entry.rangeVersion =
p != nullptr ? p->entry.rangeVersion : oldestVersion;
beginNode->entry.pointVersion = writeVersion;
beginNode->maxVersion = writeVersion;
}
beginNode->maxVersion = std::max(beginNode->maxVersion, writeVersion);
beginNode->entry.pointVersion =
std::max(beginNode->entry.pointVersion, writeVersion);
auto *endNode = insert<false>(useAsRoot, end, writeVersion, allocators);
const bool insertedEnd = !std::exchange(endNode->entryPresent, true);
if (insertedEnd) {
auto *p = nextLogical(endNode);
endNode->entry.pointVersion =
p != nullptr ? p->entry.rangeVersion : oldestVersion;
endNode->maxVersion =
std::max(endNode->maxVersion, endNode->entry.pointVersion);
}
endNode->entry.rangeVersion = writeVersion;
if (insertedEnd) {
// beginNode may have been invalidated
beginNode = insert<true>(useAsRoot, begin, writeVersion, allocators);
}
for (beginNode = nextLogical(beginNode); beginNode != endNode;) {
auto *old = beginNode;
beginNode = nextLogical(beginNode);
old->entryPresent = false;
if (old->numChildren == 0 && old->parent != nullptr) {
eraseChild(old->parent, old->parentsIndex, allocators);
}
}
}
struct FirstGeqStepwise {
Node *n;
std::span<const uint8_t> remaining;
int cmp;
enum Phase {
Init,
// Being in this phase implies that the key matches the search path exactly
// up to this point
Search,
DownLeftSpine
};
Phase phase;
FirstGeqStepwise(Node *n, std::span<const uint8_t> remaining)
: n(n), remaining(remaining), phase(Init) {}
// Not being done implies that n is not the firstGeq
bool step() {
switch (phase) {
case Search: {
if (remaining.size() == 0) {
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
return downLeftSpine();
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
int c = getChildGeq(n, remaining[0]);
if (c >= 0) {
n = getChildExists(n, c);
return downLeftSpine();
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefixPartialKey(n->partialKey, remaining.data(),
commonLen);
if (i < commonLen) {
auto c = n->partialKey[i] <=> remaining[i];
if (c > 0) {
return downLeftSpine();
} else {
n = nextSibling(n);
return downLeftSpine();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining =
remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
return downLeftSpine();
}
}
}
[[fallthrough]];
case Init:
phase = Search;
if (remaining.size() == 0 && n->entryPresent) {
cmp = 0;
return true;
}
return false;
case DownLeftSpine:
int c = getChildGeq(n, 0);
assert(c >= 0);
n = getChildExists(n, c);
if (n->entryPresent) {
cmp = 1;
return true;
}
return false;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
bool downLeftSpine() {
phase = DownLeftSpine;
if (n == nullptr || n->entryPresent) {
cmp = 1;
return true;
}
return step();
}
};
Iterator firstGeq(Node *n, const std::span<const uint8_t> key) {
FirstGeqStepwise stepwise{n, key};
while (!stepwise.step())
;
return {stepwise.n, stepwise.cmp};
}
Iterator firstGeq(Node *n, std::string_view key) {
return firstGeq(
n, std::span<const uint8_t>((const uint8_t *)key.data(), key.size()));
}
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
void check(const ReadRange *reads, Result *result, int count) const {
for (int i = 0; i < count; ++i) {
const auto &r = reads[i];
auto begin = std::span<const uint8_t>(r.begin.p, r.begin.len);
auto end = std::span<const uint8_t>(r.end.p, r.end.len);
result[i] =
reads[i].readVersion < oldestVersion ? TooOld
: (end.size() > 0
? checkRangeRead(root, begin, end, reads[i].readVersion)
: checkPointRead(root, begin, reads[i].readVersion))
? Commit
: Conflict;
}
}
void addWrites(const WriteRange *writes, int count) {
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
auto begin = std::span<const uint8_t>(w.begin.p, w.begin.len);
auto end = std::span<const uint8_t>(w.end.p, w.end.len);
if (w.end.len > 0) {
keyUpdates += 2;
addWriteRange(root, oldestVersion, begin, end, w.writeVersion,
&allocators);
} else {
keyUpdates += 1;
addPointWrite(root, oldestVersion, begin, w.writeVersion, &allocators);
}
}
}
void setOldestVersion(int64_t oldestVersion) {
this->oldestVersion = oldestVersion;
Node *prev = firstGeq(root, removalKey).n;
// There's no way to erase removalKey without introducing a key after it
assert(prev != nullptr);
while (keyUpdates-- > 0) {
Node *n = nextLogical(prev);
if (n == nullptr) {
removalKey = {};
return;
}
if (std::max(prev->entry.pointVersion, prev->entry.rangeVersion) <=
oldestVersion) {
// Any transaction prev would have prevented from committing is
// going to fail with TooOld anyway.
// There's no way to insert a range such that range version of the right
// node is greater than the point version of the left node
assert(n->entry.rangeVersion <= oldestVersion);
prev->entryPresent = false;
if (prev->numChildren == 0 && prev->parent != nullptr) {
eraseChild(prev->parent, prev->parentsIndex, &allocators);
}
}
prev = n;
}
removalKeyArena = Arena();
removalKey = getSearchPath(removalKeyArena, prev);
}
explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) {
// Insert ""
root = allocators.node4.allocate();
root->maxVersion = oldestVersion;
root->entry.pointVersion = oldestVersion;
root->entry.rangeVersion = oldestVersion;
root->entryPresent = true;
}
~Impl() { destroyTree(root); }
NodeAllocators allocators;
Arena removalKeyArena;
std::span<const uint8_t> removalKey;
int64_t keyUpdates = 0;
Node *root;
int64_t oldestVersion;
};
// ==================== END IMPLEMENTATION ====================
// GCOVR_EXCL_START
void ConflictSet::check(const ReadRange *reads, Result *results,
int count) const {
return impl->check(reads, results, count);
}
void ConflictSet::addWrites(const WriteRange *writes, int count) {
return impl->addWrites(writes, count);
}
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
return impl->setOldestVersion(oldestVersion);
}
ConflictSet::ConflictSet(int64_t oldestVersion)
: impl(new (safe_malloc(sizeof(Impl))) Impl{oldestVersion}) {}
ConflictSet::~ConflictSet() {
if (impl) {
impl->~Impl();
free(impl);
}
}
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
impl = std::exchange(other.impl, nullptr);
return *this;
}
using ConflictSet_Result = ConflictSet::Result;
using ConflictSet_Key = ConflictSet::Key;
using ConflictSet_ReadRange = ConflictSet::ReadRange;
using ConflictSet_WriteRange = ConflictSet::WriteRange;
extern "C" {
__attribute__((__visibility__("default"))) void
ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads,
ConflictSet_Result *results, int count) {
((ConflictSet::Impl *)cs)->check(reads, results, count);
}
__attribute__((__visibility__("default"))) void
ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes,
int count) {
((ConflictSet::Impl *)cs)->addWrites(writes, count);
}
__attribute__((__visibility__("default"))) void
ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) {
((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion);
}
__attribute__((__visibility__("default"))) void *
ConflictSet_create(int64_t oldestVersion) {
return new (safe_malloc(sizeof(ConflictSet::Impl)))
ConflictSet::Impl{oldestVersion};
}
__attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) {
using Impl = ConflictSet::Impl;
((Impl *)cs)->~Impl();
free(cs);
}
}
namespace {
std::string getSearchPathPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = vector<char>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
if (result.size() > 0) {
return printable(std::string_view((const char *)&result[0],
result.size())); // NOLINT
} else {
return std::string();
}
}
std::string getPartialKeyPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = std::string((const char *)&n->parentsIndex,
n->parent == nullptr ? 0 : 1) +
std::string((const char *)n->partialKey, n->partialKeyLen);
return printable(result); // NOLINT
}
std::string strinc(std::string_view str, bool &ok) {
int index;
for (index = str.size() - 1; index >= 0; index--)
if ((uint8_t &)(str[index]) != 255)
break;
// Must not be called with a string that consists only of zero or more '\xff'
// bytes.
if (index < 0) {
ok = false;
return {};
}
ok = true;
auto r = std::string(str.substr(0, index + 1));
((uint8_t &)r[r.size() - 1])++;
return r;
}
std::string getSearchPath(Node *n) {
assert(n != nullptr);
Arena arena;
auto result = getSearchPath(arena, n);
return std::string((const char *)result.data(), result.size());
}
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node) {
constexpr int kSeparation = 3;
struct DebugDotPrinter {
explicit DebugDotPrinter(FILE *file) : file(file) {}
void print(Node *n, int y = 0) {
assert(n != nullptr);
if (n->entryPresent) {
fprintf(file,
" k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64
"\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, n->maxVersion, n->entry.pointVersion,
n->entry.rangeVersion, getPartialKeyPrintable(n).c_str(), x, y);
} else {
fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, n->maxVersion, getPartialKeyPrintable(n).c_str(), x,
y);
}
x += kSeparation;
for (int child = getChildGeq(n, 0); child >= 0;
child = getChildGeq(n, child + 1)) {
auto *c = getChildExists(n, child);
fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c);
print(c, y - kSeparation);
}
}
int x = 0;
FILE *file;
};
fprintf(file, "digraph ConflictSet {\n");
fprintf(file, " node [shape = box];\n");
assert(node != nullptr);
DebugDotPrinter printer{file};
printer.print(node);
fprintf(file, "}\n");
}
void checkParentPointers(Node *node, bool &success) {
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChildExists(node, i);
if (child->parent != node) {
fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n",
getSearchPathPrintable(node).c_str(), i, (void *)child->parent,
(void *)node);
success = false;
}
checkParentPointers(child, success);
}
}
[[maybe_unused]] int64_t checkMaxVersion(Node *root, Node *node,
int64_t oldestVersion, bool &success) {
int64_t expected = std::numeric_limits<int64_t>::lowest();
if (node->entryPresent) {
expected = std::max(expected, node->entry.pointVersion);
}
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChildExists(node, i);
expected = std::max(expected,
checkMaxVersion(root, child, oldestVersion, success));
if (child->entryPresent) {
expected = std::max(expected, child->entry.rangeVersion);
}
}
auto key = getSearchPath(root);
bool ok;
auto inc = strinc(key, ok);
if (ok) {
auto borrowed = firstGeq(root, inc);
if (borrowed.n != nullptr) {
expected = std::max(expected, borrowed.n->entry.rangeVersion);
}
}
if (node->maxVersion > oldestVersion && node->maxVersion != expected) {
fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",
getSearchPathPrintable(node).c_str(), node->maxVersion, expected);
success = false;
}
return expected;
}
[[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) {
int64_t total = node->entryPresent;
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChildExists(node, i);
int64_t e = checkEntriesExist(child, success);
total += e;
if (e == 0) {
Arena arena;
fprintf(stderr, "%s has child %02x with no reachable entries\n",
getSearchPathPrintable(node).c_str(), i);
success = false;
}
}
return total;
}
bool checkCorrectness(Node *node, int64_t oldestVersion) {
bool success = true;
checkParentPointers(node, success);
checkMaxVersion(node, node, oldestVersion, success);
checkEntriesExist(node, success);
return success;
}
} // namespace
namespace std {
void __throw_length_error(const char *) { __builtin_unreachable(); }
} // namespace std
#ifdef ENABLE_MAIN
void printTree() {
int64_t writeVersion = 0;
ConflictSet::Impl cs{writeVersion};
ReferenceImpl refImpl{writeVersion};
Arena arena;
constexpr int kNumKeys = 5;
auto *write = new (arena) ConflictSet::WriteRange[kNumKeys];
for (int i = 0; i < kNumKeys; ++i) {
write[i].begin = toKey(arena, i);
write[i].end.len = 0;
write[i].writeVersion = ++writeVersion;
}
cs.addWrites(write, kNumKeys);
for (int i = 0; i < kNumKeys; ++i) {
write[i].writeVersion = ++writeVersion;
}
cs.addWrites(write, kNumKeys);
debugPrintDot(stdout, cs.root);
}
int main(void) {
printTree();
return 0;
}
#endif
#ifdef ENABLE_FUZZ
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
TestDriver<ConflictSet::Impl> driver{data, size};
static_assert(driver.kMaxKeyLen > Node::kPartialKeyMaxLen);
for (;;) {
bool done = driver.next();
if (!driver.ok) {
debugPrintDot(stdout, driver.cs.root);
fflush(stdout);
abort();
}
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check correctness\n");
#endif
bool success = checkCorrectness(driver.cs.root, driver.cs.oldestVersion);
if (!success) {
debugPrintDot(stdout, driver.cs.root);
fflush(stdout);
abort();
}
if (done) {
break;
}
}
return 0;
}
#endif
// GCOVR_EXCL_STOP