Files
conflict-set/ConflictSet.cpp
Andrew Noyes ee5a84cd7b
All checks were successful
Tests / 64 bit versions total: 8220, passed: 8220
Tests / Debug total: 8218, passed: 8218
Tests / SIMD fallback total: 8220, passed: 8220
Tests / Release [clang] total: 8220, passed: 8220
Clang |Total|New|Outstanding|Fixed|Trend |:-:|:-:|:-:|:-:|:-: |0|0|0|0|:clap:
Tests / gcc total: 8220, passed: 8220
GNU C Compiler (gcc) |Total|New|Outstanding|Fixed|Trend |:-:|:-:|:-:|:-:|:-: |0|0|0|0|:clap:
Tests / Release [clang,aarch64] total: 5446, passed: 5446
Tests / Coverage total: 5497, passed: 5497
Code Coverage #### Project Overview No changes detected, that affect the code coverage. * Line Coverage: 97.69% (3165/3240) * Branch Coverage: 42.26% (19263/45585) * Complexity Density: 0.00 * Lines of Code: 3240 #### Quality Gates Summary Output truncated.
weaselab/conflict-set/pipeline/head This commit looks good
Remove dead stores
2024-11-15 17:03:29 -08:00

6280 lines
201 KiB
C++

/*
Copyright 2024 Andrew Noyes
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#if !defined(USE_SIMD_FALLBACK) && defined(__has_include)
#if defined(__x86_64__) && __has_include("immintrin.h")
#define HAS_AVX 1
#include <immintrin.h>
#elif __has_include("arm_neon.h")
#define HAS_ARM_NEON 1
#include <arm_neon.h>
#endif
#endif
#include "ConflictSet.h"
#include "Internal.h"
#include "LongestCommonPrefix.h"
#include "Metrics.h"
#include <algorithm>
#include <bit>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <inttypes.h>
#include <limits>
#include <span>
#include <string>
#include <string_view>
#include <sys/time.h>
#include <type_traits>
#include <utility>
#ifndef __SANITIZE_THREAD__
#if defined(__has_feature)
#if __has_feature(thread_sanitizer)
#define __SANITIZE_THREAD__
#endif
#endif
#endif
#include <memcheck.h>
using namespace weaselab;
// Use assert for checking potentially complex properties during tests.
// Use assume to hint simple properties to the optimizer.
// TODO use the c++23 version when that's available
#ifdef NDEBUG
#if __has_builtin(__builtin_assume)
#define assume(e) __builtin_assume(e)
#else
#define assume(e) \
if (!(e)) \
__builtin_unreachable()
#endif
#else
#define assume assert
#endif
#if SHOW_MEMORY
void addNode(struct Node *);
void removeNode(struct Node *);
void addKey(struct Node *);
void removeKey(struct Node *);
#else
constexpr void addNode(struct Node *) {}
constexpr void removeNode(struct Node *) {}
constexpr void addKey(struct Node *) {}
constexpr void removeKey(struct Node *) {}
#endif
// ==================== BEGIN IMPLEMENTATION ====================
constexpr int64_t kNominalVersionWindow = 2e9;
constexpr int64_t kMaxCorrectVersionWindow =
std::numeric_limits<int32_t>::max();
static_assert(kNominalVersionWindow <= kMaxCorrectVersionWindow);
#ifndef USE_64_BIT
#define USE_64_BIT 0
#endif
auto operator<(const ConflictSet::WriteRange &lhs,
const ConflictSet::WriteRange &rhs) {
if (lhs.end.len == 0) {
return lhs.begin < rhs.begin;
} else {
return lhs.end < rhs.begin;
}
}
struct InternalVersionT {
constexpr InternalVersionT() = default;
constexpr explicit InternalVersionT(int64_t value) : value(value) {}
constexpr int64_t toInt64() const { return value; } // GCOVR_EXCL_LINE
constexpr auto operator<=>(const InternalVersionT &rhs) const {
#if USE_64_BIT
return value <=> rhs.value;
#else
// Maintains ordering after overflow, as long as the full-precision versions
// are within `kMaxCorrectVersionWindow` of eachother.
return int32_t(value - rhs.value) <=> 0;
#endif
}
constexpr bool operator==(const InternalVersionT &) const = default;
#if USE_64_BIT
static const InternalVersionT zero;
#else
static thread_local InternalVersionT zero;
#endif
private:
#if USE_64_BIT
int64_t value;
#else
uint32_t value;
#endif
};
#if USE_64_BIT
const InternalVersionT InternalVersionT::zero{0};
#else
thread_local InternalVersionT InternalVersionT::zero;
#endif
struct Entry {
InternalVersionT pointVersion;
InternalVersionT rangeVersion;
};
struct BitSet {
bool test(int i) const;
void set(int i);
void reset(int i);
int firstSetGeq(int i) const;
// Calls `f` with the index of each bit set
template <class F> void forEachSet(F f) const {
// See section 3.1 in https://arxiv.org/pdf/1709.07821.pdf for details about
// this approach
for (int begin = 0; begin < 256; begin += 64) {
uint64_t word = words[begin >> 6];
while (word) {
uint64_t temp = word & -word;
int index = begin + std::countr_zero(word);
f(index);
word ^= temp;
}
}
}
void init() {
for (auto &w : words) {
w = 0;
}
}
private:
uint64_t words[4];
};
bool BitSet::test(int i) const {
assert(0 <= i);
assert(i < 256);
return words[i >> 6] & (uint64_t(1) << (i & 63));
}
void BitSet::set(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] |= uint64_t(1) << (i & 63);
}
void BitSet::reset(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] &= ~(uint64_t(1) << (i & 63));
}
int BitSet::firstSetGeq(int i) const {
assume(0 <= i);
// i may be >= 256
uint64_t mask = uint64_t(-1) << (i & 63);
for (int j = i >> 6; j < 4; ++j) {
uint64_t masked = mask & words[j];
if (masked) {
return (j << 6) + std::countr_zero(masked);
}
mask = -1;
}
return -1;
}
enum Type : int8_t {
Type_Node0,
Type_Node3,
Type_Node16,
Type_Node48,
Type_Node256,
};
template <class T> struct NodeAllocator;
struct TaggedNodePointer {
TaggedNodePointer() = default;
operator struct Node *() { return (struct Node *)withoutType(); }
operator struct Node0 *() {
assert(getType() == Type_Node0);
return (struct Node0 *)withoutType();
}
operator struct Node3 *() {
assert(getType() == Type_Node3);
return (struct Node3 *)withoutType();
}
operator struct Node16 *() {
assert(getType() == Type_Node16);
return (struct Node16 *)withoutType();
}
operator struct Node48 *() {
assert(getType() == Type_Node48);
return (struct Node48 *)withoutType();
}
operator struct Node256 *() {
assert(getType() == Type_Node256);
return (struct Node256 *)withoutType();
}
/*implicit*/ TaggedNodePointer(std::nullptr_t) : p(0) {}
/*implicit*/ TaggedNodePointer(Node0 *x)
: TaggedNodePointer((struct Node *)x, Type_Node0) {}
/*implicit*/ TaggedNodePointer(Node3 *x)
: TaggedNodePointer((struct Node *)x, Type_Node3) {}
/*implicit*/ TaggedNodePointer(Node16 *x)
: TaggedNodePointer((struct Node *)x, Type_Node16) {}
/*implicit*/ TaggedNodePointer(Node48 *x)
: TaggedNodePointer((struct Node *)x, Type_Node48) {}
/*implicit*/ TaggedNodePointer(Node256 *x)
: TaggedNodePointer((struct Node *)x, Type_Node256) {}
bool operator!=(std::nullptr_t) { return p != 0; }
bool operator==(std::nullptr_t) { return p == 0; }
bool operator==(const TaggedNodePointer &) const = default;
bool operator==(Node *n) const { return (uintptr_t)n == withoutType(); }
Node *operator->() { return (Node *)withoutType(); }
Type getType();
TaggedNodePointer(const TaggedNodePointer &) = default;
TaggedNodePointer &operator=(const TaggedNodePointer &) = default;
/*implicit*/ TaggedNodePointer(Node *n);
private:
TaggedNodePointer(struct Node *p, Type t) : p((uintptr_t)p) {
assert((this->p & 7) == 0);
this->p |= t;
assume(p != 0);
}
uintptr_t withoutType() const { return p & ~uintptr_t(7); }
uintptr_t p;
};
struct Node {
/* begin section that's copied to the next node */
union {
Entry entry;
/* Set to the forwarding point for this node if releaseDeferred is set */
Node *forwardTo;
};
Node *parent;
int32_t partialKeyLen;
int16_t numChildren;
bool entryPresent;
// Temp variable used to signal the end of the range during addWriteRange
bool endOfRange;
uint8_t parentsIndex;
/* end section that's copied to the next node */
/* If set, this node has been replaced and the next node in the forwarding
* chain is `forwardTo`*/
bool releaseDeferred;
uint8_t *partialKey();
Type getType() const { return type; }
int32_t getCapacity() const {
assert(!releaseDeferred);
return partialKeyCapacity;
}
private:
template <class T> friend struct NodeAllocator;
// These are publically readable, but should only be written by
// NodeAllocator
Type type;
int32_t partialKeyCapacity;
};
TaggedNodePointer::TaggedNodePointer(Node *n)
: TaggedNodePointer(n, n->getType()) {}
Type TaggedNodePointer::getType() {
assert(p != 0);
return Type(p & uintptr_t(7));
}
constexpr int kNodeCopyBegin = offsetof(Node, entry);
constexpr int kNodeCopySize =
offsetof(Node, parentsIndex) + sizeof(Node::parentsIndex) - kNodeCopyBegin;
// copyChildrenAndKeyFrom is responsible for copying all
// public members of Node, copying the partial key, logically copying the
// children (converting representation if necessary), and updating all the
// children's parent pointers. The caller must then insert the new node into the
// tree.
struct Node0 : Node {
constexpr static auto kType = Type_Node0;
uint8_t *partialKey() {
assert(!releaseDeferred);
return (uint8_t *)(this + 1);
}
void copyChildrenAndKeyFrom(const Node0 &other);
void copyChildrenAndKeyFrom(const struct Node3 &other);
size_t size() const { return sizeof(Node0) + getCapacity(); }
};
struct Node3 : Node {
constexpr static auto kMaxNodes = 3;
constexpr static auto kType = Type_Node3;
// Sorted
uint8_t index[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
TaggedNodePointer children[kMaxNodes];
uint8_t *partialKey() {
assert(!releaseDeferred);
return (uint8_t *)(this + 1);
}
void copyChildrenAndKeyFrom(const Node0 &other);
void copyChildrenAndKeyFrom(const Node3 &other);
void copyChildrenAndKeyFrom(const struct Node16 &other);
size_t size() const { return sizeof(Node3) + getCapacity(); }
};
struct Node16 : Node {
constexpr static auto kType = Type_Node16;
constexpr static auto kMaxNodes = 16;
// Sorted
uint8_t index[kMaxNodes];
TaggedNodePointer children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
uint8_t *partialKey() {
assert(!releaseDeferred);
return (uint8_t *)(this + 1);
}
void copyChildrenAndKeyFrom(const Node3 &other);
void copyChildrenAndKeyFrom(const Node16 &other);
void copyChildrenAndKeyFrom(const struct Node48 &other);
size_t size() const { return sizeof(Node16) + getCapacity(); }
};
struct Node48 : Node {
constexpr static auto kType = Type_Node48;
constexpr static auto kMaxNodes = 48;
constexpr static int kMaxOfMaxPageSize = 16;
constexpr static int kMaxOfMaxShift =
std::countr_zero(uint32_t(kMaxOfMaxPageSize));
constexpr static int kMaxOfMaxTotalPages = kMaxNodes / kMaxOfMaxPageSize;
BitSet bitSet;
TaggedNodePointer children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
InternalVersionT maxOfMax[kMaxOfMaxTotalPages];
uint8_t reverseIndex[kMaxNodes];
int8_t index[256];
uint8_t *partialKey() {
assert(!releaseDeferred);
return (uint8_t *)(this + 1);
}
void copyChildrenAndKeyFrom(const Node16 &other);
void copyChildrenAndKeyFrom(const Node48 &other);
void copyChildrenAndKeyFrom(const struct Node256 &other);
size_t size() const { return sizeof(Node48) + getCapacity(); }
};
struct Node256 : Node {
constexpr static auto kType = Type_Node256;
constexpr static auto kMaxNodes = 256;
constexpr static int kMaxOfMaxPageSize = 16;
constexpr static int kMaxOfMaxShift =
std::countr_zero(uint32_t(kMaxOfMaxPageSize));
constexpr static int kMaxOfMaxTotalPages = kMaxNodes / kMaxOfMaxPageSize;
BitSet bitSet;
TaggedNodePointer children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
InternalVersionT maxOfMax[kMaxOfMaxTotalPages];
uint8_t *partialKey() {
assert(!releaseDeferred);
return (uint8_t *)(this + 1);
}
void copyChildrenAndKeyFrom(const Node48 &other);
void copyChildrenAndKeyFrom(const Node256 &other);
size_t size() const { return sizeof(Node256) + getCapacity(); }
};
inline void Node0::copyChildrenAndKeyFrom(const Node0 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node0::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node3::copyChildrenAndKeyFrom(const Node0 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node3::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, kMaxNodes);
memcpy(children, other.children, kMaxNodes * sizeof(children[0])); // NOLINT
memcpy(childMaxVersion, other.childMaxVersion,
kMaxNodes * sizeof(childMaxVersion[0]));
memcpy(partialKey(), &other + 1, partialKeyLen);
for (int i = 0; i < numChildren; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node3::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, kMaxNodes);
memcpy(children, other.children, kMaxNodes * sizeof(children[0])); // NOLINT
memcpy(childMaxVersion, other.childMaxVersion,
kMaxNodes * sizeof(childMaxVersion[0]));
memcpy(partialKey(), &other + 1, partialKeyLen);
for (int i = 0; i < numChildren; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node16::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, Node3::kMaxNodes);
memcpy(children, other.children,
Node3::kMaxNodes * sizeof(children[0])); // NOLINT
memcpy(childMaxVersion, other.childMaxVersion,
Node3::kMaxNodes * sizeof(childMaxVersion[0]));
memcpy(partialKey(), &other + 1, partialKeyLen);
assert(numChildren == Node3::kMaxNodes);
for (int i = 0; i < Node3::kMaxNodes; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node16::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, sizeof(index));
for (int i = 0; i < numChildren; ++i) {
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
}
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node16::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
int i = 0;
other.bitSet.forEachSet([&](int c) {
// Suppress a false positive -Waggressive-loop-optimizations warning
// in gcc
assume(i < Node16::kMaxNodes);
index[i] = c;
children[i] = other.children[other.index[c]];
childMaxVersion[i] = other.childMaxVersion[other.index[c]];
assert(children[i]->parent == &other);
children[i]->parent = this;
++i;
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node48::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
assert(numChildren == Node16::kMaxNodes);
memset(index, -1, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
memcpy(partialKey(), &other + 1, partialKeyLen);
bitSet.init();
int i = 0;
for (auto x : other.index) {
bitSet.set(x);
index[x] = i;
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
reverseIndex[i] = x;
maxOfMax[i >> Node48::kMaxOfMaxShift] =
std::max(maxOfMax[i >> Node48::kMaxOfMaxShift], childMaxVersion[i]);
++i;
}
}
inline void Node48::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
bitSet = other.bitSet;
memcpy(index, other.index, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
for (int i = 0; i < numChildren; ++i) {
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
}
memcpy(reverseIndex, other.reverseIndex, sizeof(reverseIndex));
memcpy(maxOfMax, other.maxOfMax, sizeof(maxOfMax));
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node48::copyChildrenAndKeyFrom(const Node256 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memset(index, -1, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
bitSet = other.bitSet;
int i = 0;
bitSet.forEachSet([&](int c) {
// Suppress a false positive -Waggressive-loop-optimizations warning
// in gcc.
assume(i < Node48::kMaxNodes);
index[c] = i;
children[i] = other.children[c];
childMaxVersion[i] = other.childMaxVersion[c];
assert(children[i]->parent == &other);
children[i]->parent = this;
reverseIndex[i] = c;
maxOfMax[i >> Node48::kMaxOfMaxShift] =
std::max(maxOfMax[i >> Node48::kMaxOfMaxShift], childMaxVersion[i]);
++i;
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node256::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
bitSet = other.bitSet;
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
for (auto &v : maxOfMax) {
v = z;
}
bitSet.forEachSet([&](int c) {
children[c] = other.children[other.index[c]];
childMaxVersion[c] = other.childMaxVersion[other.index[c]];
assert(children[c]->parent == &other);
children[c]->parent = this;
maxOfMax[c >> Node256::kMaxOfMaxShift] =
std::max(maxOfMax[c >> Node256::kMaxOfMaxShift], childMaxVersion[c]);
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node256::copyChildrenAndKeyFrom(const Node256 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
bitSet = other.bitSet;
bitSet.forEachSet([&](int c) {
children[c] = other.children[c];
childMaxVersion[c] = other.childMaxVersion[c];
assert(children[c]->parent == &other);
children[c]->parent = this;
});
memcpy(maxOfMax, other.maxOfMax, sizeof(maxOfMax));
memcpy(partialKey(), &other + 1, partialKeyLen);
}
namespace {
std::string getSearchPathPrintable(Node *n);
std::string getSearchPath(Node *n);
} // namespace
// Bound memory usage following the analysis in the ART paper
// Each node with an entry present gets a budget of kBytesPerKey. Node0 always
// has an entry present.
// Induction hypothesis is that each node's surplus is >= kMinNodeSurplus
#if USE_64_BIT
constexpr int kBytesPerKey = 144;
constexpr int kMinNodeSurplus = 104;
#else
constexpr int kBytesPerKey = 112;
constexpr int kMinNodeSurplus = 80;
#endif
// Count the entry itself as a child
constexpr int kMinChildrenNode0 = 1;
constexpr int kMinChildrenNode3 = 2;
constexpr int kMinChildrenNode16 = 4;
constexpr int kMinChildrenNode48 = 17;
constexpr int kMinChildrenNode256 = 49;
constexpr int kNode256Surplus =
kMinChildrenNode256 * kMinNodeSurplus - sizeof(Node256);
static_assert(kNode256Surplus >= kMinNodeSurplus);
constexpr int kNode48Surplus =
kMinChildrenNode48 * kMinNodeSurplus - sizeof(Node48);
static_assert(kNode48Surplus >= kMinNodeSurplus);
constexpr int kNode16Surplus =
kMinChildrenNode16 * kMinNodeSurplus - sizeof(Node16);
static_assert(kNode16Surplus >= kMinNodeSurplus);
constexpr int kNode3Surplus =
kMinChildrenNode3 * kMinNodeSurplus - sizeof(Node3);
static_assert(kNode3Surplus >= kMinNodeSurplus);
static_assert(kBytesPerKey - sizeof(Node0) >= kMinNodeSurplus);
// We'll additionally maintain this property:
// `(children + entryPresent) * length >= capacity`
//
// Which should give us the budget to pay for the key bytes. (children +
// entryPresent) is a lower bound on how many keys these bytes are a prefix of
constexpr int getMaxCapacity(int numChildren, int entryPresent,
int partialKeyLen) {
return (numChildren + entryPresent) * (partialKeyLen + 1);
}
constexpr int getMaxCapacity(Node *self) {
return getMaxCapacity(self->numChildren, self->entryPresent,
self->partialKeyLen);
}
constexpr int64_t kMaxFreeListBytes = 1 << 20;
// Maintains a free list up to kMaxFreeListBytes. If the top element of the list
// doesn't meet the capacity constraints, it's freed and a new node is allocated
// with the minimum capacity. The hope is that "unfit" nodes don't get stuck in
// the free list.
template <class T> struct NodeAllocator {
static_assert(std::derived_from<T, Node>);
static_assert(std::is_trivial_v<T>);
T *allocate(int minCapacity, int maxCapacity) {
assert(minCapacity <= maxCapacity);
assert(freeListSize >= 0);
assert(freeListSize <= kMaxFreeListBytes);
T *result = allocate_helper(minCapacity, maxCapacity);
result->endOfRange = false;
result->releaseDeferred = false;
if constexpr (!std::is_same_v<T, Node0>) {
memset(result->children, 0, sizeof(result->children));
const auto z = InternalVersionT::zero;
for (auto &v : result->childMaxVersion) {
v = z;
}
}
if constexpr (std::is_same_v<T, Node48> || std::is_same_v<T, Node256>) {
const auto z = InternalVersionT::zero;
for (auto &v : result->maxOfMax) {
v = z;
}
}
return result;
}
void release(T *p) {
if (freeListSize + sizeof(T) + p->partialKeyCapacity > kMaxFreeListBytes) {
removeNode(p);
return safe_free(p, sizeof(T) + p->partialKeyCapacity);
}
p->parent = freeList;
freeList = p;
freeListSize += sizeof(T) + p->partialKeyCapacity;
VALGRIND_MAKE_MEM_NOACCESS(p, sizeof(T) + p->partialKeyCapacity);
}
void deferRelease(T *p, Node *forwardTo) {
p->releaseDeferred = true;
p->forwardTo = forwardTo;
if (freeListSize + sizeof(T) + p->partialKeyCapacity > kMaxFreeListBytes) {
p->parent = deferredListOverflow;
deferredListOverflow = p;
} else {
if (deferredList == nullptr) {
deferredListFront = p;
}
p->parent = deferredList;
deferredList = p;
freeListSize += sizeof(T) + p->partialKeyCapacity;
}
}
void releaseDeferred() {
if (deferredList != nullptr) {
deferredListFront->parent = freeList;
#ifndef NVALGRIND
for (auto *iter = deferredList; iter != freeList;) {
auto *tmp = iter;
iter = (T *)iter->parent;
VALGRIND_MAKE_MEM_NOACCESS(tmp, sizeof(T) + tmp->partialKeyCapacity);
}
#endif
freeList = std::exchange(deferredList, nullptr);
}
for (T *n = std::exchange(deferredListOverflow, nullptr); n != nullptr;) {
auto *tmp = n;
n = (T *)n->parent;
release(tmp);
}
}
NodeAllocator() = default;
NodeAllocator(const NodeAllocator &) = delete;
NodeAllocator &operator=(const NodeAllocator &) = delete;
NodeAllocator(NodeAllocator &&) = delete;
NodeAllocator &operator=(NodeAllocator &&) = delete;
~NodeAllocator() {
assert(deferredList == nullptr);
assert(deferredListOverflow == nullptr);
for (T *iter = freeList; iter != nullptr;) {
VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(T));
auto *tmp = iter;
iter = (T *)iter->parent;
removeNode(tmp);
safe_free(tmp, sizeof(T) + tmp->partialKeyCapacity);
}
}
private:
int64_t freeListSize = 0;
T *freeList = nullptr;
T *deferredList = nullptr;
// Used to concatenate deferredList to freeList
T *deferredListFront;
T *deferredListOverflow = nullptr;
T *allocate_helper(int minCapacity, int maxCapacity) {
if (freeList != nullptr) {
VALGRIND_MAKE_MEM_DEFINED(freeList, sizeof(T));
freeListSize -= sizeof(T) + freeList->partialKeyCapacity;
assume(freeList->partialKeyCapacity >= 0);
assume(minCapacity >= 0);
assume(minCapacity <= maxCapacity);
if (freeList->partialKeyCapacity >= minCapacity &&
freeList->partialKeyCapacity <= maxCapacity) {
auto *result = freeList;
freeList = (T *)freeList->parent;
VALGRIND_MAKE_MEM_UNDEFINED(result,
sizeof(T) + result->partialKeyCapacity);
VALGRIND_MAKE_MEM_DEFINED(&result->partialKeyCapacity,
sizeof(result->partialKeyCapacity));
VALGRIND_MAKE_MEM_DEFINED(&result->type, sizeof(result->type));
return result;
} else {
auto *p = freeList;
freeList = (T *)p->parent;
removeNode(p);
safe_free(p, sizeof(T) + p->partialKeyCapacity);
}
}
auto *result = (T *)safe_malloc(sizeof(T) + minCapacity);
result->type = T::kType;
result->partialKeyCapacity = minCapacity;
addNode(result);
return result;
}
};
uint8_t *Node::partialKey() {
switch (type) {
case Type_Node0:
return ((Node0 *)this)->partialKey();
case Type_Node3:
return ((Node3 *)this)->partialKey();
case Type_Node16:
return ((Node16 *)this)->partialKey();
case Type_Node48:
return ((Node48 *)this)->partialKey();
case Type_Node256:
return ((Node256 *)this)->partialKey();
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// A type that's plumbed along the check call tree. Lifetime ends after each
// check call.
struct ReadContext {
int64_t point_read_accum = 0;
int64_t prefix_read_accum = 0;
int64_t range_read_accum = 0;
int64_t point_read_short_circuit_accum = 0;
int64_t prefix_read_short_circuit_accum = 0;
int64_t range_read_short_circuit_accum = 0;
int64_t point_read_iterations_accum = 0;
int64_t prefix_read_iterations_accum = 0;
int64_t range_read_iterations_accum = 0;
int64_t range_read_node_scan_accum = 0;
int64_t commits_accum = 0;
int64_t conflicts_accum = 0;
int64_t too_olds_accum = 0;
ConflictSet::Impl *impl;
bool operator==(const ReadContext &) const = default; // GCOVR_EXCL_LINE
};
// A type that's plumbed along the non-const call tree. Same lifetime as
// ConflictSet::Impl
struct WriteContext {
struct Accum {
int64_t entries_erased;
int64_t insert_iterations;
int64_t entries_inserted;
int64_t nodes_allocated;
int64_t nodes_released;
int64_t point_writes;
int64_t range_writes;
int64_t write_bytes;
} accum;
#if USE_64_BIT
static constexpr InternalVersionT zero{0};
#else
// Cache a copy of InternalVersionT::zero, so we don't need to do the TLS
// lookup as often.
InternalVersionT zero;
#endif
WriteContext() { memset(&accum, 0, sizeof(accum)); }
template <class T> T *allocate(int minCapacity, int maxCapacity) {
static_assert(!std::is_same_v<T, Node>);
++accum.nodes_allocated;
if constexpr (std::is_same_v<T, Node0>) {
return node0.allocate(minCapacity, maxCapacity);
} else if constexpr (std::is_same_v<T, Node3>) {
return node3.allocate(minCapacity, maxCapacity);
} else if constexpr (std::is_same_v<T, Node16>) {
return node16.allocate(minCapacity, maxCapacity);
} else if constexpr (std::is_same_v<T, Node48>) {
return node48.allocate(minCapacity, maxCapacity);
} else if constexpr (std::is_same_v<T, Node256>) {
return node256.allocate(minCapacity, maxCapacity);
}
}
template <class T> void release(T *c) {
static_assert(!std::is_same_v<T, Node>);
++accum.nodes_released;
if constexpr (std::is_same_v<T, Node0>) {
return node0.release(c);
} else if constexpr (std::is_same_v<T, Node3>) {
return node3.release(c);
} else if constexpr (std::is_same_v<T, Node16>) {
return node16.release(c);
} else if constexpr (std::is_same_v<T, Node48>) {
return node48.release(c);
} else if constexpr (std::is_same_v<T, Node256>) {
return node256.release(c);
}
}
// Place in a list to be released in the next call to releaseDeferred.
template <class T> void deferRelease(T *n, Node *forwardTo) {
static_assert(!std::is_same_v<T, Node>);
if constexpr (std::is_same_v<T, Node0>) {
return node0.deferRelease(n, forwardTo);
} else if constexpr (std::is_same_v<T, Node3>) {
return node3.deferRelease(n, forwardTo);
} else if constexpr (std::is_same_v<T, Node16>) {
return node16.deferRelease(n, forwardTo);
} else if constexpr (std::is_same_v<T, Node48>) {
return node48.deferRelease(n, forwardTo);
} else if constexpr (std::is_same_v<T, Node256>) {
return node256.deferRelease(n, forwardTo);
}
}
// Release all nodes passed to deferRelease since the last call to
// releaseDeferred.
void releaseDeferred() {
node0.releaseDeferred();
node3.releaseDeferred();
node16.releaseDeferred();
node48.releaseDeferred();
node256.releaseDeferred();
}
private:
NodeAllocator<Node0> node0;
NodeAllocator<Node3> node3;
NodeAllocator<Node16> node16;
NodeAllocator<Node48> node48;
NodeAllocator<Node256> node256;
};
int getNodeIndex(Node3 *n, uint8_t index) {
assume(n->numChildren >= 1);
assume(n->numChildren <= 3);
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] == index) {
return i;
}
}
return -1;
}
int getNodeIndexExists(Node3 *n, uint8_t index) {
assume(n->numChildren >= 1);
assume(n->numChildren <= 3);
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] == index) {
return i;
}
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
int getNodeIndex(Node16 *self, uint8_t index) {
#ifdef HAS_AVX
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
// key_vec is 16 repeated copies of the searched-for byte, one for every
// possible position in child_keys that needs to be searched.
__m128i key_vec = _mm_set1_epi8(index);
// Compare all child_keys to 'index' in parallel. Don't worry if some of the
// keys aren't valid, we'll mask the results to only consider the valid ones
// below.
__m128i indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
// Build a mask to select only the first node->num_children values from the
// comparison (because the other values are meaningless)
uint32_t mask = (1 << self->numChildren) - 1;
// Change the results of the comparison into a bitfield, masking off any
// invalid comparisons.
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
// No match if there are no '1's in the bitfield.
if (bitfield == 0)
return -1;
// Find the index of the first '1' in the bitfield by counting the leading
// zeros.
return std::countr_zero(bitfield);
#elif defined(HAS_ARM_NEON)
// Based on
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
uint8x16_t indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
assume(self->numChildren <= Node16::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each match in valid range
uint64_t bitfield =
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
if (bitfield == 0)
return -1;
return std::countr_zero(bitfield) / 4;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
}
}
return -1;
#endif
}
int getNodeIndexExists(Node16 *self, uint8_t index) {
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(index);
__m128i indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
uint32_t mask = (1 << self->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
assume(bitfield != 0);
return std::countr_zero(bitfield);
#elif defined(HAS_ARM_NEON)
// Based on
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
uint8x16_t indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
assume(self->numChildren <= Node16::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each match in valid range
uint64_t bitfield =
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
assume(bitfield != 0);
return std::countr_zero(bitfield) / 4;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
}
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
#endif
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node3 *self, uint8_t index) {
return self->children[getNodeIndexExists(self, index)];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node16 *self, uint8_t index) {
return self->children[getNodeIndexExists(self, index)];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node48 *self, uint8_t index) {
assert(self->bitSet.test(index));
return self->children[self->index[index]];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node256 *self, uint8_t index) {
assert(self->bitSet.test(index));
return self->children[index];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
return getChildExists(static_cast<Node3 *>(self), index);
}
case Type_Node16: {
return getChildExists(static_cast<Node16 *>(self), index);
}
case Type_Node48: {
return getChildExists(static_cast<Node48 *>(self), index);
}
case Type_Node256: {
return getChildExists(static_cast<Node256 *>(self), index);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition `n` is not the root
InternalVersionT maxVersion(Node *n) {
int index = n->parentsIndex;
n = n->parent;
assert(n != nullptr);
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndexExists(n3, index);
return n3->childMaxVersion[i];
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndexExists(n16, index);
return n16->childMaxVersion[i];
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
return n48->childMaxVersion[n48->index[index]];
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
return n256->childMaxVersion[index];
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition `n` is not the root
InternalVersionT exchangeMaxVersion(Node *n, InternalVersionT newMax) {
int index = n->parentsIndex;
n = n->parent;
assert(n != nullptr);
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndexExists(n3, index);
return std::exchange(n3->childMaxVersion[i], newMax);
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndexExists(n16, index);
return std::exchange(n16->childMaxVersion[i], newMax);
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
return std::exchange(n48->childMaxVersion[n48->index[index]], newMax);
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
return std::exchange(n256->childMaxVersion[index], newMax);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition `n` is not the root
void setMaxVersion(Node *n, InternalVersionT newMax) {
assert(newMax >= InternalVersionT::zero);
int index = n->parentsIndex;
n = n->parent;
assert(n != nullptr);
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndexExists(n3, index);
n3->childMaxVersion[i] = newMax;
return;
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndexExists(n16, index);
n16->childMaxVersion[i] = newMax;
return;
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
int i = n48->index[index];
n48->childMaxVersion[i] = newMax;
n48->maxOfMax[i >> Node48::kMaxOfMaxShift] = std::max<InternalVersionT>(
n48->maxOfMax[i >> Node48::kMaxOfMaxShift], newMax);
return;
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
n256->childMaxVersion[index] = newMax;
n256->maxOfMax[index >> Node256::kMaxOfMaxShift] =
std::max<InternalVersionT>(
n256->maxOfMax[index >> Node256::kMaxOfMaxShift], newMax);
return;
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// If impl is nullptr, then n->parent must not be nullptr
TaggedNodePointer &getInTree(Node *n, ConflictSet::Impl *impl);
TaggedNodePointer getChild(Node0 *, uint8_t) { return nullptr; }
TaggedNodePointer getChild(Node3 *self, uint8_t index) {
int i = getNodeIndex(self, index);
return i < 0 ? nullptr : self->children[i];
}
TaggedNodePointer getChild(Node16 *self, uint8_t index) {
int i = getNodeIndex(self, index);
return i < 0 ? nullptr : self->children[i];
}
TaggedNodePointer getChild(Node48 *self, uint8_t index) {
int i = self->index[index];
return i < 0 ? nullptr : self->children[i];
}
TaggedNodePointer getChild(Node256 *self, uint8_t index) {
return self->children[index];
}
TaggedNodePointer getChild(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChild(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChild(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChild(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChild(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChild(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
struct ChildAndMaxVersion {
TaggedNodePointer child;
InternalVersionT maxVersion;
static ChildAndMaxVersion empty() {
ChildAndMaxVersion result;
result.child = nullptr;
return result;
}
};
ChildAndMaxVersion getChildAndMaxVersion(Node0 *, uint8_t) { return {}; }
ChildAndMaxVersion getChildAndMaxVersion(Node3 *self, uint8_t index) {
int i = getNodeIndex(self, index);
if (i < 0) {
return ChildAndMaxVersion::empty();
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node16 *self, uint8_t index) {
int i = getNodeIndex(self, index);
if (i < 0) {
return ChildAndMaxVersion::empty();
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node48 *self, uint8_t index) {
int i = self->index[index];
if (i < 0) {
return ChildAndMaxVersion::empty();
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node256 *self, uint8_t index) {
return {self->children[index], self->childMaxVersion[index]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChildAndMaxVersion(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChildAndMaxVersion(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChildAndMaxVersion(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChildAndMaxVersion(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChildAndMaxVersion(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
TaggedNodePointer getChildGeq(Node0 *, int) { return nullptr; }
TaggedNodePointer getChildGeq(Node3 *n, int child) {
assume(n->numChildren >= 1);
assume(n->numChildren <= 3);
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] >= child) {
return n->children[i];
}
}
return nullptr;
}
TaggedNodePointer getChildGeq(Node16 *self, int child) {
if (child > 255) {
return nullptr;
}
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(child);
__m128i indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices));
int mask = (1 << self->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
return bitfield == 0 ? nullptr : self->children[std::countr_zero(bitfield)];
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self->index, sizeof(self->index));
// 0xff for each leq
auto results = vcleq_u8(vdupq_n_u8(child), indices);
assume(self->numChildren <= Node16::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
return bitfield == 0 ? nullptr
: self->children[std::countr_zero(bitfield) / 4];
#else
for (int i = 0; i < self->numChildren; ++i) {
if (i > 0) {
assert(self->index[i - 1] < self->index[i]);
}
if (self->index[i] >= child) {
return self->children[i];
}
}
return nullptr;
#endif
}
TaggedNodePointer getChildGeq(Node48 *self, int child) {
int c = self->bitSet.firstSetGeq(child);
return c < 0 ? nullptr : self->children[self->index[c]];
}
TaggedNodePointer getChildGeq(Node256 *self, int child) {
int c = self->bitSet.firstSetGeq(child);
return c < 0 ? nullptr : self->children[c];
}
TaggedNodePointer getChildGeq(Node *self, int child) {
switch (self->getType()) {
case Type_Node0:
return getChildGeq(static_cast<Node0 *>(self), child);
case Type_Node3:
return getChildGeq(static_cast<Node3 *>(self), child);
case Type_Node16:
return getChildGeq(static_cast<Node16 *>(self), child);
case Type_Node48:
return getChildGeq(static_cast<Node48 *>(self), child);
case Type_Node256:
return getChildGeq(static_cast<Node256 *>(self), child);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
TaggedNodePointer getFirstChild(Node0 *) { return nullptr; }
TaggedNodePointer getFirstChild(Node3 *self) {
// Improves scan performance
__builtin_prefetch(self->children[1]);
return self->children[0];
}
TaggedNodePointer getFirstChild(Node16 *self) {
// Improves scan performance
__builtin_prefetch(self->children[1]);
return self->children[0];
}
TaggedNodePointer getFirstChild(Node48 *self) {
return self->children[self->index[self->bitSet.firstSetGeq(0)]];
}
TaggedNodePointer getFirstChild(Node256 *self) {
return self->children[self->bitSet.firstSetGeq(0)];
}
TaggedNodePointer getFirstChild(Node *self) {
// Only require that the node-specific overloads are covered
// GCOVR_EXCL_START
switch (self->getType()) {
case Type_Node0:
return getFirstChild(static_cast<Node0 *>(self));
case Type_Node3:
return getFirstChild(static_cast<Node3 *>(self));
case Type_Node16:
return getFirstChild(static_cast<Node16 *>(self));
case Type_Node48:
return getFirstChild(static_cast<Node48 *>(self));
case Type_Node256:
return getFirstChild(static_cast<Node256 *>(self));
default:
__builtin_unreachable();
}
// GCOVR_EXCL_STOP
}
// self must not be the root
void maybeDecreaseCapacity(Node *&self, WriteContext *writeContext,
ConflictSet::Impl *impl);
void consumePartialKeyFull(TaggedNodePointer &self, TrivialSpan &key,
InternalVersionT writeVersion,
WriteContext *writeContext,
ConflictSet::Impl *impl) {
// Handle an existing partial key
int commonLen = std::min<int>(self->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(self->partialKey(), key.data(), commonLen);
if (partialKeyIndex < self->partialKeyLen) {
Node *old = self;
// Since root cannot have a partial key
assert(old->parent != nullptr);
InternalVersionT oldMaxVersion = exchangeMaxVersion(old, writeVersion);
// *self will have one child (old)
auto *newSelf = writeContext->allocate<Node3>(
partialKeyIndex, getMaxCapacity(1, 0, partialKeyIndex));
newSelf->parent = old->parent;
newSelf->parentsIndex = old->parentsIndex;
newSelf->partialKeyLen = partialKeyIndex;
newSelf->entryPresent = false;
newSelf->numChildren = 1;
memcpy(newSelf->partialKey(), old->partialKey(), newSelf->partialKeyLen);
uint8_t oldDistinguishingByte = old->partialKey()[partialKeyIndex];
old->parent = newSelf;
old->parentsIndex = oldDistinguishingByte;
newSelf->index[0] = oldDistinguishingByte;
newSelf->children[0] = old;
newSelf->childMaxVersion[0] = oldMaxVersion;
self = newSelf;
memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1,
old->partialKeyLen - (partialKeyIndex + 1));
old->partialKeyLen -= partialKeyIndex + 1;
// Maintain memory capacity invariant
maybeDecreaseCapacity(old, writeContext, impl);
}
key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex);
}
// Consume any partial key of `self`, and update `self` and
// `key` such that `self` is along the search path of `key`
inline __attribute__((always_inline)) void
consumePartialKey(TaggedNodePointer &self, TrivialSpan &key,
InternalVersionT writeVersion, WriteContext *writeContext,
ConflictSet::Impl *impl) {
if (self->partialKeyLen > 0) {
consumePartialKeyFull(self, key, writeVersion, writeContext, impl);
}
}
// Return the next node along the search path of key, consuming bytes of key
// such that the search path of the result + key is the same as the search path
// of self + key before the call. Creates a node if necessary. Updates
// `maxVersion` for result.
TaggedNodePointer &getOrCreateChild(TaggedNodePointer &self, TrivialSpan &key,
InternalVersionT newMaxVersion,
WriteContext *writeContext,
ConflictSet::Impl *impl) {
int index = key.front();
key = key.subspan(1, key.size() - 1);
// Fast path for if it exists already
switch (self->getType()) {
case Type_Node0:
break;
case Type_Node3: {
auto *self3 = static_cast<Node3 *>(self);
int i = getNodeIndex(self3, index);
if (i >= 0) {
consumePartialKey(self3->children[i], key, newMaxVersion, writeContext,
impl);
self3->childMaxVersion[i] = newMaxVersion;
return self3->children[i];
}
} break;
case Type_Node16: {
auto *self16 = static_cast<Node16 *>(self);
int i = getNodeIndex(self16, index);
if (i >= 0) {
consumePartialKey(self16->children[i], key, newMaxVersion, writeContext,
impl);
self16->childMaxVersion[i] = newMaxVersion;
return self16->children[i];
}
} break;
case Type_Node48: {
auto *self48 = static_cast<Node48 *>(self);
int secondIndex = self48->index[index];
if (secondIndex >= 0) {
consumePartialKey(self48->children[secondIndex], key, newMaxVersion,
writeContext, impl);
self48->childMaxVersion[secondIndex] = newMaxVersion;
self48->maxOfMax[secondIndex >> Node48::kMaxOfMaxShift] =
std::max(self48->maxOfMax[secondIndex >> Node48::kMaxOfMaxShift],
newMaxVersion);
return self48->children[secondIndex];
}
} break;
case Type_Node256: {
auto *self256 = static_cast<Node256 *>(self);
if (auto &result = self256->children[index]; result != nullptr) {
consumePartialKey(result, key, newMaxVersion, writeContext, impl);
self256->childMaxVersion[index] = newMaxVersion;
self256->maxOfMax[index >> Node256::kMaxOfMaxShift] = std::max(
self256->maxOfMax[index >> Node256::kMaxOfMaxShift], newMaxVersion);
return result;
}
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
auto *newChild = writeContext->allocate<Node0>(
key.size(), getMaxCapacity(0, 1, key.size()));
newChild->numChildren = 0;
newChild->entryPresent = false; // Will be set to true by the caller
newChild->partialKeyLen = key.size();
newChild->parentsIndex = index;
memcpy(newChild->partialKey(), key.data(), key.size());
key = {};
switch (self->getType()) {
case Type_Node0: {
auto *self0 = static_cast<Node0 *>(self);
auto *newSelf = writeContext->allocate<Node3>(
self->partialKeyLen, getMaxCapacity(1, 1, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self0);
writeContext->deferRelease(self0, newSelf);
self = newSelf;
goto insert3;
}
case Type_Node3: {
if (self->numChildren == Node3::kMaxNodes) {
auto *self3 = static_cast<Node3 *>(self);
auto *newSelf = writeContext->allocate<Node16>(
self->partialKeyLen,
getMaxCapacity(4, self->entryPresent, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self3);
writeContext->deferRelease(self3, newSelf);
self = newSelf;
goto insert16;
}
insert3:
auto *self3 = static_cast<Node3 *>(self);
int i = self->numChildren - 1;
for (; i >= 0; --i) {
if (self3->index[i] < index) {
break;
}
self3->index[i + 1] = self3->index[i];
self3->children[i + 1] = self3->children[i];
self3->childMaxVersion[i + 1] = self3->childMaxVersion[i];
}
self3->index[i + 1] = index;
auto &result = self3->children[i + 1];
self3->childMaxVersion[i + 1] = newMaxVersion;
result = newChild;
++self->numChildren;
newChild->parent = self;
return result;
}
case Type_Node16: {
if (self->numChildren == Node16::kMaxNodes) {
auto *self16 = static_cast<Node16 *>(self);
auto *newSelf = writeContext->allocate<Node48>(
self->partialKeyLen,
getMaxCapacity(17, self->entryPresent, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self16);
writeContext->deferRelease(self16, newSelf);
self = newSelf;
goto insert48;
}
insert16:
assert(self->getType() == Type_Node16);
auto *self16 = static_cast<Node16 *>(self);
int i = self->numChildren - 1;
for (; i >= 0; --i) {
if (self16->index[i] < index) {
break;
}
self16->index[i + 1] = self16->index[i];
self16->children[i + 1] = self16->children[i];
self16->childMaxVersion[i + 1] = self16->childMaxVersion[i];
}
self16->index[i + 1] = index;
auto &result = self16->children[i + 1];
self16->childMaxVersion[i + 1] = newMaxVersion;
result = newChild;
++self->numChildren;
newChild->parent = self;
return result;
}
case Type_Node48: {
if (self->numChildren == 48) {
auto *self48 = static_cast<Node48 *>(self);
auto *newSelf = writeContext->allocate<Node256>(
self->partialKeyLen,
getMaxCapacity(49, self->entryPresent, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self48);
writeContext->deferRelease(self48, newSelf);
self = newSelf;
goto insert256;
}
insert48:
auto *self48 = static_cast<Node48 *>(self);
self48->bitSet.set(index);
auto nextFree = self48->numChildren++;
self48->index[index] = nextFree;
self48->reverseIndex[nextFree] = index;
auto &result = self48->children[nextFree];
self48->childMaxVersion[nextFree] = newMaxVersion;
self48->maxOfMax[nextFree >> Node48::kMaxOfMaxShift] = std::max(
newMaxVersion, self48->maxOfMax[nextFree >> Node48::kMaxOfMaxShift]);
result = newChild;
newChild->parent = self;
return result;
}
case Type_Node256: {
insert256:
auto *self256 = static_cast<Node256 *>(self);
++self->numChildren;
self256->bitSet.set(index);
auto &result = self256->children[index];
self256->childMaxVersion[index] = newMaxVersion;
self256->maxOfMax[index >> Node256::kMaxOfMaxShift] = std::max(
newMaxVersion, self256->maxOfMax[index >> Node256::kMaxOfMaxShift]);
result = newChild;
newChild->parent = self;
return result;
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
Node *nextPhysical(Node *node) {
Node *nextChild = getFirstChild(node);
if (nextChild != nullptr) {
return nextChild;
}
for (;;) {
int index = node->parentsIndex;
node = node->parent;
if (node == nullptr) {
return nullptr;
}
Node *nextChild = getChildGeq(node, index + 1);
if (nextChild != nullptr) {
return nextChild;
}
}
}
Node *nextLogical(Node *node) {
Node *nextChild = getFirstChild(node);
if (nextChild != nullptr) {
node = nextChild;
goto downLeftSpine;
}
for (;;) {
int index = node->parentsIndex;
node = node->parent;
if (node == nullptr) {
return nullptr;
}
Node *nextChild = getChildGeq(node, index + 1);
if (nextChild != nullptr) {
node = nextChild;
goto downLeftSpine;
}
}
downLeftSpine:
for (; !node->entryPresent; node = getFirstChild(node)) {
}
return node;
}
void freeAndMakeCapacityBetween(Node *&self, int minCapacity, int maxCapacity,
WriteContext *writeContext,
ConflictSet::Impl *impl) {
switch (self->getType()) {
case Type_Node0: {
auto *self0 = (Node0 *)self;
auto *newSelf = writeContext->allocate<Node0>(minCapacity, maxCapacity);
newSelf->copyChildrenAndKeyFrom(*self0);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self0, newSelf);
self = newSelf;
} break;
case Type_Node3: {
auto *self3 = (Node3 *)self;
auto *newSelf = writeContext->allocate<Node3>(minCapacity, maxCapacity);
newSelf->copyChildrenAndKeyFrom(*self3);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self3, newSelf);
self = newSelf;
} break;
case Type_Node16: {
auto *self16 = (Node16 *)self;
auto *newSelf = writeContext->allocate<Node16>(minCapacity, maxCapacity);
newSelf->copyChildrenAndKeyFrom(*self16);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self16, newSelf);
self = newSelf;
} break;
case Type_Node48: {
auto *self48 = (Node48 *)self;
auto *newSelf = writeContext->allocate<Node48>(minCapacity, maxCapacity);
newSelf->copyChildrenAndKeyFrom(*self48);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self48, newSelf);
self = newSelf;
} break;
case Type_Node256: {
auto *self256 = (Node256 *)self;
auto *newSelf = writeContext->allocate<Node256>(minCapacity, maxCapacity);
newSelf->copyChildrenAndKeyFrom(*self256);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self256, newSelf);
self = newSelf;
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Fix larger-than-desired capacities. self must not be the root
void maybeDecreaseCapacity(Node *&self, WriteContext *writeContext,
ConflictSet::Impl *impl) {
const int maxCapacity =
(self->numChildren + int(self->entryPresent)) * (self->partialKeyLen + 1);
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "maybeDecreaseCapacity: current: %d, max: %d, key: %s\n",
self->getCapacity(), maxCapacity,
getSearchPathPrintable(self).c_str());
#endif
if (self->getCapacity() <= maxCapacity) {
return;
}
freeAndMakeCapacityBetween(self, self->partialKeyLen, maxCapacity,
writeContext, impl);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) void rezero16(InternalVersionT *vs,
InternalVersionT zero) {
uint32_t z;
memcpy(&z, &zero, sizeof(z));
const auto zvec = _mm512_set1_epi32(z);
auto m = _mm512_cmplt_epi32_mask(
_mm512_sub_epi32(_mm512_loadu_epi32(vs), zvec), _mm512_setzero_epi32());
_mm512_mask_storeu_epi32(vs, m, zvec);
}
__attribute__((target("default")))
#endif
void rezero16(InternalVersionT *vs, InternalVersionT zero) {
for (int i = 0; i < 16; ++i) {
vs[i] = std::max(vs[i], zero);
}
}
#if USE_64_BIT
void rezero(Node *, InternalVersionT) {}
#else
void rezero(Node *n, InternalVersionT z) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "rezero to %" PRId64 ": %s\n", z.toInt64(),
getSearchPathPrintable(n).c_str());
#endif
if (n->entryPresent) {
n->entry.pointVersion = std::max(n->entry.pointVersion, z);
n->entry.rangeVersion = std::max(n->entry.rangeVersion, z);
}
switch (n->getType()) {
case Type_Node0: {
} break;
case Type_Node3: {
auto *self = static_cast<Node3 *>(n);
for (int i = 0; i < 3; ++i) {
self->childMaxVersion[i] = std::max(self->childMaxVersion[i], z);
}
} break;
case Type_Node16: {
auto *self = static_cast<Node16 *>(n);
rezero16(self->childMaxVersion, z);
} break;
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
for (int i = 0; i < 48; i += 16) {
rezero16(self->childMaxVersion + i, z);
}
for (auto &m : self->maxOfMax) {
m = std::max(m, z);
}
} break;
case Type_Node256: {
auto *self = static_cast<Node256 *>(n);
for (int i = 0; i < 256; i += 16) {
rezero16(self->childMaxVersion + i, z);
}
rezero16(self->maxOfMax, z);
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
#endif
void mergeWithChild(TaggedNodePointer &self, WriteContext *writeContext,
Node3 *self3, ConflictSet::Impl *impl) {
assert(!self3->entryPresent);
Node *child = self3->children[0];
const int minCapacity = self3->partialKeyLen + 1 + child->partialKeyLen;
const int maxCapacity =
getMaxCapacity(child->numChildren, child->entryPresent, minCapacity);
if (minCapacity > child->getCapacity()) {
freeAndMakeCapacityBetween(child, minCapacity, maxCapacity, writeContext,
impl);
}
// Merge partial key with child
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Merge %s into %s\n", getSearchPathPrintable(self).c_str(),
getSearchPathPrintable(child).c_str());
#endif
InternalVersionT childMaxVersion = self3->childMaxVersion[0];
// Construct new partial key for child
memmove(child->partialKey() + self3->partialKeyLen + 1, child->partialKey(),
child->partialKeyLen);
memcpy(child->partialKey(), self3->partialKey(), self->partialKeyLen);
child->partialKey()[self3->partialKeyLen] = self3->index[0];
child->partialKeyLen += 1 + self3->partialKeyLen;
child->parent = self->parent;
child->parentsIndex = self->parentsIndex;
// Max versions are stored in the parent, so we need to update it now
// that we have a new parent. Safe to call since the root never has a partial
// key.
setMaxVersion(child, std::max(childMaxVersion, writeContext->zero));
self = child;
writeContext->deferRelease(self3, self);
}
bool needsDownsize(Node *n) {
static int minTable[] = {0, kMinChildrenNode3, kMinChildrenNode16,
kMinChildrenNode48, kMinChildrenNode256};
return n->numChildren + n->entryPresent < minTable[n->getType()];
}
void downsize(Node3 *self, WriteContext *writeContext,
ConflictSet::Impl *impl) {
if (self->numChildren == 0) {
auto *newSelf = writeContext->allocate<Node0>(
self->partialKeyLen, getMaxCapacity(0, 1, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self, newSelf);
} else {
assert(self->numChildren == 1 && !self->entryPresent);
mergeWithChild(getInTree(self, impl), writeContext, self, impl);
}
}
void downsize(Node16 *self, WriteContext *writeContext,
ConflictSet::Impl *impl) {
assert(self->numChildren + int(self->entryPresent) < kMinChildrenNode16);
auto *newSelf = writeContext->allocate<Node3>(
self->partialKeyLen,
getMaxCapacity(kMinChildrenNode16 - 1, 0, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self, newSelf);
}
void downsize(Node48 *self, WriteContext *writeContext,
ConflictSet::Impl *impl) {
assert(self->numChildren + int(self->entryPresent) < kMinChildrenNode48);
auto *newSelf = writeContext->allocate<Node16>(
self->partialKeyLen,
getMaxCapacity(kMinChildrenNode48 - 1, 0, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self, newSelf);
}
void downsize(Node256 *self, WriteContext *writeContext,
ConflictSet::Impl *impl) {
assert(self->numChildren + int(self->entryPresent) < kMinChildrenNode256);
auto *self256 = (Node256 *)self;
auto *newSelf = writeContext->allocate<Node48>(
self->partialKeyLen,
getMaxCapacity(kMinChildrenNode256 - 1, 0, self->partialKeyLen));
newSelf->copyChildrenAndKeyFrom(*self256);
getInTree(self, impl) = newSelf;
writeContext->deferRelease(self256, newSelf);
}
void downsize(Node *self, WriteContext *writeContext, ConflictSet::Impl *impl) {
switch (self->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3:
downsize(static_cast<Node3 *>(self), writeContext, impl);
break;
case Type_Node16:
downsize(static_cast<Node16 *>(self), writeContext, impl);
break;
case Type_Node48:
downsize(static_cast<Node48 *>(self), writeContext, impl);
break;
case Type_Node256:
downsize(static_cast<Node256 *>(self), writeContext, impl);
break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition: self is not the root. May invalidate nodes along the search
// path to self. May invalidate children of self->parent. Returns a pointer to
// the node after self. Precondition: `self->entryPresent`
Node *erase(Node *self, WriteContext *writeContext, ConflictSet::Impl *impl,
bool logical) {
++writeContext->accum.entries_erased;
assert(self->parent != nullptr);
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Erase: %s\n", getSearchPathPrintable(self).c_str());
#endif
Node *parent = self->parent;
uint8_t parentsIndex = self->parentsIndex;
auto *result = logical ? nextLogical(self) : nextPhysical(self);
removeKey(self);
assert(self->entryPresent);
self->entryPresent = false;
if (self->numChildren != 0) {
if (needsDownsize(self)) {
downsize(self, writeContext, impl);
}
while (self->releaseDeferred) {
self = self->forwardTo;
}
maybeDecreaseCapacity(self, writeContext, impl);
if (result != nullptr) {
while (result->releaseDeferred) {
result = result->forwardTo;
}
}
return result;
}
assert(self->getType() == Type_Node0);
writeContext->release((Node0 *)self);
switch (parent->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *parent3 = static_cast<Node3 *>(parent);
int nodeIndex = getNodeIndex(parent3, parentsIndex);
assert(nodeIndex >= 0);
--parent->numChildren;
for (int i = nodeIndex; i < parent->numChildren; ++i) {
parent3->index[i] = parent3->index[i + 1];
parent3->children[i] = parent3->children[i + 1];
parent3->childMaxVersion[i] = parent3->childMaxVersion[i + 1];
}
if (needsDownsize(parent3)) {
downsize(parent3, writeContext, impl);
}
} break;
case Type_Node16: {
auto *parent16 = static_cast<Node16 *>(parent);
int nodeIndex = getNodeIndex(parent16, parentsIndex);
assert(nodeIndex >= 0);
--parent->numChildren;
for (int i = nodeIndex; i < parent->numChildren; ++i) {
parent16->index[i] = parent16->index[i + 1];
parent16->children[i] = parent16->children[i + 1];
parent16->childMaxVersion[i] = parent16->childMaxVersion[i + 1];
}
if (needsDownsize(parent16)) {
downsize(parent16, writeContext, impl);
}
} break;
case Type_Node48: {
auto *parent48 = static_cast<Node48 *>(parent);
parent48->bitSet.reset(parentsIndex);
int8_t toRemoveChildrenIndex =
std::exchange(parent48->index[parentsIndex], -1);
auto lastChildrenIndex = --parent48->numChildren;
assert(toRemoveChildrenIndex >= 0);
assert(lastChildrenIndex >= 0);
if (toRemoveChildrenIndex != lastChildrenIndex) {
parent48->children[toRemoveChildrenIndex] =
parent48->children[lastChildrenIndex];
parent48->childMaxVersion[toRemoveChildrenIndex] =
parent48->childMaxVersion[lastChildrenIndex];
parent48->maxOfMax[toRemoveChildrenIndex >> Node48::kMaxOfMaxShift] =
std::max(parent48->maxOfMax[toRemoveChildrenIndex >>
Node48::kMaxOfMaxShift],
parent48->childMaxVersion[toRemoveChildrenIndex]);
auto parentIndex =
parent48->children[toRemoveChildrenIndex]->parentsIndex;
parent48->index[parentIndex] = toRemoveChildrenIndex;
parent48->reverseIndex[toRemoveChildrenIndex] = parentIndex;
}
parent48->childMaxVersion[lastChildrenIndex] = writeContext->zero;
if (needsDownsize(parent48)) {
downsize(parent48, writeContext, impl);
}
} break;
case Type_Node256: {
auto *parent256 = static_cast<Node256 *>(parent);
parent256->bitSet.reset(parentsIndex);
parent256->children[parentsIndex] = nullptr;
--parent->numChildren;
if (needsDownsize(parent256)) {
downsize(parent256, writeContext, impl);
}
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
while (parent->releaseDeferred) {
parent = parent->forwardTo;
}
maybeDecreaseCapacity(parent, writeContext, impl);
if (result != nullptr) {
while (result->releaseDeferred) {
result = result->forwardTo;
}
}
return result;
}
TaggedNodePointer nextSibling(Node *node) {
for (;;) {
if (node->parent == nullptr) {
return nullptr;
}
auto next = getChildGeq(node->parent, node->parentsIndex + 1);
if (next == nullptr) {
node = node->parent;
} else {
return next;
}
}
}
#ifdef HAS_AVX
uint32_t compare16(const InternalVersionT *vs, InternalVersionT rv) {
#if USE_64_BIT
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (vs[i] > rv) << i;
}
return compared;
#else
uint32_t compared = 0;
__m128i w[4]; // GCOVR_EXCL_LINE
memcpy(w, vs, sizeof(w));
uint32_t r; // GCOVR_EXCL_LINE
memcpy(&r, &rv, sizeof(r));
const auto rvVec = _mm_set1_epi32(r);
const auto zero = _mm_setzero_si128();
for (int i = 0; i < 4; ++i) {
compared |= _mm_movemask_ps(
__m128(_mm_cmpgt_epi32(_mm_sub_epi32(w[i], rvVec), zero)))
<< (i * 4);
}
return compared;
#endif
}
__attribute__((target("avx512f"))) uint32_t
compare16_avx512(const InternalVersionT *vs, InternalVersionT rv) {
#if USE_64_BIT
int64_t r;
memcpy(&r, &rv, sizeof(r));
uint32_t low =
_mm512_cmpgt_epi64_mask(_mm512_loadu_epi64(vs), _mm512_set1_epi64(r));
uint32_t high =
_mm512_cmpgt_epi64_mask(_mm512_loadu_epi64(vs + 8), _mm512_set1_epi64(r));
return low | (high << 8);
#else
uint32_t r;
memcpy(&r, &rv, sizeof(r));
return _mm512_cmpgt_epi32_mask(
_mm512_sub_epi32(_mm512_loadu_epi32(vs), _mm512_set1_epi32(r)),
_mm512_setzero_epi32());
#endif
}
#endif
// Returns true if v[i] <= readVersion for all i such that begin <= is[i] < end
// Preconditions: begin <= end, end - begin < 256
template <bool kAVX512>
bool scan16(const InternalVersionT *vs, const uint8_t *is, int begin, int end,
InternalVersionT readVersion) {
assert(begin <= end);
assert(end - begin < 256);
#ifdef HAS_ARM_NEON
uint8x16_t indices;
memcpy(&indices, is, 16);
// 0xff for each in bounds
auto results =
vcltq_u8(vsubq_u8(indices, vdupq_n_u8(begin)), vdupq_n_u8(end - begin));
// 0xf for each 0xff
uint64_t mask = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0);
uint32x4_t w4[4];
memcpy(w4, vs, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] =
vmovn_u32(vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t compared = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
return !(compared & mask);
#elif defined(HAS_AVX)
__m128i indices;
memcpy(&indices, is, 16);
indices = _mm_sub_epi8(indices, _mm_set1_epi8(begin));
uint32_t mask = ~_mm_movemask_epi8(_mm_cmpeq_epi8(
indices, _mm_max_epu8(indices, _mm_set1_epi8(end - begin))));
uint32_t compared = 0;
if constexpr (kAVX512) {
compared = compare16_avx512(vs, readVersion);
} else {
compared = compare16(vs, readVersion);
}
return !(compared & mask);
#else
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) { return c - shiftAmount < shiftUpperBound; };
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (vs[i] > readVersion) << i;
}
uint32_t mask = 0;
for (int i = 0; i < 16; ++i) {
mask |= inBounds(is[i]) << i;
}
return !(compared & mask);
#endif
}
// Returns true if v[i] <= readVersion for all i such that begin <= i < end
template <bool kAVX512>
bool scan16(const InternalVersionT *vs, int begin, int end,
InternalVersionT readVersion) {
assert(0 <= begin && begin < 16);
assert(0 <= end && end <= 16);
assert(begin <= end);
#if defined(HAS_ARM_NEON)
uint32x4_t w4[4];
memcpy(w4, vs, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] =
vmovn_u32(vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t conflict = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
conflict &= end == 16 ? -1 : (uint64_t(1) << (end << 2)) - 1;
conflict >>= begin << 2;
return !conflict;
#elif defined(HAS_AVX)
uint32_t conflict;
if constexpr (kAVX512) {
conflict = compare16_avx512(vs, readVersion);
} else {
conflict = compare16(vs, readVersion);
}
conflict &= (1 << end) - 1;
conflict >>= begin;
return !conflict;
#else
uint64_t conflict = 0;
for (int i = 0; i < 16; ++i) {
conflict |= (vs[i] > readVersion) << i;
}
conflict &= (1 << end) - 1;
conflict >>= begin;
return !conflict;
#endif
}
// Return whether or not the max version among all keys starting with the search
// path of n + [child], where child in (begin, end) is <= readVersion. Does not
// account for the range version of firstGt(searchpath(n) + [end - 1])
template <bool kAVX512>
bool checkMaxBetweenExclusiveImpl(Node0 *, int, int, InternalVersionT,
ReadContext *readContext) {
++readContext->range_read_node_scan_accum;
return true;
}
template <bool kAVX512>
bool checkMaxBetweenExclusiveImpl(Node3 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
++readContext->range_read_node_scan_accum;
assume(-1 <= begin);
assume(begin <= 256);
assume(-1 <= end);
assume(end <= 256);
assume(begin < end);
assert(!(begin == -1 && end == 256));
auto *self = static_cast<Node3 *>(n);
++begin;
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) { return c - shiftAmount < shiftUpperBound; };
uint32_t mask = 0;
for (int i = 0; i < Node3::kMaxNodes; ++i) {
mask |= inBounds(self->index[i]) << i;
}
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
for (int i = 0; i < Node3::kMaxNodes; ++i) {
compared |= (self->childMaxVersion[i] > readVersion) << i;
}
return !(compared & mask) && firstRangeOk;
}
template <bool kAVX512>
bool checkMaxBetweenExclusiveImpl(Node16 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
++readContext->range_read_node_scan_accum;
assume(-1 <= begin);
assume(begin <= 256);
assume(-1 <= end);
assume(end <= 256);
assume(begin < end);
assert(!(begin == -1 && end == 256));
auto *self = static_cast<Node16 *>(n);
++begin;
assert(begin <= end);
assert(end - begin < 256);
#ifdef HAS_ARM_NEON
uint8x16_t indices;
memcpy(&indices, self->index, 16);
// 0xff for each in bounds
auto results =
vcltq_u8(vsubq_u8(indices, vdupq_n_u8(begin)), vdupq_n_u8(end - begin));
// 0xf for each 0xff
uint64_t mask = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0);
mask &= self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren << 2)) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask) >> 2];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32x4_t w4[4];
memcpy(w4, self->childMaxVersion, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] =
vmovn_u32(vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t compared = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
return !(compared & mask) && firstRangeOk;
#elif defined(HAS_AVX)
__m128i indices;
memcpy(&indices, self->index, 16);
indices = _mm_sub_epi8(indices, _mm_set1_epi8(begin));
uint32_t mask =
0xffff & ~_mm_movemask_epi8(_mm_cmpeq_epi8(
indices, _mm_max_epu8(indices, _mm_set1_epi8(end - begin))));
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
if constexpr (kAVX512) {
compared = compare16_avx512(self->childMaxVersion, readVersion);
} else {
compared = compare16(self->childMaxVersion, readVersion);
}
return !(compared & mask) && firstRangeOk;
#else
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) { return c - shiftAmount < shiftUpperBound; };
uint32_t mask = 0;
for (int i = 0; i < 16; ++i) {
mask |= inBounds(self->index[i]) << i;
}
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (self->childMaxVersion[i] > readVersion) << i;
}
return !(compared & mask) && firstRangeOk;
#endif
}
template <bool kAVX512>
bool checkMaxBetweenExclusiveImpl(Node48 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
++readContext->range_read_node_scan_accum;
assume(-1 <= begin);
assume(begin <= 256);
assume(-1 <= end);
assume(end <= 256);
assume(begin < end);
assert(!(begin == -1 && end == 256));
auto *self = static_cast<Node48 *>(n);
{
int c = self->bitSet.firstSetGeq(begin + 1);
if (c >= 0 && c < end) {
Node *child = self->children[self->index[c]];
if (child->entryPresent && child->entry.rangeVersion > readVersion) {
return false;
}
begin = c;
} else {
return true;
}
// [begin, end) is now the half-open interval of children we're interested
// in.
assert(begin < end);
}
// Check all pages
static_assert(Node48::kMaxOfMaxPageSize == 16);
for (int i = 0; i < Node48::kMaxOfMaxTotalPages; ++i) {
if (self->maxOfMax[i] > readVersion) {
if (!scan16<kAVX512>(self->childMaxVersion +
(i << Node48::kMaxOfMaxShift),
self->reverseIndex + (i << Node48::kMaxOfMaxShift),
begin, end, readVersion)) {
return false;
}
}
}
return true;
}
template <bool kAVX512>
bool checkMaxBetweenExclusiveImpl(Node256 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
++readContext->range_read_node_scan_accum;
assume(-1 <= begin);
assume(begin <= 256);
assume(-1 <= end);
assume(end <= 256);
assume(begin < end);
assert(!(begin == -1 && end == 256));
static_assert(Node256::kMaxOfMaxTotalPages == 16);
auto *self = static_cast<Node256 *>(n);
{
int c = self->bitSet.firstSetGeq(begin + 1);
if (c >= 0 && c < end) {
Node *child = self->children[c];
if (child->entryPresent && child->entry.rangeVersion > readVersion) {
return false;
}
begin = c;
} else {
return true;
}
// [begin, end) is now the half-open interval of children we're interested
// in.
assert(begin < end);
}
const int firstPage = begin >> Node256::kMaxOfMaxShift;
const int lastPage = (end - 1) >> Node256::kMaxOfMaxShift;
// Check the only page if there's only one
if (firstPage == lastPage) {
if (self->maxOfMax[firstPage] <= readVersion) {
return true;
}
const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1);
const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift);
return scan16<kAVX512>(self->childMaxVersion +
(firstPage << Node256::kMaxOfMaxShift),
intraPageBegin, intraPageEnd, readVersion);
}
// Check the first page
if (self->maxOfMax[firstPage] > readVersion) {
const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1);
if (!scan16<kAVX512>(self->childMaxVersion +
(firstPage << Node256::kMaxOfMaxShift),
intraPageBegin, 16, readVersion)) {
return false;
}
}
// Check the last page
if (self->maxOfMax[lastPage] > readVersion) {
const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift);
if (!scan16<kAVX512>(self->childMaxVersion +
(lastPage << Node256::kMaxOfMaxShift),
0, intraPageEnd, readVersion)) {
return false;
}
}
// Check inner pages
return scan16<kAVX512>(self->maxOfMax, firstPage + 1, lastPage, readVersion);
}
bool checkMaxBetweenExclusive(Node0 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion,
readContext);
}
bool checkMaxBetweenExclusive(Node3 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion,
readContext);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) bool
checkMaxBetweenExclusive(Node16 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<true>(n, begin, end, readVersion,
readContext);
}
__attribute__((target("default")))
#endif
bool checkMaxBetweenExclusive(Node16 *n, int begin, int end,
InternalVersionT readVersion, ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion,
readContext);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) bool
checkMaxBetweenExclusive(Node48 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<true>(n, begin, end, readVersion,
readContext);
}
__attribute__((target("default")))
#endif
bool checkMaxBetweenExclusive(Node48 *n, int begin, int end,
InternalVersionT readVersion, ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion,
readContext);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) bool
checkMaxBetweenExclusive(Node256 *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<true>(n, begin, end, readVersion,
readContext);
}
__attribute__((target("default")))
#endif
bool checkMaxBetweenExclusive(Node256 *n, int begin, int end,
InternalVersionT readVersion, ReadContext *readContext) {
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion,
readContext);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) bool
checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
switch (n->getType()) {
case Type_Node0:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node0 *>(n), begin,
end, readVersion, readContext);
case Type_Node3:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node3 *>(n), begin,
end, readVersion, readContext);
case Type_Node16:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node16 *>(n), begin,
end, readVersion, readContext);
case Type_Node48:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node48 *>(n), begin,
end, readVersion, readContext);
case Type_Node256:
return checkMaxBetweenExclusiveImpl<true>(static_cast<Node256 *>(n), begin,
end, readVersion, readContext);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
__attribute__((target("default")))
#endif
bool checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *readContext) {
switch (n->getType()) {
case Type_Node0:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node0 *>(n), begin,
end, readVersion, readContext);
case Type_Node3:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node3 *>(n), begin,
end, readVersion, readContext);
case Type_Node16:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node16 *>(n), begin,
end, readVersion, readContext);
case Type_Node48:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node48 *>(n), begin,
end, readVersion, readContext);
case Type_Node256:
return checkMaxBetweenExclusiveImpl<false>(static_cast<Node256 *>(n), begin,
end, readVersion, readContext);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
TrivialSpan getSearchPath(Arena &arena, Node *n) {
assert(n != nullptr);
auto result = vector<uint8_t>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
return {result.begin(), result.size()};
}
// Return true if the max version among all keys that start with key + [child],
// where begin < child < end, is <= readVersion.
//
// Precondition: transitively, no child of n has a search path that's a longer
// prefix of key than n
template <class NodeT>
bool checkRangeStartsWith(NodeT *nTyped, TrivialSpan key, int begin, int end,
InternalVersionT readVersion,
ReadContext *readContext) {
Node *n;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
#endif
auto remaining = key;
if (remaining.size() == 0) {
return checkMaxBetweenExclusive(nTyped, begin, end, readVersion,
readContext);
}
Node *child = getChild(nTyped, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(nTyped, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(nTyped);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
assert(n->partialKeyLen > 0);
{
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
assert(n->partialKeyLen > remaining.size());
if (begin < n->partialKey()[remaining.size()] &&
n->partialKey()[remaining.size()] < end) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n) <= readVersion;
}
return true;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
downLeftSpine:
for (; !n->entryPresent; n = getFirstChild(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
#ifdef __x86_64__
// Explicitly instantiate with target avx512f attribute so the compiler can
// inline compare16_32bit_avx512, and generally use avx512f within more
// functions
template __attribute__((target("avx512f"))) bool
scan16<true>(const InternalVersionT *vs, const uint8_t *is, int begin, int end,
InternalVersionT readVersion);
template __attribute__((target("avx512f"))) bool
scan16<true>(const InternalVersionT *vs, int begin, int end,
InternalVersionT readVersion);
template __attribute__((target("avx512f"))) bool
checkMaxBetweenExclusiveImpl<true>(Node16 *n, int begin, int end,
InternalVersionT readVersion, ReadContext *);
template __attribute__((target("avx512f"))) bool
checkMaxBetweenExclusiveImpl<true>(Node48 *n, int begin, int end,
InternalVersionT readVersion, ReadContext *);
template __attribute__((target("avx512f"))) bool
checkMaxBetweenExclusiveImpl<true>(Node256 *n, int begin, int end,
InternalVersionT readVersion, ReadContext *);
#endif
// Returns a pointer the pointer to the newly inserted node in the tree. Caller
// must set `entryPresent`, and `entry` fields. All nodes along the search path
// of the result will have `maxVersion` set to `writeVersion` as a
// postcondition. Nodes along the search path may be invalidated. Callers must
// ensure that the max version of the self argument is updated.
[[nodiscard]] TaggedNodePointer *
insert(TaggedNodePointer *self, TrivialSpan key, InternalVersionT writeVersion,
WriteContext *writeContext, ConflictSet::Impl *impl) {
for (; key.size() != 0; ++writeContext->accum.insert_iterations) {
self = &getOrCreateChild(*self, key, writeVersion, writeContext, impl);
}
return self;
}
void eraseTree(Node *root, WriteContext *writeContext) {
Arena arena;
auto toFree = vector<Node *>(arena);
toFree.push_back(root);
while (toFree.size() > 0) {
auto *n = toFree.back();
toFree.pop_back();
writeContext->accum.entries_erased += n->entryPresent;
++writeContext->accum.nodes_released;
removeKey(n);
switch (n->getType()) {
case Type_Node0: {
auto *n0 = static_cast<Node0 *>(n);
writeContext->release(n0);
} break;
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
for (int i = 0; i < n3->numChildren; ++i) {
toFree.push_back(n3->children[i]);
}
writeContext->release(n3);
} break;
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
for (int i = 0; i < n16->numChildren; ++i) {
toFree.push_back(n16->children[i]);
}
writeContext->release(n16);
} break;
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
for (int i = 0; i < n48->numChildren; ++i) {
toFree.push_back(n48->children[i]);
}
writeContext->release(n48);
} break;
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
auto *out = toFree.unsafePrepareAppend(n256->numChildren).data();
n256->bitSet.forEachSet([&](int i) { *out++ = n256->children[i]; });
assert(out == toFree.end());
writeContext->release(n256);
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
}
void addPointWrite(TaggedNodePointer &root, TrivialSpan key,
InternalVersionT writeVersion, WriteContext *writeContext,
ConflictSet::Impl *impl) {
++writeContext->accum.point_writes;
auto n = *insert(&root, key, writeVersion, writeContext, impl);
if (!n->entryPresent) {
++writeContext->accum.entries_inserted;
auto *p = nextLogical(n);
addKey(n);
n->entryPresent = true;
n->entry.pointVersion = writeVersion;
n->entry.rangeVersion =
p == nullptr ? writeContext->zero
: std::max(p->entry.rangeVersion, writeContext->zero);
} else {
assert(writeVersion >= n->entry.pointVersion);
n->entry.pointVersion = writeVersion;
}
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) InternalVersionT horizontalMaxUpTo16(
InternalVersionT *vs, [[maybe_unused]] InternalVersionT z, int len) {
assume(len <= 16);
#if USE_64_BIT
// Hope it gets vectorized
InternalVersionT max = vs[0];
for (int i = 1; i < len; ++i) {
max = std::max(vs[i], max);
}
return max;
#else
uint32_t zero;
memcpy(&zero, &z, sizeof(zero));
auto zeroVec = _mm512_set1_epi32(zero);
auto max = InternalVersionT(
zero +
_mm512_reduce_max_epi32(_mm512_sub_epi32(
_mm512_mask_loadu_epi32(zeroVec, _mm512_int2mask((1 << len) - 1), vs),
zeroVec)));
return max;
#endif
}
__attribute__((target("default")))
#endif
InternalVersionT
horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT, int len) {
assume(len <= 16);
InternalVersionT max = vs[0];
for (int i = 1; i < len; ++i) {
max = std::max(vs[i], max);
}
return max;
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) InternalVersionT
horizontalMax16(InternalVersionT *vs, [[maybe_unused]] InternalVersionT z) {
#if USE_64_BIT
// Hope it gets vectorized
InternalVersionT max = vs[0];
for (int i = 1; i < 16; ++i) {
max = std::max(vs[i], max);
}
return max;
#else
uint32_t zero; // GCOVR_EXCL_LINE
memcpy(&zero, &z, sizeof(zero));
auto zeroVec = _mm512_set1_epi32(zero);
return InternalVersionT(zero + _mm512_reduce_max_epi32(_mm512_sub_epi32(
_mm512_loadu_epi32(vs), zeroVec)));
#endif
}
__attribute__((target("default")))
#endif
InternalVersionT
horizontalMax16(InternalVersionT *vs, InternalVersionT) {
InternalVersionT max = vs[0];
for (int i = 1; i < 16; ++i) {
max = std::max(vs[i], max);
}
return max;
}
// Precondition: `node->entryPresent`, and node is not the root
void fixupMaxVersion(Node *node, WriteContext *writeContext) {
assert(node->parent);
InternalVersionT max;
assert(node->entryPresent);
max = std::max(node->entry.pointVersion, writeContext->zero);
switch (node->getType()) {
case Type_Node0:
break;
case Type_Node3: {
auto *self3 = static_cast<Node3 *>(node);
max = std::max(max,
horizontalMaxUpTo16(self3->childMaxVersion,
writeContext->zero, self3->numChildren));
} break;
case Type_Node16: {
auto *self16 = static_cast<Node16 *>(node);
max = std::max(max, horizontalMaxUpTo16(self16->childMaxVersion,
writeContext->zero,
self16->numChildren));
} break;
case Type_Node48: {
auto *self48 = static_cast<Node48 *>(node);
for (auto v : self48->maxOfMax) {
max = std::max(v, max);
}
} break;
case Type_Node256: {
auto *self256 = static_cast<Node256 *>(node);
max = std::max(
max, horizontalMax16(self256->childMaxVersion, writeContext->zero));
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
setMaxVersion(node, max);
}
struct AddedWriteRange {
Node *beginNode;
Node *endNode;
};
AddedWriteRange addWriteRange(Node *beginRoot, TrivialSpan begin, Node *endRoot,
TrivialSpan end, InternalVersionT writeVersion,
WriteContext *writeContext,
ConflictSet::Impl *impl) {
++writeContext->accum.range_writes;
Node *beginNode = *insert(&getInTree(beginRoot, impl), begin, writeVersion,
writeContext, impl);
addKey(beginNode);
if (!beginNode->entryPresent) {
++writeContext->accum.entries_inserted;
auto *p = nextLogical(beginNode);
beginNode->entry.rangeVersion =
p == nullptr ? writeContext->zero
: std::max(p->entry.rangeVersion, writeContext->zero);
while (endRoot->releaseDeferred) {
endRoot = endRoot->forwardTo;
}
beginNode->entryPresent = true;
}
beginNode->entry.pointVersion = writeVersion;
Node *endNode =
*insert(&getInTree(endRoot, impl), end, writeVersion, writeContext, impl);
addKey(endNode);
if (!endNode->entryPresent) {
++writeContext->accum.entries_inserted;
auto *p = nextLogical(endNode);
endNode->entry.pointVersion =
p == nullptr ? writeContext->zero
: std::max(p->entry.rangeVersion, writeContext->zero);
while (beginNode->releaseDeferred) {
beginNode = beginNode->forwardTo;
}
endNode->entryPresent = true;
}
endNode->entry.rangeVersion = writeVersion;
return {beginNode, endNode};
}
void eraseInRange(Node *beginNode, Node *endNode, WriteContext *writeContext,
ConflictSet::Impl *impl) {
// Erase nodes in range
assert(!beginNode->endOfRange);
assert(!endNode->endOfRange);
endNode->endOfRange = true;
Node *iter = beginNode;
for (iter = nextLogical(iter); !iter->endOfRange;
iter = erase(iter, writeContext, impl, /*logical*/ true)) {
assert(!iter->endOfRange);
assert(!iter->releaseDeferred);
}
assert(iter->endOfRange);
iter->endOfRange = false;
// Inserting end trashed the last node's maxVersion. Fix that. Safe to call
// since the end key always has non-zero size.
fixupMaxVersion(iter, writeContext);
}
void addWriteRange(TaggedNodePointer &root, TrivialSpan begin, TrivialSpan end,
InternalVersionT writeVersion, WriteContext *writeContext,
ConflictSet::Impl *impl) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == begin.size() && end.size() == begin.size() + 1 &&
end.back() == 0) {
return addPointWrite(root, begin, writeVersion, writeContext, impl);
}
auto useAsRoot =
insert(&root, begin.subspan(0, lcp), writeVersion, writeContext, impl);
auto [beginNode, endNode] = addWriteRange(
*useAsRoot, begin.subspan(lcp, begin.size() - lcp), *useAsRoot,
end.subspan(lcp, end.size() - lcp), writeVersion, writeContext, impl);
eraseInRange(beginNode, endNode, writeContext, impl);
}
Node *firstGeqPhysical(Node *n, const TrivialSpan key) {
auto remaining = key;
for (;;) {
if (remaining.size() == 0) {
return n;
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
return n;
} else {
n = nextSibling(n);
if (n == nullptr) {
// This line is genuinely unreachable from any entry point of the
// final library, since we can't remove a key without introducing a
// key after it, and the only production caller of firstGeq is for
// resuming the setOldestVersion scan.
return nullptr; // GCOVR_EXCL_LINE
}
return n;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
return n;
} else {
n = nextSibling(n);
return n;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > remaining.size()) {
// n is the first physical node greater than remaining, and there's no
// eq node
return n;
}
}
}
}
#ifndef __has_attribute
#define __has_attribute(x) 0
#endif
#if __has_attribute(musttail)
#define MUSTTAIL __attribute__((musttail))
#else
#define MUSTTAIL
#endif
#if __has_attribute(preserve_none)
#define PRESERVE_NONE __attribute__((preserve_none))
#else
#define PRESERVE_NONE
#endif
#if __has_attribute(musttail) && __has_attribute(preserve_none)
constexpr bool kEnableInterleaved = true;
#else
constexpr bool kEnableInterleaved = false;
#endif
namespace check {
typedef PRESERVE_NONE void (*Continuation)(struct Job *, struct Context *);
// State relevant to an individual query
struct Job {
void setResult(bool ok) {
*result = ok ? ConflictSet::Commit : ConflictSet::Conflict;
}
void init(const ConflictSet::ReadRange *read, ConflictSet::Result *result,
Node *root, int64_t oldestVersionFullPrecision);
Node *n;
TrivialSpan begin;
InternalVersionT maxV;
TrivialSpan end; // range read only
TrivialSpan remaining; // range read only
Node *child; // range read only
int lcp; // range read only
Node *commonPrefixNode; // range read only
InternalVersionT readVersion;
ConflictSet::Result *result;
Continuation continuation;
Job *prev;
Job *next;
};
// State relevant to every query
struct Context {
int count;
int64_t oldestVersionFullPrecision;
Node *root;
const ConflictSet::ReadRange *queries;
ConflictSet::Result *results;
int64_t started;
ReadContext readContext;
};
PRESERVE_NONE void keepGoing(Job *job, Context *context) {
job = job->next;
MUSTTAIL return job->continuation(job, context);
}
PRESERVE_NONE void complete(Job *job, Context *context) {
if (context->started == context->count) {
if (job->prev == job) {
return;
}
job->prev->next = job->next;
job->next->prev = job->prev;
job = job->next;
MUSTTAIL return job->continuation(job, context);
} else {
int temp = context->started++;
job->init(context->queries + temp, context->results + temp, context->root,
context->oldestVersionFullPrecision);
MUSTTAIL return job->continuation(job, context);
}
}
template <class NodeT>
PRESERVE_NONE void down_left_spine(Job *job, Context *context);
static Continuation downLeftSpineTable[] = {
down_left_spine<Node0>, down_left_spine<Node3>, down_left_spine<Node16>,
down_left_spine<Node48>, down_left_spine<Node256>};
template <class NodeT>
PRESERVE_NONE void down_left_spine(Job *job, Context *context) {
assert(job->n->getType() == NodeT::kType);
NodeT *n = static_cast<NodeT *>(job->n);
if (n->entryPresent) {
job->setResult(n->entry.rangeVersion <= job->readVersion);
MUSTTAIL return complete(job, context);
}
auto child = getFirstChild(n);
job->n = child;
__builtin_prefetch(job->n);
job->continuation = downLeftSpineTable[child.getType()];
MUSTTAIL return keepGoing(job, context);
}
namespace point_read_state_machine {
PRESERVE_NONE void begin(Job *, Context *);
template <class NodeT> PRESERVE_NONE void iter(Job *, Context *);
static Continuation iterTable[] = {iter<Node0>, iter<Node3>, iter<Node16>,
iter<Node48>, iter<Node256>};
void begin(Job *job, Context *context) {
++context->readContext.point_read_accum;
if (job->begin.size() == 0) [[unlikely]] {
// We don't erase the root
assert(job->n->entryPresent);
job->setResult(job->n->entry.pointVersion <= job->readVersion);
MUSTTAIL return complete(job, context);
}
auto [taggedChild, maxV] = getChildAndMaxVersion(job->n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(job->n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
// The root never has a next sibling
job->setResult(true);
MUSTTAIL return complete(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT> void iter(Job *job, Context *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
job->begin = job->begin.subspan(1, job->begin.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, job->begin.size());
int i = longestCommonPrefix(n->partialKey(), job->begin.data(), commonLen);
if (i < commonLen) [[unlikely]] {
auto c = n->partialKey()[i] <=> job->begin[i];
if (c > 0) {
MUSTTAIL return down_left_spine<NodeT>(job, context);
} else {
auto s = nextSibling(n);
job->n = s;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = downLeftSpineTable[s.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
job->begin = job->begin.subspan(commonLen, job->begin.size() - commonLen);
} else if (n->partialKeyLen > job->begin.size()) [[unlikely]] {
// n is the first physical node greater than remaining, and there's no
// eq node
MUSTTAIL return down_left_spine<NodeT>(job, context);
}
}
if (job->maxV <= job->readVersion) {
job->setResult(true);
++context->readContext.point_read_short_circuit_accum;
MUSTTAIL return complete(job, context);
}
++context->readContext.point_read_iterations_accum;
if (job->begin.size() == 0) [[unlikely]] {
if (n->entryPresent) {
job->setResult(n->entry.pointVersion <= job->readVersion);
MUSTTAIL return complete(job, context);
}
auto c = getFirstChild(n);
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
auto [taggedChild, maxV] = getChildAndMaxVersion(n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = downLeftSpineTable[c->getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
} // namespace point_read_state_machine
namespace prefix_read_state_machine {
PRESERVE_NONE void begin(Job *, Context *);
template <class NodeT> PRESERVE_NONE void iter(Job *, Context *);
static Continuation iterTable[] = {iter<Node0>, iter<Node3>, iter<Node16>,
iter<Node48>, iter<Node256>};
void begin(Job *job, Context *context) {
++context->readContext.prefix_read_accum;
// There's no way to encode a prefix read of ""
assert(job->begin.size() > 0);
auto [taggedChild, maxV] = getChildAndMaxVersion(job->n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(job->n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
// The root never has a next sibling
job->setResult(true);
MUSTTAIL return complete(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT> void iter(Job *job, Context *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
job->begin = job->begin.subspan(1, job->begin.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, job->begin.size());
int i = longestCommonPrefix(n->partialKey(), job->begin.data(), commonLen);
if (i < commonLen) [[unlikely]] {
auto c = n->partialKey()[i] <=> job->begin[i];
if (c > 0) {
MUSTTAIL return down_left_spine<NodeT>(job, context);
} else {
auto c = nextSibling(n);
job->n = c;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
job->begin = job->begin.subspan(commonLen, job->begin.size() - commonLen);
} else if (n->partialKeyLen > job->begin.size()) [[unlikely]] {
// n is the first physical node greater than remaining, and there's no
// eq node. All physical nodes that start with prefix are reachable from
// n.
if (job->maxV > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
MUSTTAIL return down_left_spine<NodeT>(job, context);
}
}
if (job->maxV <= job->readVersion) {
job->setResult(true);
++context->readContext.prefix_read_short_circuit_accum;
MUSTTAIL return complete(job, context);
}
++context->readContext.prefix_read_iterations_accum;
if (job->begin.size() == 0) [[unlikely]] {
job->setResult(job->maxV <= job->readVersion);
MUSTTAIL return complete(job, context);
}
auto [taggedChild, maxV] = getChildAndMaxVersion(n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
} // namespace prefix_read_state_machine
namespace range_read_state_machine {
PRESERVE_NONE void begin(Job *, Context *);
template <class NodeT> PRESERVE_NONE void common_prefix_iter(Job *, Context *);
template <class NodeT>
PRESERVE_NONE void done_common_prefix_iter(Job *, Context *);
static Continuation commonPrefixIterTable[] = {
common_prefix_iter<Node0>, common_prefix_iter<Node3>,
common_prefix_iter<Node16>, common_prefix_iter<Node48>,
common_prefix_iter<Node256>};
static Continuation doneCommonPrefixIterTable[] = {
done_common_prefix_iter<Node0>, done_common_prefix_iter<Node3>,
done_common_prefix_iter<Node16>, done_common_prefix_iter<Node48>,
done_common_prefix_iter<Node256>};
template <class NodeT> PRESERVE_NONE void left_side_iter(Job *, Context *);
template <class NodeT>
PRESERVE_NONE void left_side_down_left_spine(Job *, Context *);
static Continuation leftSideDownLeftSpineTable[] = {
left_side_down_left_spine<Node0>, left_side_down_left_spine<Node3>,
left_side_down_left_spine<Node16>, left_side_down_left_spine<Node48>,
left_side_down_left_spine<Node256>};
PRESERVE_NONE void done_left_side_iter(Job *, Context *);
static Continuation leftSideIterTable[] = {
left_side_iter<Node0>, left_side_iter<Node3>, left_side_iter<Node16>,
left_side_iter<Node48>, left_side_iter<Node256>};
template <class NodeT> PRESERVE_NONE void right_side_iter(Job *, Context *);
static Continuation rightSideIterTable[] = {
right_side_iter<Node0>, right_side_iter<Node3>, right_side_iter<Node16>,
right_side_iter<Node48>, right_side_iter<Node256>};
PRESERVE_NONE void begin(Job *job, Context *context) {
job->lcp = longestCommonPrefix(job->begin.data(), job->end.data(),
std::min(job->begin.size(), job->end.size()));
if (job->lcp == job->begin.size() &&
job->end.size() == job->begin.size() + 1 && job->end.back() == 0) {
// Call directly since we have nothing to prefetch
MUSTTAIL return check::point_read_state_machine::begin(job, context);
}
if (job->lcp == job->begin.size() - 1 &&
job->end.size() == job->begin.size() &&
job->begin.back() + 1 == job->end.back()) {
// Call directly since we have nothing to prefetch
MUSTTAIL return check::prefix_read_state_machine::begin(job, context);
}
++context->readContext.range_read_accum;
job->remaining = job->begin.subspan(0, job->lcp);
if (job->remaining.size() == 0) {
MUSTTAIL return doneCommonPrefixIterTable[job->n->getType()](job, context);
}
auto [c, maxV] = getChildAndMaxVersion(job->n, job->remaining[0]);
job->maxV = maxV;
job->child = c;
if (job->child == nullptr) {
MUSTTAIL return doneCommonPrefixIterTable[job->n->getType()](job, context);
}
job->continuation = commonPrefixIterTable[c.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
// Advance down common prefix, but stay on a physical path in the tree
template <class NodeT> void common_prefix_iter(Job *job, Context *context) {
assert(NodeT::kType == job->child->getType());
NodeT *child = static_cast<NodeT *>(job->child);
if (child->partialKeyLen > 0) {
int cl = std::min<int>(child->partialKeyLen, job->remaining.size() - 1);
int i =
longestCommonPrefix(child->partialKey(), job->remaining.data() + 1, cl);
if (i != child->partialKeyLen) {
MUSTTAIL return doneCommonPrefixIterTable[job->n->getType()](job,
context);
}
}
job->n = child;
job->remaining = job->remaining.subspan(1 + child->partialKeyLen,
job->remaining.size() -
(1 + child->partialKeyLen));
if (job->maxV <= job->readVersion) {
job->setResult(true);
++context->readContext.range_read_short_circuit_accum;
MUSTTAIL return complete(job, context);
}
++context->readContext.range_read_iterations_accum;
if (job->remaining.size() == 0) {
MUSTTAIL return done_common_prefix_iter<NodeT>(job, context);
}
auto [c, maxV] = getChildAndMaxVersion(child, job->remaining[0]);
job->maxV = maxV;
job->child = c;
if (job->child == nullptr) {
MUSTTAIL return done_common_prefix_iter<NodeT>(job, context);
}
job->continuation = commonPrefixIterTable[c.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT>
PRESERVE_NONE void done_common_prefix_iter(Job *job, Context *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
{
Arena arena;
assert(getSearchPath(arena, n) <=>
job->begin.subspan(0, job->lcp - job->remaining.size()) ==
0);
}
const int consumed = job->lcp - job->remaining.size();
assume(consumed >= 0);
job->begin = job->begin.subspan(consumed, job->begin.size() - consumed);
job->end = job->end.subspan(consumed, job->end.size() - consumed);
job->lcp -= consumed;
job->commonPrefixNode = n;
if (job->lcp == job->begin.size()) {
job->remaining = job->end;
if (job->lcp == 0) {
if (n->entryPresent && n->entry.pointVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (!checkMaxBetweenExclusive(n, -1, job->remaining[0], job->readVersion,
&context->readContext)) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
}
// This is a hack
--job->lcp;
auto c = getChild(n, job->remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, job->remaining[0]);
if (c != nullptr) {
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
MUSTTAIL return downLeftSpineTable[c.getType()](job, context);
}
}
job->n = child;
job->continuation = rightSideIterTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
// If this were not true we would have returned above
assert(job->begin.size() > 0);
if (!checkRangeStartsWith(n, job->begin.subspan(0, job->lcp),
job->begin[job->lcp], job->end[job->lcp],
job->readVersion, &context->readContext)) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
job->remaining = job->begin;
auto [c, maxV] = getChildAndMaxVersion(n, job->remaining[0]);
job->maxV = maxV;
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, job->remaining[0]);
if (c != nullptr) {
job->n = c;
job->continuation = leftSideDownLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
MUSTTAIL return done_left_side_iter(job, context);
}
job->continuation = leftSideDownLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->n = child;
job->continuation = leftSideIterTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are >= key is <= readVersion
template <class NodeT>
PRESERVE_NONE void left_side_iter(Job *job, Context *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
job->remaining = job->remaining.subspan(1, job->remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, job->remaining.size());
int i =
longestCommonPrefix(n->partialKey(), job->remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> job->remaining[i];
if (c > 0) {
if (n->parent == job->commonPrefixNode) {
if (i < job->lcp) {
MUSTTAIL return left_side_down_left_spine<NodeT>(job, context);
}
}
if (n->entryPresent && n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (job->maxV > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
MUSTTAIL return done_left_side_iter(job, context);
} else {
auto c = nextSibling(n);
job->n = c;
if (job->n == nullptr) {
MUSTTAIL return done_left_side_iter(job, context);
}
job->continuation = leftSideDownLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
job->remaining =
job->remaining.subspan(commonLen, job->remaining.size() - commonLen);
} else if (n->partialKeyLen > job->remaining.size()) {
if (n->entryPresent && n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (job->maxV > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
MUSTTAIL return done_left_side_iter(job, context);
}
}
if (job->maxV <= job->readVersion) {
MUSTTAIL return done_left_side_iter(job, context);
}
++context->readContext.range_read_iterations_accum;
if (job->remaining.size() == 0) {
assert(job->maxV > job->readVersion);
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (!checkMaxBetweenExclusive(n, job->remaining[0], 256, job->readVersion,
&context->readContext)) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
auto [c, maxV] = getChildAndMaxVersion(job->n, job->remaining[0]);
job->maxV = maxV;
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, job->remaining[0]);
if (c != nullptr) {
job->n = c;
MUSTTAIL return done_left_side_iter(job, context);
} else {
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
MUSTTAIL return done_left_side_iter(job, context);
}
job->continuation = leftSideDownLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->n = child;
job->continuation = leftSideIterTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
PRESERVE_NONE void done_left_side_iter(Job *job, Context *context) {
job->n = job->commonPrefixNode;
job->remaining = job->end;
auto c = getChild(job->n, job->remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(job->n, job->remaining[0]);
if (c != nullptr) {
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
MUSTTAIL return keepGoing(job, context);
} else {
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = downLeftSpineTable[c.getType()];
MUSTTAIL return keepGoing(job, context);
}
}
job->n = child;
job->continuation = rightSideIterTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT>
void left_side_down_left_spine(Job *job, Context *context) {
assert(job->n->getType() == NodeT::kType);
NodeT *n = static_cast<NodeT *>(job->n);
if (n->entryPresent) {
if (n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
MUSTTAIL return done_left_side_iter(job, context);
}
auto c = getFirstChild(n);
job->n = c;
job->continuation = leftSideDownLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are < key is <= readVersion
template <class NodeT>
PRESERVE_NONE void right_side_iter(Job *job, Context *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
job->remaining = job->remaining.subspan(1, job->remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, job->remaining.size());
int i =
longestCommonPrefix(n->partialKey(), job->remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> job->remaining[i];
if (c > 0) {
MUSTTAIL return down_left_spine<NodeT>(job, context);
} else {
if ((n->parent != job->commonPrefixNode || i >= job->lcp) &&
n->entryPresent && n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if ((n->parent != job->commonPrefixNode || i >= job->lcp) &&
maxVersion(n) > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
job->remaining =
job->remaining.subspan(commonLen, job->remaining.size() - commonLen);
} else if (n->partialKeyLen > job->remaining.size()) {
MUSTTAIL return down_left_spine<NodeT>(job, context);
}
}
++context->readContext.range_read_iterations_accum;
if (job->remaining.size() == 0) {
MUSTTAIL return down_left_spine<NodeT>(job, context);
}
if (n->entryPresent && n->entry.pointVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (!checkMaxBetweenExclusive(n, -1, job->remaining[0], job->readVersion,
&context->readContext)) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (n->entryPresent && n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
auto c = getChild(job->n, job->remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, job->remaining[0]);
if (c != nullptr) {
job->n = c;
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
auto c = nextSibling(job->n);
job->n = c;
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = downLeftSpineTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->n = child;
job->continuation = rightSideIterTable[c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
} // namespace range_read_state_machine
void Job::init(const ConflictSet::ReadRange *read, ConflictSet::Result *result,
Node *root, int64_t oldestVersionFullPrecision) {
auto begin = TrivialSpan(read->begin.p, read->begin.len);
auto end = TrivialSpan(read->end.p, read->end.len);
if (read->readVersion < oldestVersionFullPrecision) [[unlikely]] {
*result = ConflictSet::TooOld;
continuation = complete;
} else if (end.size() == 0) {
this->begin = begin;
this->n = root;
this->readVersion = InternalVersionT(read->readVersion);
this->result = result;
continuation = check::point_read_state_machine::begin;
} else {
this->begin = begin;
this->end = end;
this->n = root;
this->readVersion = InternalVersionT(read->readVersion);
this->result = result;
continuation = check::range_read_state_machine::begin;
}
}
} // namespace check
namespace interleaved_insert {
typedef PRESERVE_NONE void (*Continuation)(struct Job *, struct Context *);
struct Result {
Result() = default;
Result(Node *insertionPoint, TrivialSpan remaining)
: insertionPoint(insertionPoint), remaining(remaining),
endInsertionPoint(nullptr) {}
Result(Node *insertionPoint, TrivialSpan remaining, Node *endInsertionPoint,
TrivialSpan endRemaining)
: insertionPoint(insertionPoint), remaining(remaining),
endInsertionPoint(endInsertionPoint), endRemaining(endRemaining) {}
Node *insertionPoint;
TrivialSpan remaining;
Node *endInsertionPoint; // Range write only
TrivialSpan endRemaining; // Range write only
Result *nextRangeWrite; // Linked list to skip over point writes in phase 3.
// Populated in phase 2
};
static_assert(std::is_trivial_v<Result>);
// State relevant to an individual insertion
struct Job {
TrivialSpan remaining;
Node *n;
TaggedNodePointer child;
int childIndex;
Result *result;
TrivialSpan begin; // Range write only
TrivialSpan end; // Range write only
Node *endNode; // Range write only
int commonPrefixLen; // Range write only
// State for context switching machinery - not application specific
Continuation continuation;
Job *prev;
Job *next;
void init(Context *, int index);
bool getChildAndIndex(Node0 *, uint8_t) { return false; }
bool getChildAndIndex(Node3 *self, uint8_t index) {
childIndex = getNodeIndex(self, index);
if (childIndex >= 0) {
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node16 *self, uint8_t index) {
childIndex = getNodeIndex(self, index);
if (childIndex >= 0) {
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node48 *self, uint8_t index) {
childIndex = self->index[index];
if (childIndex >= 0) {
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node256 *self, uint8_t i) {
child = self->children[i];
if (child != nullptr) {
childIndex = i;
child = self->children[childIndex];
return true;
}
return false;
}
bool getChildAndIndex(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChildAndIndex(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChildAndIndex(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChildAndIndex(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChildAndIndex(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChildAndIndex(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
};
// State relevant to every insertion
struct Context {
int count;
int64_t started;
const ConflictSet::WriteRange *writes;
Node *root;
InternalVersionT writeVersion;
Result *results;
int64_t iterations = 0;
};
PRESERVE_NONE void keepGoing(Job *job, Context *context) {
job = job->next;
MUSTTAIL return job->continuation(job, context);
}
PRESERVE_NONE void complete(Job *job, Context *context) {
if (context->started == context->count) {
if (job->prev == job) {
return;
}
job->prev->next = job->next;
job->next->prev = job->prev;
job = job->next;
MUSTTAIL return job->continuation(job, context);
} else {
int temp = context->started++;
job->init(context, temp);
MUSTTAIL return job->continuation(job, context);
}
}
template <class NodeTFrom, class NodeTTo>
PRESERVE_NONE void pointIter(Job *, Context *);
template <class NodeTFrom> struct PointIterTable {
static constexpr Continuation table[] = {
pointIter<NodeTFrom, Node0>, pointIter<NodeTFrom, Node3>,
pointIter<NodeTFrom, Node16>, pointIter<NodeTFrom, Node48>,
pointIter<NodeTFrom, Node256>};
};
static constexpr Continuation const *pointIterTable[] = {
PointIterTable<Node0>::table, PointIterTable<Node3>::table,
PointIterTable<Node16>::table, PointIterTable<Node48>::table,
PointIterTable<Node256>::table,
};
template <class NodeTFrom, class NodeTTo>
void pointIter(Job *job, Context *context) {
assert(NodeTFrom::kType == job->n->getType());
NodeTFrom *n = static_cast<NodeTFrom *>(job->n);
assert(NodeTTo::kType == job->child->getType());
NodeTTo *child = static_cast<NodeTTo *>(job->child);
auto key = job->remaining.subspan(1, job->remaining.size() - 1);
if (child->partialKeyLen > 0) {
int commonLen = std::min<int>(child->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(child->partialKey(), key.data(), commonLen);
if (partialKeyIndex < child->partialKeyLen) {
*job->result = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
}
// child is on the search path. Commit to advancing and updating max version
job->n = child;
job->remaining =
key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen);
if constexpr (std::is_same_v<NodeTFrom, Node3> ||
std::is_same_v<NodeTFrom, Node16>) {
n->childMaxVersion[job->childIndex] = context->writeVersion;
} else if constexpr (std::is_same_v<NodeTFrom, Node48> ||
std::is_same_v<NodeTFrom, Node256>) {
n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] =
std::max(n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift],
context->writeVersion);
n->childMaxVersion[job->childIndex] = context->writeVersion;
}
if (job->remaining.size() == 0) [[unlikely]] {
*job->result = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
++context->iterations;
if (!job->getChildAndIndex(child, job->remaining.front())) [[unlikely]] {
*job->result = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
job->continuation = PointIterTable<NodeTTo>::table[job->child.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeTFrom, class NodeTTo>
PRESERVE_NONE void prefixIter(Job *, Context *);
template <class NodeTFrom, class NodeTTo>
PRESERVE_NONE void beginIter(Job *, Context *);
template <class NodeTFrom, class NodeTTo>
PRESERVE_NONE void endIter(Job *, Context *);
template <class NodeTFrom> struct PrefixIterTable {
static constexpr Continuation table[] = {
prefixIter<NodeTFrom, Node0>, prefixIter<NodeTFrom, Node3>,
prefixIter<NodeTFrom, Node16>, prefixIter<NodeTFrom, Node48>,
prefixIter<NodeTFrom, Node256>};
};
static constexpr Continuation const *prefixIterTable[] = {
PrefixIterTable<Node0>::table, PrefixIterTable<Node3>::table,
PrefixIterTable<Node16>::table, PrefixIterTable<Node48>::table,
PrefixIterTable<Node256>::table,
};
template <class NodeTFrom> struct BeginIterTable {
static constexpr Continuation table[] = {
beginIter<NodeTFrom, Node0>, beginIter<NodeTFrom, Node3>,
beginIter<NodeTFrom, Node16>, beginIter<NodeTFrom, Node48>,
beginIter<NodeTFrom, Node256>};
};
static constexpr Continuation const *beginIterTable[] = {
BeginIterTable<Node0>::table, BeginIterTable<Node3>::table,
BeginIterTable<Node16>::table, BeginIterTable<Node48>::table,
BeginIterTable<Node256>::table,
};
template <class NodeTFrom> struct EndIterTable {
static constexpr Continuation table[] = {
endIter<NodeTFrom, Node0>, endIter<NodeTFrom, Node3>,
endIter<NodeTFrom, Node16>, endIter<NodeTFrom, Node48>,
endIter<NodeTFrom, Node256>};
};
static constexpr Continuation const *endIterTable[] = {
EndIterTable<Node0>::table, EndIterTable<Node3>::table,
EndIterTable<Node16>::table, EndIterTable<Node48>::table,
EndIterTable<Node256>::table,
};
template <class NodeTFrom, class NodeTTo>
void prefixIter(Job *job, Context *context) {
assert(NodeTFrom::kType == job->n->getType());
NodeTFrom *n = static_cast<NodeTFrom *>(job->n);
assert(NodeTTo::kType == job->child->getType());
NodeTTo *child = static_cast<NodeTTo *>(job->child);
auto key = job->remaining.subspan(1, job->remaining.size() - 1);
if (child->partialKeyLen > 0) {
int commonLen = std::min<int>(child->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(child->partialKey(), key.data(), commonLen);
if (partialKeyIndex < child->partialKeyLen) {
goto noNodeOnSearchPath;
}
}
// child is on the search path. Commit to advancing and updating max version
job->n = child;
job->remaining =
key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen);
if constexpr (std::is_same_v<NodeTFrom, Node3> ||
std::is_same_v<NodeTFrom, Node16>) {
n->childMaxVersion[job->childIndex] = context->writeVersion;
} else if constexpr (std::is_same_v<NodeTFrom, Node48> ||
std::is_same_v<NodeTFrom, Node256>) {
n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] =
std::max(n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift],
context->writeVersion);
n->childMaxVersion[job->childIndex] = context->writeVersion;
}
if (job->remaining.size() == 0) [[unlikely]] {
job->endNode = job->n;
job->begin = job->begin.subspan(job->commonPrefixLen,
job->begin.size() - job->commonPrefixLen);
job->end = job->end.subspan(job->commonPrefixLen,
job->end.size() - job->commonPrefixLen);
if (job->begin.size() == 0) [[unlikely]] {
goto gotoEndIter;
} else if (!job->getChildAndIndex(child, job->begin.front())) [[unlikely]] {
goto gotoEndIter;
} else {
job->continuation = BeginIterTable<NodeTTo>::table[job->child.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
}
++context->iterations;
if (!job->getChildAndIndex(child, job->remaining.front())) [[unlikely]] {
goto noNodeOnSearchPath;
}
job->continuation = PrefixIterTable<NodeTTo>::table[job->child.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
noNodeOnSearchPath: {
int prefixLen = job->commonPrefixLen - job->remaining.size();
assert(prefixLen >= 0);
assert(job->n != nullptr);
*job->result = {
job->n,
job->begin.subspan(prefixLen, job->begin.size() - prefixLen),
job->n,
job->end.subspan(prefixLen, job->end.size() - prefixLen),
};
MUSTTAIL return complete(job, context);
}
gotoEndIter:
if (!job->getChildAndIndex(child, job->end.front())) [[unlikely]] {
*job->result = {
job->n,
job->begin,
job->n,
job->end,
};
MUSTTAIL return complete(job, context);
} else {
job->continuation = EndIterTable<NodeTTo>::table[job->child.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
}
template <class NodeTFrom, class NodeTTo>
void beginIter(Job *job, Context *context) {
assert(NodeTFrom::kType == job->n->getType());
NodeTFrom *n = static_cast<NodeTFrom *>(job->n);
assert(NodeTTo::kType == job->child->getType());
NodeTTo *child = static_cast<NodeTTo *>(job->child);
auto key = job->begin.subspan(1, job->begin.size() - 1);
if (child->partialKeyLen > 0) {
int commonLen = std::min<int>(child->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(child->partialKey(), key.data(), commonLen);
if (partialKeyIndex < child->partialKeyLen) {
goto gotoEndIter;
}
}
// child is on the search path. Commit to advancing and updating max version
job->n = child;
job->begin =
key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen);
if constexpr (std::is_same_v<NodeTFrom, Node3> ||
std::is_same_v<NodeTFrom, Node16>) {
n->childMaxVersion[job->childIndex] = context->writeVersion;
} else if constexpr (std::is_same_v<NodeTFrom, Node48> ||
std::is_same_v<NodeTFrom, Node256>) {
n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] =
std::max(n->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift],
context->writeVersion);
n->childMaxVersion[job->childIndex] = context->writeVersion;
}
if (job->begin.size() == 0) [[unlikely]] {
goto gotoEndIter;
}
++context->iterations;
if (!job->getChildAndIndex(child, job->begin.front())) [[unlikely]] {
goto gotoEndIter;
}
job->continuation = BeginIterTable<NodeTTo>::table[job->child.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
gotoEndIter:
if (!job->getChildAndIndex(job->endNode, job->end.front())) [[unlikely]] {
*job->result = {
job->n,
job->begin,
job->endNode,
job->end,
};
MUSTTAIL return complete(job, context);
} else {
MUSTTAIL return endIterTable[job->endNode->getType()][job->child.getType()](
job, context);
}
}
template <class NodeTFrom, class NodeTTo>
void endIter(Job *job, Context *context) {
assert(NodeTFrom::kType == job->endNode->getType());
NodeTFrom *endNode = static_cast<NodeTFrom *>(job->endNode);
assert(NodeTTo::kType == job->child->getType());
NodeTTo *child = static_cast<NodeTTo *>(job->child);
auto key = job->end.subspan(1, job->end.size() - 1);
if (child->partialKeyLen > 0) {
int commonLen = std::min<int>(child->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(child->partialKey(), key.data(), commonLen);
if (partialKeyIndex < child->partialKeyLen) {
*job->result = {job->n, job->begin, job->endNode, job->end};
assert(job->endNode != nullptr);
MUSTTAIL return complete(job, context);
}
}
// child is on the search path. Commit to advancing and updating max version
job->endNode = child;
job->end =
key.subspan(child->partialKeyLen, key.size() - child->partialKeyLen);
if constexpr (std::is_same_v<NodeTFrom, Node3> ||
std::is_same_v<NodeTFrom, Node16>) {
endNode->childMaxVersion[job->childIndex] = context->writeVersion;
} else if constexpr (std::is_same_v<NodeTFrom, Node48> ||
std::is_same_v<NodeTFrom, Node256>) {
endNode->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift] = std::max(
endNode->maxOfMax[job->childIndex >> NodeTFrom::kMaxOfMaxShift],
context->writeVersion);
endNode->childMaxVersion[job->childIndex] = context->writeVersion;
}
if (job->end.size() == 0) [[unlikely]] {
*job->result = {job->n, job->begin, job->endNode, job->end};
assert(job->endNode != nullptr);
MUSTTAIL return complete(job, context);
}
++context->iterations;
if (!job->getChildAndIndex(child, job->end.front())) [[unlikely]] {
*job->result = {job->n, job->begin, job->endNode, job->end};
assert(job->endNode != nullptr);
MUSTTAIL return complete(job, context);
}
job->continuation = EndIterTable<NodeTTo>::table[job->child.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
void Job::init(Context *context, int index) {
result = context->results + index;
n = context->root;
if (context->writes[index].end.len == 0) {
goto pointWrite;
}
begin = TrivialSpan(context->writes[index].begin.p,
context->writes[index].begin.len);
end =
TrivialSpan(context->writes[index].end.p, context->writes[index].end.len);
commonPrefixLen = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (commonPrefixLen == begin.size() && end.size() == begin.size() + 1 &&
end.back() == 0) {
goto pointWrite;
}
remaining = TrivialSpan(context->writes[index].begin.p, commonPrefixLen);
if (commonPrefixLen > 0) {
// common prefix iter will set endNode
if (!getChildAndIndex(n, remaining.front())) [[unlikely]] {
*result = {
n,
begin,
n,
end,
};
continuation = complete;
} else {
continuation = prefixIterTable[n->getType()][child.getType()];
}
} else if (begin.size() > 0 && getChildAndIndex(n, begin.front())) {
endNode = n;
continuation = beginIterTable[n->getType()][child.getType()];
} else {
assert(end.size() > 0);
endNode = n;
if (!getChildAndIndex(n, end.front())) [[unlikely]] {
*result = {
n,
begin,
n,
end,
};
continuation = complete;
} else {
continuation = endIterTable[n->getType()][child.getType()];
}
}
return;
pointWrite:
remaining = TrivialSpan(context->writes[index].begin.p,
context->writes[index].begin.len);
if (remaining.size() == 0) [[unlikely]] {
*result = {n, remaining};
continuation = complete;
} else {
if (!getChildAndIndex(n, remaining.front())) [[unlikely]] {
*result = {n, remaining};
continuation = complete;
} else {
continuation = pointIterTable[n->getType()][child.getType()];
}
}
}
} // namespace interleaved_insert
// Sequential implementations
namespace {
// Logically this is the same as performing firstGeq and then checking against
// point or range version according to cmp, but this version short circuits as
// soon as it can prove that there's no conflict.
bool checkPointRead(Node *n, const TrivialSpan key,
InternalVersionT readVersion, ReadContext *readContext) {
++readContext->point_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;; ++readContext->point_read_iterations_accum) {
if (remaining.size() == 0) {
if (n->entryPresent) {
return n->entry.pointVersion <= readVersion;
}
n = getFirstChild(n);
goto downLeftSpine;
}
auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > remaining.size()) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
if (maxV <= readVersion) {
++readContext->point_read_short_circuit_accum;
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChild(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
// Logically this is the same as performing firstGeq and then checking against
// max version or range version if this prefix doesn't exist, but this version
// short circuits as soon as it can prove that there's no conflict.
bool checkPrefixRead(Node *n, const TrivialSpan key,
InternalVersionT readVersion, ReadContext *readContext) {
++readContext->prefix_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;; ++readContext->prefix_read_iterations_accum) {
if (remaining.size() == 0) {
// There's no way to encode a prefix read of "", so n is not the root
return maxVersion(n) <= readVersion;
}
auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > remaining.size()) {
// n is the first physical node greater than remaining, and there's no
// eq node. All physical nodes that start with prefix are reachable from
// n.
if (maxVersion(n) > readVersion) {
return false;
}
goto downLeftSpine;
}
}
if (maxV <= readVersion) {
++readContext->prefix_read_short_circuit_accum;
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChild(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are >= key is <= readVersion
bool checkRangeLeftSide(Node *n, TrivialSpan key, int prefixLen,
InternalVersionT readVersion,
ReadContext *readContext) {
auto remaining = key;
int searchPathLen = 0;
for (;; ++readContext->range_read_iterations_accum) {
if (remaining.size() == 0) {
assert(searchPathLen >= prefixLen);
return maxVersion(n) <= readVersion;
}
if (searchPathLen >= prefixLen) {
if (!checkMaxBetweenExclusive(n, remaining[0], 256, readVersion,
readContext)) {
return false;
}
}
auto [c, maxV] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
if (searchPathLen < prefixLen) {
n = c;
goto downLeftSpine;
}
n = c;
return maxVersion(n) <= readVersion;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
if (searchPathLen < prefixLen) {
goto downLeftSpine;
}
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n) <= readVersion;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > remaining.size()) {
assert(searchPathLen >= prefixLen);
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n) <= readVersion;
}
}
if (maxV <= readVersion) {
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChild(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are < key is <= readVersion
bool checkRangeRightSide(Node *n, TrivialSpan key, int prefixLen,
InternalVersionT readVersion,
ReadContext *readContext) {
auto remaining = key;
int searchPathLen = 0;
for (;; ++readContext->range_read_iterations_accum) {
assert(searchPathLen <= key.size());
if (remaining.size() == 0) {
goto downLeftSpine;
}
if (searchPathLen >= prefixLen) {
if (n->entryPresent && n->entry.pointVersion > readVersion) {
return false;
}
if (!checkMaxBetweenExclusive(n, -1, remaining[0], readVersion,
readContext)) {
return false;
}
}
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
return false;
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
goto backtrack;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
++searchPathLen;
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
return false;
}
goto backtrack;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > remaining.size()) {
goto downLeftSpine;
}
}
}
backtrack:
for (;;) {
// searchPathLen > prefixLen implies n is not the root
if (searchPathLen > prefixLen && maxVersion(n) > readVersion) {
return false;
}
if (n->parent == nullptr) {
return true;
}
auto next = getChildGeq(n->parent, n->parentsIndex + 1);
if (next == nullptr) {
searchPathLen -= 1 + n->partialKeyLen;
n = n->parent;
} else {
n = next;
goto downLeftSpine;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChild(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
bool checkRangeRead(Node *n, TrivialSpan begin, TrivialSpan end,
InternalVersionT readVersion, ReadContext *readContext) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == begin.size() && end.size() == begin.size() + 1 &&
end.back() == 0) {
return checkPointRead(n, begin, readVersion, readContext);
}
if (lcp == begin.size() - 1 && end.size() == begin.size() &&
begin.back() + 1 == end.back()) {
return checkPrefixRead(n, begin, readVersion, readContext);
}
++readContext->range_read_accum;
auto remaining = begin.subspan(0, lcp);
Arena arena;
// Advance down common prefix, but stay on a physical path in the tree
for (;; ++readContext->range_read_iterations_accum) {
assert(getSearchPath(arena, n) <=>
begin.subspan(0, lcp - remaining.size()) ==
0);
if (remaining.size() == 0) {
break;
}
auto [c, v] = getChildAndMaxVersion(n, remaining[0]);
Node *child = c;
if (child == nullptr) {
break;
}
if (child->partialKeyLen > 0) {
int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1);
int i =
longestCommonPrefix(child->partialKey(), remaining.data() + 1, cl);
if (i != child->partialKeyLen) {
break;
}
}
if (v <= readVersion) {
++readContext->range_read_short_circuit_accum;
return true;
}
n = child;
remaining =
remaining.subspan(1 + child->partialKeyLen,
remaining.size() - (1 + child->partialKeyLen));
}
assert(getSearchPath(arena, n) <=> begin.subspan(0, lcp - remaining.size()) ==
0);
const int consumed = lcp - remaining.size();
assume(consumed >= 0);
begin = begin.subspan(consumed, begin.size() - consumed);
end = end.subspan(consumed, end.size() - consumed);
lcp -= consumed;
if (lcp == begin.size()) {
return checkRangeRightSide(n, end, lcp, readVersion, readContext);
}
// This makes it safe to check maxVersion within checkRangeLeftSide. If this
// were false, then we would have returned above since lcp == begin.size().
assert(!(n->parent == nullptr && begin.size() == 0));
return checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp],
readVersion, readContext) &&
checkRangeLeftSide(n, begin, lcp + 1, readVersion, readContext) &&
checkRangeRightSide(n, end, lcp + 1, readVersion, readContext);
}
} // namespace
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
// We still have the sequential implementation for compilers that don't
// support preserve_none and musttail
void useSequential(const ReadRange *reads, Result *result, int count,
check::Context &context) {
for (int i = 0; i < count; ++i) {
if (reads[i].readVersion < oldestVersionFullPrecision) [[unlikely]] {
result[i] = TooOld;
} else {
bool ok;
if (reads[i].end.len == 0) {
ok = checkPointRead(
root, TrivialSpan(reads[i].begin.p, reads[i].begin.len),
InternalVersionT(reads[i].readVersion), &context.readContext);
} else {
ok = checkRangeRead(
root, TrivialSpan(reads[i].begin.p, reads[i].begin.len),
TrivialSpan(reads[i].end.p, reads[i].end.len),
InternalVersionT(reads[i].readVersion), &context.readContext);
}
result[i] = ok ? Commit : Conflict;
}
}
}
void check(const ReadRange *reads, Result *result, int count) {
assert(oldestVersionFullPrecision >=
newestVersionFullPrecision - kNominalVersionWindow);
if (count == 0) {
return;
}
int64_t check_byte_accum = 0;
check::Context context;
context.readContext.impl = this;
if constexpr (kEnableInterleaved) {
if (count == 1) {
useSequential(reads, result, count, context);
} else {
constexpr int kConcurrent = 16;
check::Job inProgress[kConcurrent];
context.count = count;
context.oldestVersionFullPrecision = oldestVersionFullPrecision;
context.root = root;
context.queries = reads;
context.results = result;
int64_t started = std::min(kConcurrent, count);
context.started = started;
for (int i = 0; i < started; i++) {
inProgress[i].init(reads + i, result + i, root,
oldestVersionFullPrecision);
}
for (int i = 0; i < started - 1; i++) {
inProgress[i].next = inProgress + i + 1;
}
for (int i = 1; i < started; i++) {
inProgress[i].prev = inProgress + i - 1;
}
inProgress[0].prev = inProgress + started - 1;
inProgress[started - 1].next = inProgress;
// Kick off the sequence of tail calls that finally returns once all
// jobs are done
inProgress->continuation(inProgress, &context);
#ifndef NDEBUG
Arena arena;
auto *results2 = new (arena) Result[count];
check::Context context2;
context2.readContext.impl = this;
useSequential(reads, results2, count, context2);
assert(memcmp(result, results2, count) == 0);
assert(context.readContext == context2.readContext);
#endif
}
} else {
useSequential(reads, result, count, context);
}
for (int i = 0; i < count; ++i) {
assert(reads[i].readVersion >= 0);
assert(reads[i].readVersion <= newestVersionFullPrecision);
const auto &r = reads[i];
check_byte_accum += r.begin.len + r.end.len;
context.readContext.commits_accum += result[i] == Commit;
context.readContext.conflicts_accum += result[i] == Conflict;
context.readContext.too_olds_accum += result[i] == TooOld;
}
point_read_total.add(context.readContext.point_read_accum);
prefix_read_total.add(context.readContext.prefix_read_accum);
range_read_total.add(context.readContext.range_read_accum);
range_read_node_scan_total.add(
context.readContext.range_read_node_scan_accum);
point_read_short_circuit_total.add(
context.readContext.point_read_short_circuit_accum);
prefix_read_short_circuit_total.add(
context.readContext.prefix_read_short_circuit_accum);
range_read_short_circuit_total.add(
context.readContext.range_read_short_circuit_accum);
point_read_iterations_total.add(
context.readContext.point_read_iterations_accum);
prefix_read_iterations_total.add(
context.readContext.prefix_read_iterations_accum);
range_read_iterations_total.add(
context.readContext.range_read_iterations_accum);
commits_total.add(context.readContext.commits_accum);
conflicts_total.add(context.readContext.conflicts_accum);
too_olds_total.add(context.readContext.too_olds_accum);
check_bytes_total.add(check_byte_accum);
}
void interleavedWrites(const WriteRange *writes, int count,
InternalVersionT writeVersion) {
// Phase 1: Search for insertion points concurrently, without modifying the
// structure of the tree.
assert(count > 1);
constexpr int kStackResultMax = 100;
interleaved_insert::Result stackResults[kStackResultMax];
constexpr int kConcurrent = 16;
interleaved_insert::Job inProgress[kConcurrent];
interleaved_insert::Context context;
context.writeVersion = writeVersion;
context.count = count;
context.root = root;
context.writes = writes;
context.results = stackResults;
if (count > kStackResultMax) [[unlikely]] {
context.results = (interleaved_insert::Result *)safe_malloc(
count * sizeof(interleaved_insert::Result));
}
int64_t started = std::min(kConcurrent, count);
context.started = started;
for (int i = 0; i < started; i++) {
inProgress[i].init(&context, i);
}
for (int i = 0; i < started - 1; i++) {
inProgress[i].next = inProgress + i + 1;
}
for (int i = 1; i < started; i++) {
inProgress[i].prev = inProgress + i - 1;
}
inProgress[0].prev = inProgress + started - 1;
inProgress[started - 1].next = inProgress;
// Kick off the sequence of tail calls that finally returns once all jobs
// are done
inProgress->continuation(inProgress, &context);
writeContext.accum.insert_iterations += context.iterations;
// Phase 2: Perform insertions. Nodes may be upsized during this phase, but
// old nodes get forwarding pointers installed and are released after
// phase 2. The search path of a node does not change in this phase.
interleaved_insert::Result *firstRangeWrite = nullptr;
interleaved_insert::Result *lastRangeWrite;
for (int i = 0; i < count; ++i) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
{
Node *b = context.results[i].insertionPoint;
Node *e = context.results[i].endInsertionPoint;
while (b->releaseDeferred) {
b = b->forwardTo;
}
if (e != nullptr) {
while (e->releaseDeferred) {
e = b->forwardTo;
}
}
fprintf(stderr, "search path: %s, begin: %s\n",
getSearchPathPrintable(b).c_str(),
printable(context.results[i].remaining).c_str());
fprintf(stderr, "search path: %s, end: %s\n",
getSearchPathPrintable(e).c_str(),
printable(context.results[i].endRemaining).c_str());
}
#endif
while (context.results[i].insertionPoint->releaseDeferred) {
context.results[i].insertionPoint =
context.results[i].insertionPoint->forwardTo;
}
if (context.results[i].endInsertionPoint == nullptr) {
addPointWrite(getInTree(context.results[i].insertionPoint, this),
context.results[i].remaining, writeVersion, &writeContext,
this);
} else {
if (firstRangeWrite == nullptr) {
firstRangeWrite = context.results + i;
} else {
lastRangeWrite->nextRangeWrite = context.results + i;
}
lastRangeWrite = context.results + i;
while (context.results[i].endInsertionPoint->releaseDeferred) {
context.results[i].endInsertionPoint =
context.results[i].endInsertionPoint->forwardTo;
}
auto [beginNode, endNode] = addWriteRange(
context.results[i].insertionPoint, context.results[i].remaining,
context.results[i].endInsertionPoint,
context.results[i].endRemaining, writeVersion, &writeContext, this);
context.results[i].insertionPoint = beginNode;
context.results[i].endInsertionPoint = endNode;
}
}
if (firstRangeWrite != nullptr) {
lastRangeWrite->nextRangeWrite = nullptr;
}
// Phase 3: Erase nodes within written ranges. Going left to right ensures
// that nothing later is on the search path of anything earlier, so we don't
// encounter invalidated nodes.
for (auto *iter = firstRangeWrite; iter != nullptr;
iter = iter->nextRangeWrite) {
if (iter->endInsertionPoint != nullptr) {
while (iter->insertionPoint->releaseDeferred) {
iter->insertionPoint = iter->insertionPoint->forwardTo;
}
while (iter->endInsertionPoint->releaseDeferred) {
iter->endInsertionPoint = iter->endInsertionPoint->forwardTo;
}
eraseInRange(iter->insertionPoint, iter->endInsertionPoint,
&writeContext, this);
}
}
if (count > kStackResultMax) [[unlikely]] {
safe_free(context.results, count * sizeof(interleaved_insert::Result));
}
}
void insertPointWritesOrSorted(const WriteRange *writes, int count,
InternalVersionT writeVersion) {
#ifndef NDEBUG
bool allPointWrites = true;
for (int i = 0; i < count; ++i) {
allPointWrites = allPointWrites && writes[i].end.len == 0;
}
bool sorted = true;
for (int i = 1; i < count; ++i) {
sorted = sorted && writes[i - 1] < writes[i];
}
assert(allPointWrites || sorted);
#endif
if (kEnableInterleaved && count > 1) {
interleavedWrites(writes, count, InternalVersionT(writeVersion));
} else {
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
auto begin = TrivialSpan(w.begin.p, w.begin.len);
auto end = TrivialSpan(w.end.p, w.end.len);
if (w.end.len > 0) {
addWriteRange(root, begin, end, InternalVersionT(writeVersion),
&writeContext, this);
} else {
addPointWrite(root, begin, InternalVersionT(writeVersion),
&writeContext, this);
}
}
}
}
void addWrites(const WriteRange *writes, int count, int64_t writeVersion) {
#if !USE_64_BIT
// There could be other conflict sets in the same thread. We need
// InternalVersionT::zero to be correct for this conflict set for the
// lifetime of the current call frame.
InternalVersionT::zero = writeContext.zero = oldestVersion;
#endif
assert(writeVersion >= newestVersionFullPrecision);
assert(writeContext.accum.entries_erased == 0);
assert(writeContext.accum.entries_inserted == 0);
if (oldestExtantVersion < writeVersion - kMaxCorrectVersionWindow)
[[unlikely]] {
if (writeVersion > newestVersionFullPrecision + kNominalVersionWindow) {
eraseTree(root, &writeContext);
init(writeVersion - kNominalVersionWindow);
}
newestVersionFullPrecision = writeVersion;
newest_version.set(newestVersionFullPrecision);
if (newestVersionFullPrecision - kNominalVersionWindow >
oldestVersionFullPrecision) {
setOldestVersion(newestVersionFullPrecision - kNominalVersionWindow);
}
while (oldestExtantVersion <
newestVersionFullPrecision - kMaxCorrectVersionWindow) {
gcScanStep(1000);
}
} else {
newestVersionFullPrecision = writeVersion;
newest_version.set(newestVersionFullPrecision);
if (newestVersionFullPrecision - kNominalVersionWindow >
oldestVersionFullPrecision) {
setOldestVersion(newestVersionFullPrecision - kNominalVersionWindow);
}
}
for (int i = 0; i < count; ++i) {
writeContext.accum.write_bytes += writes[i].begin.len + writes[i].end.len;
}
if (count > 0) {
int firstNotInserted = 0;
bool batchHasOnlyPointWrites = writes[0].end.len == 0;
bool batchIsSorted = true;
for (int i = 1; i < count; ++i) {
batchIsSorted = batchIsSorted && writes[i - 1] < writes[i];
batchHasOnlyPointWrites =
batchHasOnlyPointWrites && writes[i].end.len == 0;
if (!(batchIsSorted || batchHasOnlyPointWrites)) {
insertPointWritesOrSorted(writes + firstNotInserted,
i - firstNotInserted,
InternalVersionT(writeVersion));
firstNotInserted = i;
batchHasOnlyPointWrites = writes[i].end.len == 0;
batchIsSorted = true;
}
}
assert(batchIsSorted || batchHasOnlyPointWrites);
insertPointWritesOrSorted(writes + firstNotInserted,
count - firstNotInserted,
InternalVersionT(writeVersion));
}
writeContext.releaseDeferred();
// Run gc at least 200% the rate we're inserting entries, and at least run
// some gc just for potentially increasing the version.
keyUpdates += std::max<int64_t>(writeContext.accum.entries_inserted -
writeContext.accum.entries_erased,
1) *
2;
point_writes_total.add(writeContext.accum.point_writes);
range_writes_total.add(writeContext.accum.range_writes);
nodes_allocated_total.add(writeContext.accum.nodes_allocated);
nodes_released_total.add(writeContext.accum.nodes_released);
entries_inserted_total.add(writeContext.accum.entries_inserted);
entries_erased_total.add(writeContext.accum.entries_erased);
insert_iterations_total.add(writeContext.accum.insert_iterations);
write_bytes_total.add(writeContext.accum.write_bytes);
memset(&writeContext.accum, 0, sizeof(writeContext.accum));
}
// Spends up to `fuel` gc'ing, and returns its unused fuel. Reclaims memory
// and updates oldestExtantVersion after spending enough fuel.
int64_t gcScanStep(int64_t fuel) {
Node *n = firstGeqPhysical(root, removalKey);
// There's no way to erase removalKey without introducing a key after it
assert(n != nullptr);
// Don't erase the root
if (n == root) {
rezero(n, oldestVersion);
n = nextPhysical(n);
}
int64_t set_oldest_iterations_accum = 0;
for (; fuel > 0 && n != nullptr; ++set_oldest_iterations_accum) {
rezero(n, oldestVersion);
// The "make sure gc keeps up with writes" calculations assume that
// we're scanning key by key, not node by node. Make sure we only spend
// fuel when there's a logical entry.
fuel -= n->entryPresent;
if (n->entryPresent && std::max(n->entry.pointVersion,
n->entry.rangeVersion) <= oldestVersion) {
// Any transaction n would have prevented from committing is
// going to fail with TooOld anyway.
// There's no way to insert a range such that range version of the
// right node is greater than the point version of the left node
assert(n->entry.rangeVersion <= oldestVersion);
n = erase(n, &writeContext, this, /*logical*/ false);
} else {
n = nextPhysical(n);
}
}
writeContext.releaseDeferred();
gc_iterations_total.add(set_oldest_iterations_accum);
if (n == nullptr) {
removalKey = {};
if (removalBufferSize > kMaxRemovalBufferSize) {
safe_free(removalBuffer, removalBufferSize);
removalBufferSize = kMinRemovalBufferSize;
removalBuffer = (uint8_t *)safe_malloc(removalBufferSize);
}
oldestExtantVersion = oldestVersionAtGcBegin;
oldest_extant_version.set(oldestExtantVersion);
oldestVersionAtGcBegin = oldestVersionFullPrecision;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr,
"new oldestExtantVersion: %" PRId64
", new oldestVersionAtGcBegin: %" PRId64 "\n",
oldestExtantVersion, oldestVersionAtGcBegin);
#endif
} else {
// Store the current search path to resume the scan later
saveRemovalKey(n);
}
return fuel;
}
void saveRemovalKey(Node *n) {
uint8_t *cursor = removalBuffer + removalBufferSize;
int size = 0;
auto reserve = [&](int delta) {
if (size + delta > removalBufferSize) [[unlikely]] {
int newBufSize = std::max(removalBufferSize * 2, size + delta);
uint8_t *newBuf = (uint8_t *)safe_malloc(newBufSize);
memcpy(newBuf + newBufSize - size, cursor, size);
safe_free(removalBuffer, removalBufferSize);
removalBuffer = newBuf;
removalBufferSize = newBufSize;
cursor = newBuf + newBufSize - size;
}
};
for (;;) {
auto partialKey = TrivialSpan{n->partialKey(), n->partialKeyLen};
reserve(partialKey.size());
size += partialKey.size();
cursor -= partialKey.size();
memcpy(cursor, partialKey.data(), partialKey.size());
if (n->parent == nullptr) {
break;
}
reserve(1);
++size;
--cursor;
*cursor = n->parentsIndex;
n = n->parent;
}
removalKey = {cursor, size};
}
void setOldestVersion(int64_t newOldestVersion) {
assert(newOldestVersion >= 0);
assert(newOldestVersion <= newestVersionFullPrecision);
// If addWrites advances oldestVersion to keep within valid window, a
// subsequent setOldestVersion can be legitimately called with a version
// older than `oldestVersionFullPrecision`. < instead of <= so that we can
// do garbage collection work without advancing the oldest version.
if (newOldestVersion < oldestVersionFullPrecision) {
return;
}
InternalVersionT oldestVersion{newOldestVersion};
this->oldestVersionFullPrecision = newOldestVersion;
this->oldestVersion = oldestVersion;
#if !USE_64_BIT
InternalVersionT::zero = writeContext.zero = oldestVersion;
#endif
#ifdef NDEBUG
// This is here for performance reasons, since we want to amortize the
// cost of storing the search path as a string. In tests, we want to
// exercise the rest of the code often.
if (keyUpdates < 100) {
return;
}
#endif
keyUpdates = gcScanStep(keyUpdates);
nodes_allocated_total.add(
std::exchange(writeContext.accum.nodes_allocated, 0));
nodes_released_total.add(
std::exchange(writeContext.accum.nodes_released, 0));
entries_inserted_total.add(
std::exchange(writeContext.accum.entries_inserted, 0));
entries_erased_total.add(
std::exchange(writeContext.accum.entries_erased, 0));
oldest_version.set(oldestVersionFullPrecision);
}
int64_t getBytes() const { return totalBytes; }
void init(int64_t oldestVersion) {
this->oldestVersion = InternalVersionT(oldestVersion);
oldestVersionFullPrecision = oldestExtantVersion = oldestVersionAtGcBegin =
newestVersionFullPrecision = oldestVersion;
oldest_version.set(oldestVersionFullPrecision);
newest_version.set(newestVersionFullPrecision);
oldest_extant_version.set(oldestExtantVersion);
writeContext.~WriteContext();
new (&writeContext) WriteContext();
// Leave removalBuffer as is
removalKey = {};
keyUpdates = 10;
// Insert ""
root = writeContext.allocate<Node0>(0, 0);
root->numChildren = 0;
root->parent = nullptr;
root->entryPresent = false;
root->partialKeyLen = 0;
addKey(root);
root->entryPresent = true;
root->entry.pointVersion = this->oldestVersion;
root->entry.rangeVersion = this->oldestVersion;
#if !USE_64_BIT
InternalVersionT::zero = writeContext.zero = this->oldestVersion;
#endif
// Intentionally not resetting totalBytes
}
explicit Impl(int64_t oldestVersion) {
assert(oldestVersion >= 0);
init(oldestVersion);
metrics = initMetrics(metricsList, metricsCount);
}
~Impl() {
eraseTree(root, &writeContext);
safe_free(metrics, metricsCount * sizeof(metrics[0]));
safe_free(removalBuffer, removalBufferSize);
}
WriteContext writeContext;
static constexpr int kMinRemovalBufferSize = 1 << 10;
// Eventually downsize if larger than this value
static constexpr int kMaxRemovalBufferSize = 1 << 16;
uint8_t *removalBuffer = (uint8_t *)safe_malloc(kMinRemovalBufferSize);
int removalBufferSize = kMinRemovalBufferSize;
TrivialSpan removalKey;
int64_t keyUpdates;
TaggedNodePointer root;
InternalVersionT oldestVersion;
int64_t oldestVersionFullPrecision;
int64_t oldestExtantVersion;
int64_t oldestVersionAtGcBegin;
int64_t newestVersionFullPrecision;
int64_t totalBytes = 0;
MetricsV1 *metrics;
int metricsCount = 0;
Metric *metricsList = nullptr;
#define GAUGE(name, help) \
Gauge name { metricsList, metricsCount, #name, help }
#define COUNTER(name, help) \
Counter name { metricsList, metricsCount, #name, help }
// ==================== METRICS DEFINITIONS ====================
COUNTER(point_read_total, "Total number of point reads checked");
COUNTER(point_read_short_circuit_total,
"Total number of point reads that did not require a full search to "
"check");
COUNTER(point_read_iterations_total,
"Total number of iterations of the main loop for point read checks");
COUNTER(prefix_read_total, "Total number of prefix reads checked");
COUNTER(prefix_read_short_circuit_total,
"Total number of prefix reads that did not require a full search to "
"check");
COUNTER(prefix_read_iterations_total,
"Total number of iterations of the main loop for prefix read checks");
COUNTER(range_read_total, "Total number of range reads checked");
COUNTER(range_read_short_circuit_total,
"Total number of range reads that did not require a full search to "
"check");
COUNTER(range_read_iterations_total,
"Total number of iterations of the main loops for range read checks");
COUNTER(range_read_node_scan_total,
"Total number of scans of individual nodes while "
"checking a range read");
COUNTER(commits_total,
"Total number of checks where the result is \"commit\"");
COUNTER(conflicts_total,
"Total number of checks where the result is \"conflict\"");
COUNTER(too_olds_total,
"Total number of checks where the result is \"too old\"");
COUNTER(check_bytes_total, "Total number of key bytes checked");
COUNTER(point_writes_total, "Total number of point writes");
COUNTER(range_writes_total,
"Total number of range writes (includes prefix writes)");
GAUGE(memory_bytes, "Total number of bytes in use");
COUNTER(nodes_allocated_total,
"The total number of physical tree nodes allocated");
COUNTER(nodes_released_total,
"The total number of physical tree nodes released");
COUNTER(insert_iterations_total,
"The total number of iterations of the main loop for insertion. "
"Includes searches where the entry already existed, and so insertion "
"did not take place");
COUNTER(entries_inserted_total,
"The total number of entries inserted in the tree");
COUNTER(entries_erased_total,
"The total number of entries erased from the tree");
COUNTER(gc_iterations_total, "The total number of iterations of the main "
"loop for garbage collection");
COUNTER(write_bytes_total, "Total number of key bytes in calls to addWrites");
GAUGE(oldest_version,
"The lowest version that doesn't result in \"TooOld\" for checks");
GAUGE(newest_version, "The version of the most recent call to addWrites");
GAUGE(oldest_extant_version, "A lower bound on the lowest version "
"associated with an existing entry");
// ==================== END METRICS DEFINITIONS ====================
#undef GAUGE
#undef COUNTER
void getMetricsV1(MetricsV1 **metrics, int *count) {
*metrics = this->metrics;
*count = metricsCount;
}
};
TaggedNodePointer &getInTree(Node *n, ConflictSet::Impl *impl) {
return n->parent == nullptr ? impl->root
: getChildExists(n->parent, n->parentsIndex);
}
// Internal entry points. Public entry points should just delegate to these
void internal_check(ConflictSet::Impl *impl,
const ConflictSet::ReadRange *reads,
ConflictSet::Result *results, int count) {
impl->check(reads, results, count);
}
void internal_addWrites(ConflictSet::Impl *impl,
const ConflictSet::WriteRange *writes, int count,
int64_t writeVersion) {
mallocBytesDelta = 0;
impl->addWrites(writes, count, writeVersion);
impl->totalBytes += mallocBytesDelta;
impl->memory_bytes.set(impl->totalBytes);
#if SHOW_MEMORY
if (impl->totalBytes != mallocBytes) {
abort();
}
#endif
}
void internal_setOldestVersion(ConflictSet::Impl *impl, int64_t oldestVersion) {
mallocBytesDelta = 0;
impl->setOldestVersion(oldestVersion);
impl->totalBytes += mallocBytesDelta;
impl->memory_bytes.set(impl->totalBytes);
#if SHOW_MEMORY
if (impl->totalBytes != mallocBytes) {
abort();
}
#endif
}
ConflictSet::Impl *internal_create(int64_t oldestVersion) {
mallocBytesDelta = 0;
auto *result = new (safe_malloc(sizeof(ConflictSet::Impl)))
ConflictSet::Impl{oldestVersion};
result->totalBytes += mallocBytesDelta;
return result;
}
void internal_destroy(ConflictSet::Impl *impl) {
impl->~Impl();
safe_free(impl, sizeof(ConflictSet::Impl));
}
int64_t internal_getBytes(ConflictSet::Impl *impl) { return impl->getBytes(); }
void internal_getMetricsV1(ConflictSet::Impl *impl,
ConflictSet::MetricsV1 **metrics, int *count) {
impl->getMetricsV1(metrics, count);
}
double internal_getMetricValue(const ConflictSet::MetricsV1 *metric) {
return ((Metric *)metric->p)->value.load(std::memory_order_relaxed);
}
// ==================== END IMPLEMENTATION ====================
// GCOVR_EXCL_START
Node *firstGeqLogical(Node *n, const TrivialSpan key) {
auto remaining = key;
for (;;) {
if (remaining.size() == 0) {
if (n->entryPresent) {
return n;
}
n = getFirstChild(n);
goto downLeftSpine;
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return nullptr;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > remaining.size()) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChild(n)) {
}
return n;
}
void ConflictSet::check(const ReadRange *reads, Result *results,
int count) const {
internal_check(impl, reads, results, count);
}
void ConflictSet::addWrites(const WriteRange *writes, int count,
int64_t writeVersion) {
internal_addWrites(impl, writes, count, writeVersion);
}
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
internal_setOldestVersion(impl, oldestVersion);
}
int64_t ConflictSet::getBytes() const { return internal_getBytes(impl); }
void ConflictSet::getMetricsV1(MetricsV1 **metrics, int *count) const {
return internal_getMetricsV1(impl, metrics, count);
}
double ConflictSet::MetricsV1::getValue() const {
return internal_getMetricValue(this);
}
ConflictSet::ConflictSet(int64_t oldestVersion)
: impl(internal_create(oldestVersion)) {}
ConflictSet::~ConflictSet() {
if (impl) {
internal_destroy(impl);
}
}
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
impl = std::exchange(other.impl, nullptr);
return *this;
}
using ConflictSet_Result = ConflictSet::Result;
using ConflictSet_Key = ConflictSet::Key;
using ConflictSet_ReadRange = ConflictSet::ReadRange;
using ConflictSet_WriteRange = ConflictSet::WriteRange;
extern "C" {
__attribute__((__visibility__("default"))) void
ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads,
ConflictSet_Result *results, int count) {
internal_check((ConflictSet::Impl *)cs, reads, results, count);
}
__attribute__((__visibility__("default"))) void
ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count,
int64_t writeVersion) {
internal_addWrites((ConflictSet::Impl *)cs, writes, count, writeVersion);
}
__attribute__((__visibility__("default"))) void
ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) {
internal_setOldestVersion((ConflictSet::Impl *)cs, oldestVersion);
}
__attribute__((__visibility__("default"))) void *
ConflictSet_create(int64_t oldestVersion) {
return internal_create(oldestVersion);
}
__attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) {
internal_destroy((ConflictSet::Impl *)cs);
}
__attribute__((__visibility__("default"))) int64_t
ConflictSet_getBytes(void *cs) {
return internal_getBytes((ConflictSet::Impl *)cs);
}
}
// Make sure abi is well-defined
static_assert(std::is_standard_layout_v<ConflictSet::Result>);
static_assert(std::is_standard_layout_v<ConflictSet::Key>);
static_assert(std::is_standard_layout_v<ConflictSet::ReadRange>);
static_assert(std::is_standard_layout_v<ConflictSet::WriteRange>);
static_assert(std::is_standard_layout_v<ConflictSet::MetricsV1>);
namespace {
std::string getSearchPathPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = vector<char>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
if (result.size() > 0) {
return printable(std::string_view((const char *)&result[0],
result.size())); // NOLINT
} else {
return std::string();
}
}
std::string getPartialKeyPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = std::string((const char *)&n->parentsIndex,
n->parent == nullptr ? 0 : 1) +
std::string((const char *)n->partialKey(), n->partialKeyLen);
return printable(result); // NOLINT
}
std::string strinc(std::string_view str, bool &ok) {
int index;
for (index = str.size() - 1; index >= 0; index--)
if ((uint8_t &)(str[index]) != 255)
break;
// Must not be called with a string that consists only of zero or more
// '\xff' bytes.
if (index < 0) {
ok = false;
return {};
}
ok = true;
auto r = std::string(str.substr(0, index + 1));
((uint8_t &)r[r.size() - 1])++;
return r;
}
std::string getSearchPath(Node *n) {
assert(n != nullptr);
Arena arena;
auto result = getSearchPath(arena, n);
return std::string((const char *)result.data(), result.size());
}
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node,
ConflictSet::Impl *impl) {
constexpr int kSeparation = 3;
struct DebugDotPrinter {
explicit DebugDotPrinter(FILE *file, ConflictSet::Impl *impl)
: file(file), impl(impl) {}
void print(Node *n, int y = 0) {
assert(n != nullptr);
if (n->entryPresent) {
fprintf(file,
" k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64
"\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, n->parent == nullptr ? -1 : maxVersion(n).toInt64(),
n->entry.pointVersion.toInt64(),
n->entry.rangeVersion.toInt64(),
getPartialKeyPrintable(n).c_str(), x, y);
} else {
fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, n->parent == nullptr ? -1 : maxVersion(n).toInt64(),
getPartialKeyPrintable(n).c_str(), x, y);
}
x += kSeparation;
for (auto c = getChildGeq(n, 0); c != nullptr;
c = getChildGeq(n, c->parentsIndex + 1)) {
fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c);
print(c, y - kSeparation);
}
}
int x = 0;
FILE *file;
ConflictSet::Impl *impl;
};
fprintf(file, "digraph ConflictSet {\n");
fprintf(file, " node [shape = box];\n");
assert(node != nullptr);
DebugDotPrinter printer{file, impl};
printer.print(node);
fprintf(file, "}\n");
}
void checkParentPointers(Node *node, bool &success) {
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
if (child->parent != node) {
fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n",
getSearchPathPrintable(node).c_str(), child->parentsIndex,
(void *)child->parent, (void *)node);
success = false;
}
checkParentPointers(child, success);
}
}
Node *firstGeq(Node *n, std::string_view key) {
return firstGeqLogical(n,
TrivialSpan((const uint8_t *)key.data(), key.size()));
}
#if USE_64_BIT
void checkVersionsGeqOldestExtant(Node *, InternalVersionT) {}
#else
void checkVersionsGeqOldestExtant(Node *n,
InternalVersionT oldestExtantVersion) {
if (n->entryPresent) {
assert(n->entry.pointVersion >= oldestExtantVersion);
assert(n->entry.rangeVersion >= oldestExtantVersion);
}
switch (n->getType()) {
case Type_Node0: {
} break;
case Type_Node3: {
[[maybe_unused]] auto *self = static_cast<Node3 *>(n);
for (int i = 0; i < 3; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
} break;
case Type_Node16: {
[[maybe_unused]] auto *self = static_cast<Node16 *>(n);
for (int i = 0; i < 16; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
} break;
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
for (int i = 0; i < 48; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
for ([[maybe_unused]] auto m : self->maxOfMax) {
assert(m >= oldestExtantVersion);
}
} break;
case Type_Node256: {
auto *self = static_cast<Node256 *>(n);
for (int i = 0; i < 256; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
for ([[maybe_unused]] auto m : self->maxOfMax) {
assert(m >= oldestExtantVersion);
}
} break;
default:
abort();
}
}
#endif
[[maybe_unused]] InternalVersionT
checkMaxVersion(Node *root, Node *node, InternalVersionT oldestVersion,
bool &success, ConflictSet::Impl *impl) {
checkVersionsGeqOldestExtant(node,
InternalVersionT(impl->oldestExtantVersion));
auto expected = oldestVersion;
if (node->entryPresent) {
expected = std::max(expected, node->entry.pointVersion);
}
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
expected = std::max(
expected, checkMaxVersion(root, child, oldestVersion, success, impl));
if (child->entryPresent) {
expected = std::max(expected, child->entry.rangeVersion);
}
}
auto key = getSearchPath(root);
bool ok;
auto inc = strinc(key, ok);
if (ok) {
auto borrowed = firstGeq(root, inc);
if (borrowed != nullptr) {
expected = std::max(expected, borrowed->entry.rangeVersion);
}
}
if (node->parent && maxVersion(node) > oldestVersion &&
maxVersion(node) != expected) {
fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",
getSearchPathPrintable(node).c_str(), maxVersion(node).toInt64(),
expected.toInt64());
success = false;
}
return expected;
}
[[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) {
int64_t total = node->entryPresent;
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
int64_t e = checkEntriesExist(child, success);
total += e;
if (e == 0) {
Arena arena;
fprintf(stderr, "%s has child %02x with no reachable entries\n",
getSearchPathPrintable(node).c_str(), child->parentsIndex);
success = false;
}
}
return total;
}
[[maybe_unused]] void checkMemoryBoundInvariants(Node *node, bool &success) {
int minNumChildren;
switch (node->getType()) {
case Type_Node0:
minNumChildren = kMinChildrenNode0;
break;
case Type_Node3:
minNumChildren = kMinChildrenNode3;
break;
case Type_Node16:
minNumChildren = kMinChildrenNode16;
break;
case Type_Node48:
minNumChildren = kMinChildrenNode48;
break;
case Type_Node256:
minNumChildren = kMinChildrenNode256;
break;
default:
abort();
}
if (node->numChildren + int(node->entryPresent) < minNumChildren) {
fprintf(stderr,
"%s has %d children + %d entries, which is less than the minimum "
"required %d\n",
getSearchPathPrintable(node).c_str(), node->numChildren,
int(node->entryPresent), minNumChildren);
success = false;
}
const int maxCapacity =
(node->numChildren + int(node->entryPresent)) * (node->partialKeyLen + 1);
if (node->getCapacity() > maxCapacity) {
fprintf(stderr, "%s has d capacity %d, which is more than the allowed %d\n",
getSearchPathPrintable(node).c_str(), node->getCapacity(),
maxCapacity);
success = false;
}
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
checkMemoryBoundInvariants(child, success);
}
}
[[maybe_unused]] bool checkCorrectness(Node *node,
InternalVersionT oldestVersion,
ConflictSet::Impl *impl) {
bool success = true;
if (node->partialKeyLen > 0) {
fprintf(stderr, "Root cannot have a partial key\n");
success = false;
}
checkParentPointers(node, success);
checkMaxVersion(node, node, oldestVersion, success, impl);
checkEntriesExist(node, success);
checkMemoryBoundInvariants(node, success);
return success;
}
} // namespace
#if SHOW_MEMORY
int64_t nodeBytes = 0;
int64_t peakNodeBytes = 0;
int64_t partialCapacityBytes = 0;
int64_t peakPartialCapacityBytes = 0;
int64_t totalKeys = 0;
int64_t peakKeys = 0;
int64_t keyBytes = 0;
int64_t peakKeyBytes = 0;
int64_t getNodeSize(struct Node *n) {
switch (n->getType()) {
case Type_Node0:
return sizeof(Node0);
case Type_Node3:
return sizeof(Node3);
case Type_Node16:
return sizeof(Node16);
case Type_Node48:
return sizeof(Node48);
case Type_Node256:
return sizeof(Node256);
default:
abort();
}
}
int64_t getSearchPathLength(Node *n) {
assert(n != nullptr);
int64_t result = 0;
for (;;) {
result += n->partialKeyLen;
if (n->parent == nullptr) {
break;
}
++result;
n = n->parent;
}
return result;
}
void addNode(Node *n) {
nodeBytes += getNodeSize(n);
partialCapacityBytes += n->getCapacity();
if (nodeBytes > peakNodeBytes) {
peakNodeBytes = nodeBytes;
}
if (partialCapacityBytes > peakPartialCapacityBytes) {
peakPartialCapacityBytes = partialCapacityBytes;
}
}
void removeNode(Node *n) {
nodeBytes -= getNodeSize(n);
partialCapacityBytes -= n->getCapacity();
}
void addKey(Node *n) {
if (!n->entryPresent) {
++totalKeys;
keyBytes += getSearchPathLength(n);
if (totalKeys > peakKeys) {
peakKeys = totalKeys;
}
if (keyBytes > peakKeyBytes) {
peakKeyBytes = keyBytes;
}
}
}
void removeKey(Node *n) {
if (n->entryPresent) {
--totalKeys;
keyBytes -= getSearchPathLength(n);
}
}
struct __attribute__((visibility("default"))) PeakPrinter {
~PeakPrinter() {
printf("--- radix_tree ---\n");
printf("malloc bytes: %g\n", double(mallocBytes));
printf("Peak malloc bytes: %g\n", double(peakMallocBytes));
printf("Node bytes: %g\n", double(nodeBytes));
printf("Peak node bytes: %g\n", double(peakNodeBytes));
printf("Expected worst case node bytes: %g\n",
double(peakKeys * kBytesPerKey));
printf("Key bytes: %g\n", double(keyBytes));
printf("Peak key bytes: %g (not sharing common prefixes)\n",
double(peakKeyBytes));
printf("Partial key capacity bytes: %g\n", double(partialCapacityBytes));
printf("Peak partial key capacity bytes: %g\n",
double(peakPartialCapacityBytes));
}
} peakPrinter;
#endif
#ifdef ENABLE_MAIN
#include "third_party/nanobench.h"
template <int kN> void benchRezero() {
static_assert(kN % 16 == 0);
ankerl::nanobench::Bench bench;
InternalVersionT vs[kN];
InternalVersionT zero;
bench.run("rezero" + std::to_string(kN), [&]() {
bench.doNotOptimizeAway(vs);
bench.doNotOptimizeAway(zero);
for (int i = 0; i < kN; i += 16) {
rezero16(vs + i, zero);
}
});
}
template <int kN> void benchScan1() {
static_assert(kN % 16 == 0);
ankerl::nanobench::Bench bench;
InternalVersionT vs[kN];
uint8_t is[kN];
uint8_t begin;
uint8_t end;
InternalVersionT v;
bench.run("scan" + std::to_string(kN), [&]() {
bench.doNotOptimizeAway(vs);
bench.doNotOptimizeAway(is);
bench.doNotOptimizeAway(begin);
bench.doNotOptimizeAway(end);
bench.doNotOptimizeAway(v);
for (int i = 0; i < kN; i += 16) {
scan16</*kAVX512=*/true>(vs + i, is + i, begin, end, v);
}
});
}
template <int kN> void benchScan2() {
static_assert(kN % 16 == 0);
ankerl::nanobench::Bench bench;
InternalVersionT vs[kN];
uint8_t is[kN];
uint8_t begin;
uint8_t end;
InternalVersionT v;
bench.run("scan" + std::to_string(kN), [&]() {
bench.doNotOptimizeAway(vs);
bench.doNotOptimizeAway(begin);
bench.doNotOptimizeAway(end);
bench.doNotOptimizeAway(v);
for (int i = 0; i < kN; i += 16) {
scan16</*kAVX512=*/true>(vs + i, begin, end, v);
}
});
}
void benchHorizontal16() {
ankerl::nanobench::Bench bench;
InternalVersionT vs[16];
for (int i = 0; i < 16; ++i) {
vs[i] = InternalVersionT(rand() % 1000 + 1000);
}
#if !USE_64_BIT
InternalVersionT::zero = InternalVersionT(rand() % 1000);
#endif
bench.run("horizontal16", [&]() {
bench.doNotOptimizeAway(horizontalMax16(vs, InternalVersionT::zero));
});
int x = rand() % 15 + 1;
bench.run("horizontalUpTo16", [&]() {
bench.doNotOptimizeAway(horizontalMaxUpTo16(vs, InternalVersionT::zero, x));
});
}
void benchLCP(int len) {
ankerl::nanobench::Bench bench;
std::vector<uint8_t> lhs(len);
std::vector<uint8_t> rhs(len);
bench.run("lcp " + std::to_string(len), [&]() {
bench.doNotOptimizeAway(lhs);
bench.doNotOptimizeAway(rhs);
bench.doNotOptimizeAway(longestCommonPrefix(lhs.data(), rhs.data(), len));
});
}
void printTree() {
int64_t writeVersion = 0;
ConflictSet::Impl cs{writeVersion};
ReferenceImpl refImpl{writeVersion};
Arena arena;
ConflictSet::WriteRange write;
write.begin = "and"_s;
write.end = "ant"_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "any"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "are"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "art"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
debugPrintDot(stdout, cs.root, &cs);
}
int main(void) { benchHorizontal16(); }
#endif
#ifdef ENABLE_FUZZ
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
Arbitrary arbitrary({data, size});
TestDriver<ConflictSet::Impl> driver1{arbitrary};
TestDriver<ConflictSet::Impl> driver2{arbitrary};
bool done1 = false;
bool done2 = false;
for (;;) {
if (!done1) {
done1 = driver1.next();
if (!driver1.ok) {
debugPrintDot(stdout, driver1.cs.root, &driver1.cs);
fflush(stdout);
abort();
}
if (!checkCorrectness(driver1.cs.root, driver1.cs.oldestVersion,
&driver1.cs)) {
debugPrintDot(stdout, driver1.cs.root, &driver1.cs);
fflush(stdout);
abort();
}
}
if (!done2) {
done2 = driver2.next();
if (!driver2.ok) {
debugPrintDot(stdout, driver2.cs.root, &driver2.cs);
fflush(stdout);
abort();
}
if (!checkCorrectness(driver2.cs.root, driver2.cs.oldestVersion,
&driver2.cs)) {
debugPrintDot(stdout, driver2.cs.root, &driver2.cs);
fflush(stdout);
abort();
}
}
if (done1 && done2) {
break;
}
}
return 0;
}
#endif
// GCOVR_EXCL_STOP