Files
conflict-set/ConflictSet.cpp

4699 lines
147 KiB
C++

/*
Copyright 2024 Andrew Noyes
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "ConflictSet.h"
#include "Internal.h"
#include "LongestCommonPrefix.h"
#include "Metrics.h"
#include <algorithm>
#include <bit>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <inttypes.h>
#include <limits>
#include <span>
#include <string>
#include <string_view>
#include <sys/time.h>
#include <type_traits>
#include <utility>
#ifdef HAS_AVX
#include <immintrin.h>
#elif defined(HAS_ARM_NEON)
#include <arm_neon.h>
#endif
#ifndef __SANITIZE_THREAD__
#if defined(__has_feature)
#if __has_feature(thread_sanitizer)
#define __SANITIZE_THREAD__
#endif
#endif
#endif
#include <memcheck.h>
using namespace weaselab;
// Use assert for checking potentially complex properties during tests.
// Use assume to hint simple properties to the optimizer.
// TODO use the c++23 version when that's available
#ifdef NDEBUG
#if __has_builtin(__builtin_assume)
#define assume(e) __builtin_assume(e)
#else
#define assume(e) \
if (!(e)) \
__builtin_unreachable()
#endif
#else
#define assume assert
#endif
#if SHOW_MEMORY
void addNode(struct Node *);
void removeNode(struct Node *);
void addKey(struct Node *);
void removeKey(struct Node *);
#else
constexpr void addNode(struct Node *) {}
constexpr void removeNode(struct Node *) {}
constexpr void addKey(struct Node *) {}
constexpr void removeKey(struct Node *) {}
#endif
// ==================== BEGIN IMPLEMENTATION ====================
constexpr int64_t kNominalVersionWindow = 2e9;
constexpr int64_t kMaxCorrectVersionWindow =
std::numeric_limits<int32_t>::max();
static_assert(kNominalVersionWindow <= kMaxCorrectVersionWindow);
#ifndef USE_64_BIT
#define USE_64_BIT 0
#endif
struct InternalVersionT {
constexpr InternalVersionT() = default;
constexpr explicit InternalVersionT(int64_t value) : value(value) {}
constexpr int64_t toInt64() const { return value; } // GCOVR_EXCL_LINE
constexpr auto operator<=>(const InternalVersionT &rhs) const {
#if USE_64_BIT
return value <=> rhs.value;
#else
// Maintains ordering after overflow, as long as the full-precision versions
// are within `kMaxCorrectVersionWindow` of eachother.
return int32_t(value - rhs.value) <=> 0;
#endif
}
constexpr bool operator==(const InternalVersionT &) const = default;
#if USE_64_BIT
static const InternalVersionT zero;
#else
static thread_local InternalVersionT zero;
#endif
private:
#if USE_64_BIT
int64_t value;
#else
uint32_t value;
#endif
};
#if USE_64_BIT
const InternalVersionT InternalVersionT::zero{0};
#else
thread_local InternalVersionT InternalVersionT::zero;
#endif
struct Entry {
InternalVersionT pointVersion;
InternalVersionT rangeVersion;
};
struct BitSet {
bool test(int i) const;
void set(int i);
void reset(int i);
int firstSetGeq(int i) const;
// Calls `f` with the index of each bit set
template <class F> void forEachSet(F f) const {
// See section 3.1 in https://arxiv.org/pdf/1709.07821.pdf for details about
// this approach
for (int begin = 0; begin < 256; begin += 64) {
uint64_t word = words[begin >> 6];
while (word) {
uint64_t temp = word & -word;
int index = begin + std::countr_zero(word);
f(index);
word ^= temp;
}
}
}
void init() {
for (auto &w : words) {
w = 0;
}
}
private:
uint64_t words[4];
};
bool BitSet::test(int i) const {
assert(0 <= i);
assert(i < 256);
return words[i >> 6] & (uint64_t(1) << (i & 63));
}
void BitSet::set(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] |= uint64_t(1) << (i & 63);
}
void BitSet::reset(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] &= ~(uint64_t(1) << (i & 63));
}
int BitSet::firstSetGeq(int i) const {
assume(0 <= i);
// i may be >= 256
uint64_t mask = uint64_t(-1) << (i & 63);
for (int j = i >> 6; j < 4; ++j) {
uint64_t masked = mask & words[j];
if (masked) {
return (j << 6) + std::countr_zero(masked);
}
mask = -1;
}
return -1;
}
enum Type : int8_t {
Type_Node0,
Type_Node3,
Type_Node16,
Type_Node48,
Type_Node256,
};
template <class T> struct BoundedFreeListAllocator;
struct TaggedNodePointer {
TaggedNodePointer() = default;
operator struct Node *() { return (struct Node *)withoutType(); }
operator struct Node0 *() {
assert(getType() == Type_Node0);
return (struct Node0 *)withoutType();
}
operator struct Node3 *() {
assert(getType() == Type_Node3);
return (struct Node3 *)withoutType();
}
operator struct Node16 *() {
assert(getType() == Type_Node16);
return (struct Node16 *)withoutType();
}
operator struct Node48 *() {
assert(getType() == Type_Node48);
return (struct Node48 *)withoutType();
}
operator struct Node256 *() {
assert(getType() == Type_Node256);
return (struct Node256 *)withoutType();
}
/*implicit*/ TaggedNodePointer(std::nullptr_t) : p(0) {}
/*implicit*/ TaggedNodePointer(Node0 *x)
: TaggedNodePointer((struct Node *)x, Type_Node0) {}
/*implicit*/ TaggedNodePointer(Node3 *x)
: TaggedNodePointer((struct Node *)x, Type_Node3) {}
/*implicit*/ TaggedNodePointer(Node16 *x)
: TaggedNodePointer((struct Node *)x, Type_Node16) {}
/*implicit*/ TaggedNodePointer(Node48 *x)
: TaggedNodePointer((struct Node *)x, Type_Node48) {}
/*implicit*/ TaggedNodePointer(Node256 *x)
: TaggedNodePointer((struct Node *)x, Type_Node256) {}
bool operator!=(std::nullptr_t) { return p != 0; }
bool operator==(std::nullptr_t) { return p == 0; }
bool operator==(const TaggedNodePointer &) const = default;
bool operator==(Node *n) const { return (uintptr_t)n == withoutType(); }
Node *operator->() { return (Node *)withoutType(); }
Type getType();
TaggedNodePointer(const TaggedNodePointer &) = default;
TaggedNodePointer &operator=(const TaggedNodePointer &) = default;
/*implicit*/ TaggedNodePointer(Node *n);
private:
TaggedNodePointer(struct Node *p, Type t) : p((uintptr_t)p) {
assert((this->p & 7) == 0);
this->p |= t;
assume(p != 0);
}
uintptr_t withoutType() const { return p & ~uintptr_t(7); }
uintptr_t p;
};
struct Node {
/* begin section that's copied to the next node */
Entry entry;
Node *parent;
int32_t partialKeyLen;
int16_t numChildren;
bool entryPresent;
// Temp variable used to signal the end of the range during addWriteRange
bool endOfRange;
uint8_t parentsIndex;
/* end section that's copied to the next node */
uint8_t *partialKey();
Type getType() const { return type; }
int32_t getCapacity() const { return partialKeyCapacity; }
private:
template <class T> friend struct BoundedFreeListAllocator;
// These are publically readable, but should only be written by
// BoundedFreeListAllocator
Type type;
int32_t partialKeyCapacity;
};
TaggedNodePointer::TaggedNodePointer(Node *n)
: TaggedNodePointer(n, n->getType()) {}
Type TaggedNodePointer::getType() {
assert(p != 0);
return Type(p & uintptr_t(7));
}
constexpr int kNodeCopyBegin = offsetof(Node, entry);
constexpr int kNodeCopySize =
offsetof(Node, parentsIndex) + sizeof(Node::parentsIndex) - kNodeCopyBegin;
// copyChildrenAndKeyFrom is responsible for copying all
// public members of Node, copying the partial key, logically copying the
// children (converting representation if necessary), and updating all the
// children's parent pointers. The caller must then insert the new node into the
// tree.
struct Node0 : Node {
constexpr static auto kType = Type_Node0;
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node0 &other);
void copyChildrenAndKeyFrom(const struct Node3 &other);
size_t size() const { return sizeof(Node0) + getCapacity(); }
};
struct Node3 : Node {
constexpr static auto kMaxNodes = 3;
constexpr static auto kType = Type_Node3;
TaggedNodePointer children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
// Sorted
uint8_t index[kMaxNodes];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node0 &other);
void copyChildrenAndKeyFrom(const Node3 &other);
void copyChildrenAndKeyFrom(const struct Node16 &other);
size_t size() const { return sizeof(Node3) + getCapacity(); }
};
struct Node16 : Node {
constexpr static auto kType = Type_Node16;
constexpr static auto kMaxNodes = 16;
TaggedNodePointer children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
// Sorted
uint8_t index[kMaxNodes];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node3 &other);
void copyChildrenAndKeyFrom(const Node16 &other);
void copyChildrenAndKeyFrom(const struct Node48 &other);
size_t size() const { return sizeof(Node16) + getCapacity(); }
};
struct Node48 : Node {
constexpr static auto kType = Type_Node48;
constexpr static auto kMaxNodes = 48;
constexpr static int kMaxOfMaxPageSize = 16;
constexpr static int kMaxOfMaxShift =
std::countr_zero(uint32_t(kMaxOfMaxPageSize));
constexpr static int kMaxOfMaxTotalPages = kMaxNodes / kMaxOfMaxPageSize;
BitSet bitSet;
TaggedNodePointer children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
InternalVersionT maxOfMax[kMaxOfMaxTotalPages];
uint8_t reverseIndex[kMaxNodes];
int8_t index[256];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node16 &other);
void copyChildrenAndKeyFrom(const Node48 &other);
void copyChildrenAndKeyFrom(const struct Node256 &other);
size_t size() const { return sizeof(Node48) + getCapacity(); }
};
struct Node256 : Node {
constexpr static auto kType = Type_Node256;
constexpr static auto kMaxNodes = 256;
constexpr static int kMaxOfMaxPageSize = 16;
constexpr static int kMaxOfMaxShift =
std::countr_zero(uint32_t(kMaxOfMaxPageSize));
constexpr static int kMaxOfMaxTotalPages = kMaxNodes / kMaxOfMaxPageSize;
BitSet bitSet;
TaggedNodePointer children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
InternalVersionT maxOfMax[kMaxOfMaxTotalPages];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node48 &other);
void copyChildrenAndKeyFrom(const Node256 &other);
size_t size() const { return sizeof(Node256) + getCapacity(); }
};
inline void Node0::copyChildrenAndKeyFrom(const Node0 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node0::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node3::copyChildrenAndKeyFrom(const Node0 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node3::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(children, other.children, sizeof(*this) - sizeof(Node));
memcpy(partialKey(), &other + 1, partialKeyLen);
for (int i = 0; i < numChildren; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node3::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, kMaxNodes);
memcpy(children, other.children, kMaxNodes * sizeof(children[0])); // NOLINT
memcpy(childMaxVersion, other.childMaxVersion,
kMaxNodes * sizeof(childMaxVersion[0]));
memcpy(partialKey(), &other + 1, partialKeyLen);
for (int i = 0; i < numChildren; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node16::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, Node3::kMaxNodes);
memcpy(children, other.children,
Node3::kMaxNodes * sizeof(children[0])); // NOLINT
memcpy(childMaxVersion, other.childMaxVersion,
Node3::kMaxNodes * sizeof(childMaxVersion[0]));
memcpy(partialKey(), &other + 1, partialKeyLen);
assert(numChildren == Node3::kMaxNodes);
for (int i = 0; i < Node3::kMaxNodes; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node16::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, sizeof(index));
for (int i = 0; i < numChildren; ++i) {
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
}
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node16::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
int i = 0;
other.bitSet.forEachSet([&](int c) {
// Suppress a false positive -Waggressive-loop-optimizations warning
// in gcc
assume(i < Node16::kMaxNodes);
index[i] = c;
children[i] = other.children[other.index[c]];
childMaxVersion[i] = other.childMaxVersion[other.index[c]];
assert(children[i]->parent == &other);
children[i]->parent = this;
++i;
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node48::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
assert(numChildren == Node16::kMaxNodes);
memset(index, -1, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
memcpy(partialKey(), &other + 1, partialKeyLen);
bitSet.init();
int i = 0;
for (auto x : other.index) {
bitSet.set(x);
index[x] = i;
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
reverseIndex[i] = x;
maxOfMax[i >> Node48::kMaxOfMaxShift] =
std::max(maxOfMax[i >> Node48::kMaxOfMaxShift], childMaxVersion[i]);
++i;
}
}
inline void Node48::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
bitSet = other.bitSet;
memcpy(index, other.index, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
for (int i = 0; i < numChildren; ++i) {
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
}
memcpy(reverseIndex, other.reverseIndex, sizeof(reverseIndex));
memcpy(maxOfMax, other.maxOfMax, sizeof(maxOfMax));
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node48::copyChildrenAndKeyFrom(const Node256 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memset(index, -1, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
bitSet = other.bitSet;
int i = 0;
bitSet.forEachSet([&](int c) {
// Suppress a false positive -Waggressive-loop-optimizations warning
// in gcc.
assume(i < Node48::kMaxNodes);
index[c] = i;
children[i] = other.children[c];
childMaxVersion[i] = other.childMaxVersion[c];
assert(children[i]->parent == &other);
children[i]->parent = this;
reverseIndex[i] = c;
maxOfMax[i >> Node48::kMaxOfMaxShift] =
std::max(maxOfMax[i >> Node48::kMaxOfMaxShift], childMaxVersion[i]);
++i;
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node256::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
bitSet = other.bitSet;
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
for (auto &v : maxOfMax) {
v = z;
}
bitSet.forEachSet([&](int c) {
children[c] = other.children[other.index[c]];
childMaxVersion[c] = other.childMaxVersion[other.index[c]];
assert(children[c]->parent == &other);
children[c]->parent = this;
maxOfMax[c >> Node256::kMaxOfMaxShift] =
std::max(maxOfMax[c >> Node256::kMaxOfMaxShift], childMaxVersion[c]);
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node256::copyChildrenAndKeyFrom(const Node256 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
bitSet = other.bitSet;
bitSet.forEachSet([&](int c) {
children[c] = other.children[c];
childMaxVersion[c] = other.childMaxVersion[c];
assert(children[c]->parent == &other);
children[c]->parent = this;
});
memcpy(maxOfMax, other.maxOfMax, sizeof(maxOfMax));
memcpy(partialKey(), &other + 1, partialKeyLen);
}
namespace {
std::string getSearchPathPrintable(Node *n);
std::string getSearchPath(Node *n);
} // namespace
// Bound memory usage following the analysis in the ART paper
// Each node with an entry present gets a budget of kBytesPerKey. Node0 always
// has an entry present.
// Induction hypothesis is that each node's surplus is >= kMinNodeSurplus
#if USE_64_BIT
constexpr int kBytesPerKey = 144;
constexpr int kMinNodeSurplus = 104;
#else
constexpr int kBytesPerKey = 112;
constexpr int kMinNodeSurplus = 80;
#endif
// Cound the entry itself as a child
constexpr int kMinChildrenNode0 = 1;
constexpr int kMinChildrenNode3 = 2;
constexpr int kMinChildrenNode16 = 4;
constexpr int kMinChildrenNode48 = 17;
constexpr int kMinChildrenNode256 = 49;
constexpr int kNode256Surplus =
kMinChildrenNode256 * kMinNodeSurplus - sizeof(Node256);
static_assert(kNode256Surplus >= kMinNodeSurplus);
constexpr int kNode48Surplus =
kMinChildrenNode48 * kMinNodeSurplus - sizeof(Node48);
static_assert(kNode48Surplus >= kMinNodeSurplus);
constexpr int kNode16Surplus =
kMinChildrenNode16 * kMinNodeSurplus - sizeof(Node16);
static_assert(kNode16Surplus >= kMinNodeSurplus);
constexpr int kNode3Surplus =
kMinChildrenNode3 * kMinNodeSurplus - sizeof(Node3);
static_assert(kNode3Surplus >= kMinNodeSurplus);
static_assert(kBytesPerKey - sizeof(Node0) >= kMinNodeSurplus);
// setOldestVersion will additionally try to maintain this property:
// `(children + entryPresent) * length >= capacity`
//
// Which should give us the budget to pay for the key bytes. (children +
// entryPresent) is a lower bound on how many keys these bytes are a prefix of
constexpr int64_t kFreeListMaxMemory = 1 << 20;
template <class T> struct BoundedFreeListAllocator {
static_assert(sizeof(T) >= sizeof(void *));
static_assert(std::derived_from<T, Node>);
static_assert(std::is_trivial_v<T>);
T *allocate_helper(int partialKeyCapacity) {
if (freeList != nullptr) {
T *n = (T *)freeList;
VALGRIND_MAKE_MEM_DEFINED(freeList, sizeof(freeList));
memcpy(&freeList, freeList, sizeof(freeList));
VALGRIND_MAKE_MEM_UNDEFINED(n, sizeof(T));
VALGRIND_MAKE_MEM_DEFINED(&n->partialKeyCapacity,
sizeof(n->partialKeyCapacity));
VALGRIND_MAKE_MEM_DEFINED(&n->type, sizeof(n->type));
assert(n->type == T::kType);
VALGRIND_MAKE_MEM_UNDEFINED(n + 1, n->partialKeyCapacity);
freeListBytes -= sizeof(T) + n->partialKeyCapacity;
if (n->partialKeyCapacity >= partialKeyCapacity) {
return n;
} else {
// The intent is to filter out too-small nodes in the freelist
removeNode(n);
safe_free(n, sizeof(T) + n->partialKeyCapacity);
}
}
auto *result = (T *)safe_malloc(sizeof(T) + partialKeyCapacity);
result->type = T::kType;
result->partialKeyCapacity = partialKeyCapacity;
addNode(result);
return result;
}
T *allocate(int partialKeyCapacity) {
T *result = allocate_helper(partialKeyCapacity);
result->endOfRange = false;
if constexpr (!std::is_same_v<T, Node0>) {
memset(result->children, 0, sizeof(result->children));
const auto z = InternalVersionT::zero;
for (auto &v : result->childMaxVersion) {
v = z;
}
}
if constexpr (std::is_same_v<T, Node48> || std::is_same_v<T, Node256>) {
const auto z = InternalVersionT::zero;
for (auto &v : result->maxOfMax) {
v = z;
}
}
return result;
}
void release(T *p) {
if (freeListBytes >= kFreeListMaxMemory) {
removeNode(p);
return safe_free(p, sizeof(T) + p->partialKeyCapacity);
}
memcpy((void *)p, &freeList, sizeof(freeList));
freeList = p;
freeListBytes += sizeof(T) + p->partialKeyCapacity;
VALGRIND_MAKE_MEM_NOACCESS(freeList, sizeof(T) + p->partialKeyCapacity);
}
BoundedFreeListAllocator() = default;
BoundedFreeListAllocator(const BoundedFreeListAllocator &) = delete;
BoundedFreeListAllocator &
operator=(const BoundedFreeListAllocator &) = delete;
BoundedFreeListAllocator(BoundedFreeListAllocator &&) = delete;
BoundedFreeListAllocator &operator=(BoundedFreeListAllocator &&) = delete;
~BoundedFreeListAllocator() {
for (void *iter = freeList; iter != nullptr;) {
VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(Node));
auto *tmp = (T *)iter;
memcpy(&iter, iter, sizeof(void *));
removeNode((tmp));
safe_free(tmp, sizeof(T) + tmp->partialKeyCapacity);
}
}
private:
int64_t freeListBytes = 0;
void *freeList = nullptr;
};
uint8_t *Node::partialKey() {
switch (type) {
case Type_Node0:
return ((Node0 *)this)->partialKey();
case Type_Node3:
return ((Node3 *)this)->partialKey();
case Type_Node16:
return ((Node16 *)this)->partialKey();
case Type_Node48:
return ((Node48 *)this)->partialKey();
case Type_Node256:
return ((Node256 *)this)->partialKey();
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// A type that's plumbed along the check call tree. Lifetime ends after each
// check call.
struct ReadContext {
int64_t point_read_accum = 0;
int64_t prefix_read_accum = 0;
int64_t range_read_accum = 0;
int64_t point_read_short_circuit_accum = 0;
int64_t prefix_read_short_circuit_accum = 0;
int64_t range_read_short_circuit_accum = 0;
int64_t point_read_iterations_accum = 0;
int64_t prefix_read_iterations_accum = 0;
int64_t range_read_iterations_accum = 0;
int64_t range_read_node_scan_accum = 0;
int64_t commits_accum = 0;
int64_t conflicts_accum = 0;
int64_t too_olds_accum = 0;
ConflictSet::Impl *impl;
};
// A type that's plumbed along the non-const call tree. Same lifetime as
// ConflictSet::Impl
struct WriteContext {
struct Accum {
int64_t entries_erased;
int64_t insert_iterations;
int64_t entries_inserted;
int64_t nodes_allocated;
int64_t nodes_released;
int64_t point_writes;
int64_t range_writes;
int64_t write_bytes;
} accum;
#if USE_64_BIT
static constexpr InternalVersionT zero{0};
#else
// Cache a copy of InternalVersionT::zero, so we don't need to do the TLS
// lookup as often.
InternalVersionT zero;
#endif
WriteContext() { memset(&accum, 0, sizeof(accum)); }
template <class T> T *allocate(int c) {
++accum.nodes_allocated;
if constexpr (std::is_same_v<T, Node0>) {
return node0.allocate(c);
} else if constexpr (std::is_same_v<T, Node3>) {
return node3.allocate(c);
} else if constexpr (std::is_same_v<T, Node16>) {
return node16.allocate(c);
} else if constexpr (std::is_same_v<T, Node48>) {
return node48.allocate(c);
} else if constexpr (std::is_same_v<T, Node256>) {
return node256.allocate(c);
}
}
template <class T> void release(T *c) {
static_assert(!std::is_same_v<T, Node>);
++accum.nodes_released;
if constexpr (std::is_same_v<T, Node0>) {
return node0.release(c);
} else if constexpr (std::is_same_v<T, Node3>) {
return node3.release(c);
} else if constexpr (std::is_same_v<T, Node16>) {
return node16.release(c);
} else if constexpr (std::is_same_v<T, Node48>) {
return node48.release(c);
} else if constexpr (std::is_same_v<T, Node256>) {
return node256.release(c);
}
}
private:
BoundedFreeListAllocator<Node0> node0;
BoundedFreeListAllocator<Node3> node3;
BoundedFreeListAllocator<Node16> node16;
BoundedFreeListAllocator<Node48> node48;
BoundedFreeListAllocator<Node256> node256;
};
int getNodeIndex(Node3 *self, uint8_t index) {
Node3 *n = (Node3 *)self;
assume(n->numChildren >= 1);
assume(n->numChildren <= 3);
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] == index) {
return i;
}
}
return -1;
}
int getNodeIndexExists(Node3 *self, uint8_t index) {
Node3 *n = (Node3 *)self;
assume(n->numChildren >= 1);
assume(n->numChildren <= 3);
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] == index) {
return i;
}
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
int getNodeIndex(Node16 *self, uint8_t index) {
#ifdef HAS_AVX
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
// key_vec is 16 repeated copies of the searched-for byte, one for every
// possible position in child_keys that needs to be searched.
__m128i key_vec = _mm_set1_epi8(index);
// Compare all child_keys to 'index' in parallel. Don't worry if some of the
// keys aren't valid, we'll mask the results to only consider the valid ones
// below.
__m128i indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
// Build a mask to select only the first node->num_children values from the
// comparison (because the other values are meaningless)
uint32_t mask = (1 << self->numChildren) - 1;
// Change the results of the comparison into a bitfield, masking off any
// invalid comparisons.
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
// No match if there are no '1's in the bitfield.
if (bitfield == 0)
return -1;
// Find the index of the first '1' in the bitfield by counting the leading
// zeros.
return std::countr_zero(bitfield);
#elif defined(HAS_ARM_NEON)
// Based on
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
uint8x16_t indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
assume(self->numChildren <= Node16::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each match in valid range
uint64_t bitfield =
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
if (bitfield == 0)
return -1;
return std::countr_zero(bitfield) / 4;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
}
}
return -1;
#endif
}
int getNodeIndexExists(Node16 *self, uint8_t index) {
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(index);
__m128i indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
uint32_t mask = (1 << self->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
assume(bitfield != 0);
return std::countr_zero(bitfield);
#elif defined(HAS_ARM_NEON)
// Based on
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
uint8x16_t indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
assume(self->numChildren <= Node16::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each match in valid range
uint64_t bitfield =
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
assume(bitfield != 0);
return std::countr_zero(bitfield) / 4;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
}
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
#endif
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node3 *self, uint8_t index) {
return self->children[getNodeIndexExists(self, index)];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node16 *self, uint8_t index) {
return self->children[getNodeIndexExists(self, index)];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node48 *self, uint8_t index) {
assert(self->bitSet.test(index));
return self->children[self->index[index]];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node256 *self, uint8_t index) {
assert(self->bitSet.test(index));
return self->children[index];
}
// Precondition - an entry for index must exist in the node
TaggedNodePointer &getChildExists(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
return getChildExists(static_cast<Node3 *>(self), index);
}
case Type_Node16: {
return getChildExists(static_cast<Node16 *>(self), index);
}
case Type_Node48: {
return getChildExists(static_cast<Node48 *>(self), index);
}
case Type_Node256: {
return getChildExists(static_cast<Node256 *>(self), index);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition `n` is not the root
InternalVersionT maxVersion(Node *n) {
int index = n->parentsIndex;
n = n->parent;
assert(n != nullptr);
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndexExists(n3, index);
return n3->childMaxVersion[i];
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndexExists(n16, index);
return n16->childMaxVersion[i];
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
return n48->childMaxVersion[n48->index[index]];
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
return n256->childMaxVersion[index];
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition `n` is not the root
InternalVersionT exchangeMaxVersion(Node *n, InternalVersionT newMax) {
int index = n->parentsIndex;
n = n->parent;
assert(n != nullptr);
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndexExists(n3, index);
return std::exchange(n3->childMaxVersion[i], newMax);
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndexExists(n16, index);
return std::exchange(n16->childMaxVersion[i], newMax);
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
return std::exchange(n48->childMaxVersion[n48->index[index]], newMax);
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
return std::exchange(n256->childMaxVersion[index], newMax);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition `n` is not the root
void setMaxVersion(Node *n, InternalVersionT newMax) {
assert(newMax >= InternalVersionT::zero);
int index = n->parentsIndex;
n = n->parent;
assert(n != nullptr);
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndexExists(n3, index);
n3->childMaxVersion[i] = newMax;
return;
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndexExists(n16, index);
n16->childMaxVersion[i] = newMax;
return;
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
int i = n48->index[index];
n48->childMaxVersion[i] = newMax;
n48->maxOfMax[i >> Node48::kMaxOfMaxShift] = std::max<InternalVersionT>(
n48->maxOfMax[i >> Node48::kMaxOfMaxShift], newMax);
return;
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
n256->childMaxVersion[index] = newMax;
n256->maxOfMax[index >> Node256::kMaxOfMaxShift] =
std::max<InternalVersionT>(
n256->maxOfMax[index >> Node256::kMaxOfMaxShift], newMax);
return;
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
TaggedNodePointer &getInTree(Node *n, ConflictSet::Impl *);
TaggedNodePointer getChild(Node0 *, uint8_t) { return nullptr; }
TaggedNodePointer getChild(Node3 *self, uint8_t index) {
int i = getNodeIndex(self, index);
return i < 0 ? nullptr : self->children[i];
}
TaggedNodePointer getChild(Node16 *self, uint8_t index) {
int i = getNodeIndex(self, index);
return i < 0 ? nullptr : self->children[i];
}
TaggedNodePointer getChild(Node48 *self, uint8_t index) {
int i = self->index[index];
return i < 0 ? nullptr : self->children[i];
}
TaggedNodePointer getChild(Node256 *self, uint8_t index) {
return self->children[index];
}
TaggedNodePointer getChild(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChild(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChild(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChild(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChild(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChild(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
struct ChildAndMaxVersion {
TaggedNodePointer child;
InternalVersionT maxVersion;
};
ChildAndMaxVersion getChildAndMaxVersion(Node0 *, uint8_t) { return {}; }
ChildAndMaxVersion getChildAndMaxVersion(Node3 *self, uint8_t index) {
int i = getNodeIndex(self, index);
if (i < 0) {
return {};
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node16 *self, uint8_t index) {
int i = getNodeIndex(self, index);
if (i < 0) {
return {};
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node48 *self, uint8_t index) {
int i = self->index[index];
if (i < 0) {
return {};
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node256 *self, uint8_t index) {
return {self->children[index], self->childMaxVersion[index]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChildAndMaxVersion(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChildAndMaxVersion(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChildAndMaxVersion(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChildAndMaxVersion(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChildAndMaxVersion(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
Node *getChildGeq(Node0 *, int) { return nullptr; }
Node *getChildGeq(Node3 *n, int child) {
assume(n->numChildren >= 1);
assume(n->numChildren <= 3);
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] >= child) {
return n->children[i];
}
}
return nullptr;
}
Node *getChildGeq(Node16 *self, int child) {
if (child > 255) {
return nullptr;
}
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(child);
__m128i indices;
memcpy(&indices, self->index, Node16::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices));
int mask = (1 << self->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
return bitfield == 0 ? nullptr : self->children[std::countr_zero(bitfield)];
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self->index, sizeof(self->index));
// 0xff for each leq
auto results = vcleq_u8(vdupq_n_u8(child), indices);
assume(self->numChildren <= Node16::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
return bitfield == 0 ? nullptr
: self->children[std::countr_zero(bitfield) / 4];
#else
for (int i = 0; i < self->numChildren; ++i) {
if (i > 0) {
assert(self->index[i - 1] < self->index[i]);
}
if (self->index[i] >= child) {
return self->children[i];
}
}
return nullptr;
#endif
}
Node *getChildGeq(Node48 *self, int child) {
int c = self->bitSet.firstSetGeq(child);
if (c < 0) {
return nullptr;
}
return self->children[self->index[c]];
}
Node *getChildGeq(Node256 *self, int child) {
int c = self->bitSet.firstSetGeq(child);
if (c < 0) {
return nullptr;
}
return self->children[c];
}
Node *getChildGeq(Node *self, int child) {
switch (self->getType()) {
case Type_Node0:
return getChildGeq(static_cast<Node0 *>(self), child);
case Type_Node3:
return getChildGeq(static_cast<Node3 *>(self), child);
case Type_Node16:
return getChildGeq(static_cast<Node16 *>(self), child);
case Type_Node48:
return getChildGeq(static_cast<Node48 *>(self), child);
case Type_Node256:
return getChildGeq(static_cast<Node256 *>(self), child);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition: self has a child
Node *getFirstChildExists(Node3 *self) {
assert(self->numChildren > 0);
return self->children[0];
}
// Precondition: self has a child
Node *getFirstChildExists(Node16 *self) {
assert(self->numChildren > 0);
return self->children[0];
}
// Precondition: self has a child
Node *getFirstChildExists(Node48 *self) {
return self->children[self->index[self->bitSet.firstSetGeq(0)]];
}
// Precondition: self has a child
Node *getFirstChildExists(Node256 *self) {
return self->children[self->bitSet.firstSetGeq(0)];
}
// Precondition: self has a child
Node *getFirstChildExists(Node *self) {
// Only require that the node-specific overloads are covered
// GCOVR_EXCL_START
switch (self->getType()) {
case Type_Node0:
__builtin_unreachable();
case Type_Node3:
return getFirstChildExists(static_cast<Node3 *>(self));
case Type_Node16:
return getFirstChildExists(static_cast<Node16 *>(self));
case Type_Node48:
return getFirstChildExists(static_cast<Node48 *>(self));
case Type_Node256:
return getFirstChildExists(static_cast<Node256 *>(self));
default:
__builtin_unreachable();
}
// GCOVR_EXCL_STOP
}
void consumePartialKeyFull(TaggedNodePointer &self,
std::span<const uint8_t> &key,
InternalVersionT writeVersion, WriteContext *tls) {
// Handle an existing partial key
int commonLen = std::min<int>(self->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(self->partialKey(), key.data(), commonLen);
if (partialKeyIndex < self->partialKeyLen) {
Node *old = self;
// Since root cannot have a partial key
assert(old->parent != nullptr);
InternalVersionT oldMaxVersion = exchangeMaxVersion(old, writeVersion);
// *self will have one child (old)
auto *newSelf = tls->allocate<Node3>(partialKeyIndex);
newSelf->parent = old->parent;
newSelf->parentsIndex = old->parentsIndex;
newSelf->partialKeyLen = partialKeyIndex;
newSelf->entryPresent = false;
newSelf->numChildren = 1;
memcpy(newSelf->partialKey(), old->partialKey(), newSelf->partialKeyLen);
uint8_t oldDistinguishingByte = old->partialKey()[partialKeyIndex];
old->parent = newSelf;
old->parentsIndex = oldDistinguishingByte;
newSelf->index[0] = oldDistinguishingByte;
newSelf->children[0] = old;
newSelf->childMaxVersion[0] = oldMaxVersion;
self = newSelf;
memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1,
old->partialKeyLen - (partialKeyIndex + 1));
old->partialKeyLen -= partialKeyIndex + 1;
// We would consider decreasing capacity here, but we can't invalidate
// old since it's not on the search path. setOldestVersion will clean it
// up.
}
key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex);
}
// Consume any partial key of `self`, and update `self` and
// `key` such that `self` is along the search path of `key`
inline __attribute__((always_inline)) void
consumePartialKey(TaggedNodePointer &self, std::span<const uint8_t> &key,
InternalVersionT writeVersion, WriteContext *tls) {
if (self->partialKeyLen > 0) {
consumePartialKeyFull(self, key, writeVersion, tls);
}
}
// Return the next node along the search path of key, consuming bytes of key
// such that the search path of the result + key is the same as the search path
// of self + key before the call. Creates a node if necessary. Updates
// `maxVersion` for result.
TaggedNodePointer &getOrCreateChild(TaggedNodePointer &self,
std::span<const uint8_t> &key,
InternalVersionT newMaxVersion,
WriteContext *tls) {
int index = key.front();
key = key.subspan(1, key.size() - 1);
// Fast path for if it exists already
switch (self->getType()) {
case Type_Node0:
break;
case Type_Node3: {
auto *self3 = static_cast<Node3 *>(self);
int i = getNodeIndex(self3, index);
if (i >= 0) {
consumePartialKey(self3->children[i], key, newMaxVersion, tls);
self3->childMaxVersion[i] = newMaxVersion;
return self3->children[i];
}
} break;
case Type_Node16: {
auto *self16 = static_cast<Node16 *>(self);
int i = getNodeIndex(self16, index);
if (i >= 0) {
consumePartialKey(self16->children[i], key, newMaxVersion, tls);
self16->childMaxVersion[i] = newMaxVersion;
return self16->children[i];
}
} break;
case Type_Node48: {
auto *self48 = static_cast<Node48 *>(self);
int secondIndex = self48->index[index];
if (secondIndex >= 0) {
consumePartialKey(self48->children[secondIndex], key, newMaxVersion, tls);
self48->childMaxVersion[secondIndex] = newMaxVersion;
self48->maxOfMax[secondIndex >> Node48::kMaxOfMaxShift] =
std::max(self48->maxOfMax[secondIndex >> Node48::kMaxOfMaxShift],
newMaxVersion);
return self48->children[secondIndex];
}
} break;
case Type_Node256: {
auto *self256 = static_cast<Node256 *>(self);
if (auto &result = self256->children[index]; result != nullptr) {
consumePartialKey(result, key, newMaxVersion, tls);
self256->childMaxVersion[index] = newMaxVersion;
self256->maxOfMax[index >> Node256::kMaxOfMaxShift] = std::max(
self256->maxOfMax[index >> Node256::kMaxOfMaxShift], newMaxVersion);
return result;
}
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
auto *newChild = tls->allocate<Node0>(key.size());
newChild->numChildren = 0;
newChild->entryPresent = false;
newChild->partialKeyLen = key.size();
newChild->parentsIndex = index;
memcpy(newChild->partialKey(), key.data(), key.size());
key = {};
switch (self->getType()) {
case Type_Node0: {
auto *self0 = static_cast<Node0 *>(self);
auto *newSelf = tls->allocate<Node3>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self0);
tls->release(self0);
self = newSelf;
goto insert3;
}
case Type_Node3: {
if (self->numChildren == Node3::kMaxNodes) {
auto *self3 = static_cast<Node3 *>(self);
auto *newSelf = tls->allocate<Node16>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self3);
tls->release(self3);
self = newSelf;
goto insert16;
}
insert3:
auto *self3 = static_cast<Node3 *>(self);
int i = self->numChildren - 1;
for (; i >= 0; --i) {
if (self3->index[i] < index) {
break;
}
self3->index[i + 1] = self3->index[i];
self3->children[i + 1] = self3->children[i];
self3->childMaxVersion[i + 1] = self3->childMaxVersion[i];
}
self3->index[i + 1] = index;
auto &result = self3->children[i + 1];
self3->childMaxVersion[i + 1] = newMaxVersion;
result = newChild;
++self->numChildren;
newChild->parent = self;
return result;
}
case Type_Node16: {
if (self->numChildren == Node16::kMaxNodes) {
auto *self16 = static_cast<Node16 *>(self);
auto *newSelf = tls->allocate<Node48>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self16);
tls->release(self16);
self = newSelf;
goto insert48;
}
insert16:
assert(self->getType() == Type_Node16);
auto *self16 = static_cast<Node16 *>(self);
int i = self->numChildren - 1;
for (; i >= 0; --i) {
if (self16->index[i] < index) {
break;
}
self16->index[i + 1] = self16->index[i];
self16->children[i + 1] = self16->children[i];
self16->childMaxVersion[i + 1] = self16->childMaxVersion[i];
}
self16->index[i + 1] = index;
auto &result = self16->children[i + 1];
self16->childMaxVersion[i + 1] = newMaxVersion;
result = newChild;
++self->numChildren;
newChild->parent = self;
return result;
}
case Type_Node48: {
if (self->numChildren == 48) {
auto *self48 = static_cast<Node48 *>(self);
auto *newSelf = tls->allocate<Node256>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self48);
tls->release(self48);
self = newSelf;
goto insert256;
}
insert48:
auto *self48 = static_cast<Node48 *>(self);
self48->bitSet.set(index);
auto nextFree = self48->numChildren++;
self48->index[index] = nextFree;
self48->reverseIndex[nextFree] = index;
auto &result = self48->children[nextFree];
self48->childMaxVersion[nextFree] = newMaxVersion;
self48->maxOfMax[nextFree >> Node48::kMaxOfMaxShift] = std::max(
newMaxVersion, self48->maxOfMax[nextFree >> Node48::kMaxOfMaxShift]);
result = newChild;
newChild->parent = self;
return result;
}
case Type_Node256: {
insert256:
auto *self256 = static_cast<Node256 *>(self);
++self->numChildren;
self256->bitSet.set(index);
auto &result = self256->children[index];
self256->childMaxVersion[index] = newMaxVersion;
self256->maxOfMax[index >> Node256::kMaxOfMaxShift] = std::max(
newMaxVersion, self256->maxOfMax[index >> Node256::kMaxOfMaxShift]);
result = newChild;
newChild->parent = self;
return result;
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
Node *nextPhysical(Node *node) {
int index = -1;
for (;;) {
auto nextChild = getChildGeq(node, index + 1);
if (nextChild != nullptr) {
return nextChild;
}
index = node->parentsIndex;
node = node->parent;
if (node == nullptr) {
return nullptr;
}
}
}
Node *nextLogical(Node *node) {
for (node = nextPhysical(node); node != nullptr && !node->entryPresent;
node = nextPhysical(node))
;
return node;
}
// Invalidates `self`, replacing it with a node of at least capacity.
// Does not return nodes to freelists when kUseFreeList is false.
void freeAndMakeCapacityAtLeast(Node *&self, int capacity, WriteContext *tls,
ConflictSet::Impl *impl,
const bool kUseFreeList) {
switch (self->getType()) {
case Type_Node0: {
auto *self0 = (Node0 *)self;
auto *newSelf = tls->allocate<Node0>(capacity);
newSelf->copyChildrenAndKeyFrom(*self0);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self0);
} else {
removeNode(self0);
safe_free(self0, self0->size());
}
self = newSelf;
} break;
case Type_Node3: {
auto *self3 = (Node3 *)self;
auto *newSelf = tls->allocate<Node3>(capacity);
newSelf->copyChildrenAndKeyFrom(*self3);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self3);
} else {
removeNode(self3);
safe_free(self3, self3->size());
}
self = newSelf;
} break;
case Type_Node16: {
auto *self16 = (Node16 *)self;
auto *newSelf = tls->allocate<Node16>(capacity);
newSelf->copyChildrenAndKeyFrom(*self16);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self16);
} else {
removeNode(self16);
safe_free(self16, self16->size());
}
self = newSelf;
} break;
case Type_Node48: {
auto *self48 = (Node48 *)self;
auto *newSelf = tls->allocate<Node48>(capacity);
newSelf->copyChildrenAndKeyFrom(*self48);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self48);
} else {
removeNode(self48);
safe_free(self48, self48->size());
}
self = newSelf;
} break;
case Type_Node256: {
auto *self256 = (Node256 *)self;
auto *newSelf = tls->allocate<Node256>(capacity);
newSelf->copyChildrenAndKeyFrom(*self256);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self256);
} else {
removeNode(self256);
safe_free(self256, self256->size());
}
self = newSelf;
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Fix larger-than-desired capacities. Does not return nodes to freelists,
// since that wouldn't actually reclaim the memory used for partial key
// capacity.
void maybeDecreaseCapacity(Node *&self, WriteContext *tls,
ConflictSet::Impl *impl) {
const int maxCapacity =
(self->numChildren + int(self->entryPresent)) * (self->partialKeyLen + 1);
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "maybeDecreaseCapacity: current: %d, max: %d, key: %s\n",
self->getCapacity(), maxCapacity,
getSearchPathPrintable(self).c_str());
#endif
if (self->getCapacity() <= maxCapacity) {
return;
}
freeAndMakeCapacityAtLeast(self, maxCapacity, tls, impl, false);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) void rezero16(InternalVersionT *vs,
InternalVersionT zero) {
uint32_t z;
memcpy(&z, &zero, sizeof(z));
const auto zvec = _mm512_set1_epi32(z);
auto m = _mm512_cmplt_epi32_mask(
_mm512_sub_epi32(_mm512_loadu_epi32(vs), zvec), _mm512_setzero_epi32());
_mm512_mask_storeu_epi32(vs, m, zvec);
}
__attribute__((target("default")))
#endif
void rezero16(InternalVersionT *vs, InternalVersionT zero) {
for (int i = 0; i < 16; ++i) {
vs[i] = std::max(vs[i], zero);
}
}
#if USE_64_BIT
void rezero(Node *, InternalVersionT) {}
#else
void rezero(Node *n, InternalVersionT z) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "rezero to %" PRId64 ": %s\n", z.toInt64(),
getSearchPathPrintable(n).c_str());
#endif
if (n->entryPresent) {
n->entry.pointVersion = std::max(n->entry.pointVersion, z);
n->entry.rangeVersion = std::max(n->entry.rangeVersion, z);
}
switch (n->getType()) {
case Type_Node0: {
} break;
case Type_Node3: {
auto *self = static_cast<Node3 *>(n);
for (int i = 0; i < 3; ++i) {
self->childMaxVersion[i] = std::max(self->childMaxVersion[i], z);
}
} break;
case Type_Node16: {
auto *self = static_cast<Node16 *>(n);
rezero16(self->childMaxVersion, z);
} break;
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
for (int i = 0; i < 48; i += 16) {
rezero16(self->childMaxVersion + i, z);
}
for (auto &m : self->maxOfMax) {
m = std::max(m, z);
}
} break;
case Type_Node256: {
auto *self = static_cast<Node256 *>(n);
for (int i = 0; i < 256; i += 16) {
rezero16(self->childMaxVersion + i, z);
}
rezero16(self->maxOfMax, z);
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
#endif
void mergeWithChild(TaggedNodePointer &self, WriteContext *tls,
ConflictSet::Impl *impl, Node *&dontInvalidate,
Node3 *self3) {
assert(!self3->entryPresent);
Node *child = self3->children[0];
int minCapacity = self3->partialKeyLen + 1 + child->partialKeyLen;
if (minCapacity > child->getCapacity()) {
const bool update = child == dontInvalidate;
freeAndMakeCapacityAtLeast(child, minCapacity, tls, impl, true);
if (update) {
dontInvalidate = child;
}
}
// Merge partial key with child
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Merge %s into %s\n", getSearchPathPrintable(self).c_str(),
getSearchPathPrintable(child).c_str());
#endif
InternalVersionT childMaxVersion = self3->childMaxVersion[0];
// Construct new partial key for child
memmove(child->partialKey() + self3->partialKeyLen + 1, child->partialKey(),
child->partialKeyLen);
memcpy(child->partialKey(), self3->partialKey(), self->partialKeyLen);
child->partialKey()[self3->partialKeyLen] = self3->index[0];
child->partialKeyLen += 1 + self3->partialKeyLen;
child->parent = self->parent;
child->parentsIndex = self->parentsIndex;
// Max versions are stored in the parent, so we need to update it now
// that we have a new parent. Safe we call since the root never has a partial
// key.
setMaxVersion(child, std::max(childMaxVersion, tls->zero));
self = child;
tls->release(self3);
}
bool needsDownsize(Node *n) {
static int minTable[] = {0, kMinChildrenNode3, kMinChildrenNode16,
kMinChildrenNode48, kMinChildrenNode256};
return n->numChildren + n->entryPresent < minTable[n->getType()];
}
void downsize(Node3 *self, WriteContext *tls, ConflictSet::Impl *impl,
Node *&dontInvalidate) {
if (self->numChildren == 0) {
auto *newSelf = tls->allocate<Node0>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self);
getInTree(self, impl) = newSelf;
tls->release(self);
} else {
assert(self->numChildren == 1 && !self->entryPresent);
mergeWithChild(getInTree(self, impl), tls, impl, dontInvalidate, self);
}
}
void downsize(Node16 *self, WriteContext *tls, ConflictSet::Impl *impl) {
assert(self->numChildren + int(self->entryPresent) < kMinChildrenNode16);
auto *newSelf = tls->allocate<Node3>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self);
getInTree(self, impl) = newSelf;
tls->release(self);
}
void downsize(Node48 *self, WriteContext *tls, ConflictSet::Impl *impl) {
assert(self->numChildren + int(self->entryPresent) < kMinChildrenNode48);
auto *newSelf = tls->allocate<Node16>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self);
getInTree(self, impl) = newSelf;
tls->release(self);
}
void downsize(Node256 *self, WriteContext *tls, ConflictSet::Impl *impl) {
assert(self->numChildren + int(self->entryPresent) < kMinChildrenNode256);
auto *self256 = (Node256 *)self;
auto *newSelf = tls->allocate<Node48>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self256);
getInTree(self, impl) = newSelf;
tls->release(self256);
}
void downsize(Node *self, WriteContext *tls, ConflictSet::Impl *impl,
Node *&dontInvalidate) {
switch (self->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3:
downsize(static_cast<Node3 *>(self), tls, impl, dontInvalidate);
break;
case Type_Node16:
downsize(static_cast<Node16 *>(self), tls, impl);
break;
case Type_Node48:
downsize(static_cast<Node48 *>(self), tls, impl);
break;
case Type_Node256:
downsize(static_cast<Node256 *>(self), tls, impl);
break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition: self is not the root. May invalidate nodes along the search
// path to self. May invalidate children of self->parent. Returns a pointer to
// the node after self. Precondition: `self->entryPresent`
Node *erase(Node *self, WriteContext *tls, ConflictSet::Impl *impl,
bool logical) {
++tls->accum.entries_erased;
assert(self->parent != nullptr);
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Erase: %s\n", getSearchPathPrintable(self).c_str());
#endif
Node *parent = self->parent;
uint8_t parentsIndex = self->parentsIndex;
auto *result = logical ? nextLogical(self) : nextPhysical(self);
removeKey(self);
assert(self->entryPresent);
self->entryPresent = false;
if (self->numChildren != 0) {
if (needsDownsize(self)) {
downsize(self, tls, impl, result);
}
return result;
}
assert(self->getType() == Type_Node0);
tls->release((Node0 *)self);
switch (parent->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *parent3 = static_cast<Node3 *>(parent);
int nodeIndex = getNodeIndex(parent3, parentsIndex);
assert(nodeIndex >= 0);
--parent->numChildren;
for (int i = nodeIndex; i < parent->numChildren; ++i) {
parent3->index[i] = parent3->index[i + 1];
parent3->children[i] = parent3->children[i + 1];
parent3->childMaxVersion[i] = parent3->childMaxVersion[i + 1];
}
if (needsDownsize(parent3)) {
downsize(parent3, tls, impl, result);
}
} break;
case Type_Node16: {
auto *parent16 = static_cast<Node16 *>(parent);
int nodeIndex = getNodeIndex(parent16, parentsIndex);
assert(nodeIndex >= 0);
--parent->numChildren;
for (int i = nodeIndex; i < parent->numChildren; ++i) {
parent16->index[i] = parent16->index[i + 1];
parent16->children[i] = parent16->children[i + 1];
parent16->childMaxVersion[i] = parent16->childMaxVersion[i + 1];
}
if (needsDownsize(parent16)) {
downsize(parent16, tls, impl, result);
}
} break;
case Type_Node48: {
auto *parent48 = static_cast<Node48 *>(parent);
parent48->bitSet.reset(parentsIndex);
int8_t toRemoveChildrenIndex =
std::exchange(parent48->index[parentsIndex], -1);
auto lastChildrenIndex = --parent48->numChildren;
assert(toRemoveChildrenIndex >= 0);
assert(lastChildrenIndex >= 0);
if (toRemoveChildrenIndex != lastChildrenIndex) {
parent48->children[toRemoveChildrenIndex] =
parent48->children[lastChildrenIndex];
parent48->childMaxVersion[toRemoveChildrenIndex] =
parent48->childMaxVersion[lastChildrenIndex];
parent48->maxOfMax[toRemoveChildrenIndex >> Node48::kMaxOfMaxShift] =
std::max(parent48->maxOfMax[toRemoveChildrenIndex >>
Node48::kMaxOfMaxShift],
parent48->childMaxVersion[toRemoveChildrenIndex]);
auto parentIndex =
parent48->children[toRemoveChildrenIndex]->parentsIndex;
parent48->index[parentIndex] = toRemoveChildrenIndex;
parent48->reverseIndex[toRemoveChildrenIndex] = parentIndex;
}
parent48->childMaxVersion[lastChildrenIndex] = tls->zero;
if (needsDownsize(parent48)) {
downsize(parent48, tls, impl, result);
}
} break;
case Type_Node256: {
auto *parent256 = static_cast<Node256 *>(parent);
parent256->bitSet.reset(parentsIndex);
parent256->children[parentsIndex] = nullptr;
--parent->numChildren;
if (needsDownsize(parent256)) {
downsize(parent256, tls, impl, result);
}
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
return result;
}
Node *nextSibling(Node *node) {
for (;;) {
if (node->parent == nullptr) {
return nullptr;
}
auto next = getChildGeq(node->parent, node->parentsIndex + 1);
if (next == nullptr) {
node = node->parent;
} else {
return next;
}
}
}
#ifdef HAS_AVX
uint32_t compare16(const InternalVersionT *vs, InternalVersionT rv) {
#if USE_64_BIT
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (vs[i] > rv) << i;
}
return compared;
#else
uint32_t compared = 0;
__m128i w[4]; // GCOVR_EXCL_LINE
memcpy(w, vs, sizeof(w));
uint32_t r; // GCOVR_EXCL_LINE
memcpy(&r, &rv, sizeof(r));
const auto rvVec = _mm_set1_epi32(r);
const auto zero = _mm_setzero_si128();
for (int i = 0; i < 4; ++i) {
compared |= _mm_movemask_ps(
__m128(_mm_cmpgt_epi32(_mm_sub_epi32(w[i], rvVec), zero)))
<< (i * 4);
}
return compared;
#endif
}
__attribute__((target("avx512f"))) uint32_t
compare16_avx512(const InternalVersionT *vs, InternalVersionT rv) {
#if USE_64_BIT
int64_t r;
memcpy(&r, &rv, sizeof(r));
uint32_t low =
_mm512_cmpgt_epi64_mask(_mm512_loadu_epi64(vs), _mm512_set1_epi64(r));
uint32_t high =
_mm512_cmpgt_epi64_mask(_mm512_loadu_epi64(vs + 8), _mm512_set1_epi64(r));
return low | (high << 8);
#else
uint32_t r;
memcpy(&r, &rv, sizeof(r));
return _mm512_cmpgt_epi32_mask(
_mm512_sub_epi32(_mm512_loadu_epi32(vs), _mm512_set1_epi32(r)),
_mm512_setzero_epi32());
#endif
}
#endif
// Returns true if v[i] <= readVersion for all i such that begin <= is[i] < end
// Preconditions: begin <= end, end - begin < 256
template <bool kAVX512>
bool scan16(const InternalVersionT *vs, const uint8_t *is, int begin, int end,
InternalVersionT readVersion) {
assert(begin <= end);
assert(end - begin < 256);
#ifdef HAS_ARM_NEON
uint8x16_t indices;
memcpy(&indices, is, 16);
// 0xff for each in bounds
auto results =
vcltq_u8(vsubq_u8(indices, vdupq_n_u8(begin)), vdupq_n_u8(end - begin));
// 0xf for each 0xff
uint64_t mask = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0);
uint32x4_t w4[4];
memcpy(w4, vs, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] =
vmovn_u32(vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t compared = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
return !(compared & mask);
#elif defined(HAS_AVX)
__m128i indices;
memcpy(&indices, is, 16);
indices = _mm_sub_epi8(indices, _mm_set1_epi8(begin));
uint32_t mask = ~_mm_movemask_epi8(_mm_cmpeq_epi8(
indices, _mm_max_epu8(indices, _mm_set1_epi8(end - begin))));
uint32_t compared = 0;
if constexpr (kAVX512) {
compared = compare16_avx512(vs, readVersion);
} else {
compared = compare16(vs, readVersion);
}
return !(compared & mask);
#else
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) { return c - shiftAmount < shiftUpperBound; };
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (vs[i] > readVersion) << i;
}
uint32_t mask = 0;
for (int i = 0; i < 16; ++i) {
mask |= inBounds(is[i]) << i;
}
return !(compared & mask);
#endif
}
// Returns true if v[i] <= readVersion for all i such that begin <= i < end
template <bool kAVX512>
bool scan16(const InternalVersionT *vs, int begin, int end,
InternalVersionT readVersion) {
assert(0 <= begin && begin < 16);
assert(0 <= end && end <= 16);
assert(begin <= end);
#if defined(HAS_ARM_NEON)
uint32x4_t w4[4];
memcpy(w4, vs, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] =
vmovn_u32(vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t conflict = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
conflict &= end == 16 ? -1 : (uint64_t(1) << (end << 2)) - 1;
conflict >>= begin << 2;
return !conflict;
#elif defined(HAS_AVX)
uint32_t conflict;
if constexpr (kAVX512) {
conflict = compare16_avx512(vs, readVersion);
} else {
conflict = compare16(vs, readVersion);
}
conflict &= (1 << end) - 1;
conflict >>= begin;
return !conflict;
#else
uint64_t conflict = 0;
for (int i = 0; i < 16; ++i) {
conflict |= (vs[i] > readVersion) << i;
}
conflict &= (1 << end) - 1;
conflict >>= begin;
return !conflict;
#endif
}
// Return whether or not the max version among all keys starting with the search
// path of n + [child], where child in (begin, end) is <= readVersion. Does not
// account for the range version of firstGt(searchpath(n) + [end - 1])
template <bool kAVX512>
bool checkMaxBetweenExclusiveImpl(Node *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *tls) {
++tls->range_read_node_scan_accum;
assume(-1 <= begin);
assume(begin <= 256);
assume(-1 <= end);
assume(end <= 256);
assume(begin < end);
assert(!(begin == -1 && end == 256));
switch (n->getType()) {
case Type_Node0:
return true;
case Type_Node3: {
auto *self = static_cast<Node3 *>(n);
++begin;
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) {
return c - shiftAmount < shiftUpperBound;
};
uint32_t mask = 0;
for (int i = 0; i < Node3::kMaxNodes; ++i) {
mask |= inBounds(self->index[i]) << i;
}
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
for (int i = 0; i < Node3::kMaxNodes; ++i) {
compared |= (self->childMaxVersion[i] > readVersion) << i;
}
return !(compared & mask) && firstRangeOk;
}
case Type_Node16: {
auto *self = static_cast<Node16 *>(n);
++begin;
assert(begin <= end);
assert(end - begin < 256);
#ifdef HAS_ARM_NEON
uint8x16_t indices;
memcpy(&indices, self->index, 16);
// 0xff for each in bounds
auto results =
vcltq_u8(vsubq_u8(indices, vdupq_n_u8(begin)), vdupq_n_u8(end - begin));
// 0xf for each 0xff
uint64_t mask = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0);
mask &= self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren << 2)) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask) >> 2];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32x4_t w4[4];
memcpy(w4, self->childMaxVersion, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] = vmovn_u32(
vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t compared = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
return !(compared & mask) && firstRangeOk;
#elif defined(HAS_AVX)
__m128i indices;
memcpy(&indices, self->index, 16);
indices = _mm_sub_epi8(indices, _mm_set1_epi8(begin));
uint32_t mask =
0xffff &
~_mm_movemask_epi8(_mm_cmpeq_epi8(
indices, _mm_max_epu8(indices, _mm_set1_epi8(end - begin))));
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
if constexpr (kAVX512) {
compared = compare16_avx512(self->childMaxVersion, readVersion);
} else {
compared = compare16(self->childMaxVersion, readVersion);
}
return !(compared & mask) && firstRangeOk;
#else
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) {
return c - shiftAmount < shiftUpperBound;
};
uint32_t mask = 0;
for (int i = 0; i < 16; ++i) {
mask |= inBounds(self->index[i]) << i;
}
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
Node *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (self->childMaxVersion[i] > readVersion) << i;
}
return !(compared & mask) && firstRangeOk;
#endif
}
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
{
int c = self->bitSet.firstSetGeq(begin + 1);
if (c >= 0 && c < end) {
Node *child = self->children[self->index[c]];
if (child->entryPresent && child->entry.rangeVersion > readVersion) {
return false;
}
begin = c;
} else {
return true;
}
// [begin, end) is now the half-open interval of children we're interested
// in.
assert(begin < end);
}
// Check all pages
static_assert(Node48::kMaxOfMaxPageSize == 16);
for (int i = 0; i < Node48::kMaxOfMaxTotalPages; ++i) {
if (self->maxOfMax[i] > readVersion) {
if (!scan16<kAVX512>(self->childMaxVersion +
(i << Node48::kMaxOfMaxShift),
self->reverseIndex + (i << Node48::kMaxOfMaxShift),
begin, end, readVersion)) {
return false;
}
}
}
return true;
}
case Type_Node256: {
static_assert(Node256::kMaxOfMaxTotalPages == 16);
auto *self = static_cast<Node256 *>(n);
{
int c = self->bitSet.firstSetGeq(begin + 1);
if (c >= 0 && c < end) {
Node *child = self->children[c];
if (child->entryPresent && child->entry.rangeVersion > readVersion) {
return false;
}
begin = c;
} else {
return true;
}
// [begin, end) is now the half-open interval of children we're interested
// in.
assert(begin < end);
}
const int firstPage = begin >> Node256::kMaxOfMaxShift;
const int lastPage = (end - 1) >> Node256::kMaxOfMaxShift;
// Check the only page if there's only one
if (firstPage == lastPage) {
if (self->maxOfMax[firstPage] <= readVersion) {
return true;
}
const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1);
const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift);
return scan16<kAVX512>(self->childMaxVersion +
(firstPage << Node256::kMaxOfMaxShift),
intraPageBegin, intraPageEnd, readVersion);
}
// Check the first page
if (self->maxOfMax[firstPage] > readVersion) {
const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1);
if (!scan16<kAVX512>(self->childMaxVersion +
(firstPage << Node256::kMaxOfMaxShift),
intraPageBegin, 16, readVersion)) {
return false;
}
}
// Check the last page
if (self->maxOfMax[lastPage] > readVersion) {
const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift);
if (!scan16<kAVX512>(self->childMaxVersion +
(lastPage << Node256::kMaxOfMaxShift),
0, intraPageEnd, readVersion)) {
return false;
}
}
// Check inner pages
return scan16<kAVX512>(self->maxOfMax, firstPage + 1, lastPage,
readVersion);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) bool
checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *tls) {
return checkMaxBetweenExclusiveImpl<true>(n, begin, end, readVersion, tls);
}
__attribute__((target("default")))
#endif
bool checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *tls) {
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion, tls);
}
Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
assert(n != nullptr);
auto result = vector<uint8_t>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
return result;
} // GCOVR_EXCL_LINE
// Return true if the max version among all keys that start with key + [child],
// where begin < child < end, is <= readVersion.
//
// Precondition: transitively, no child of n has a search path that's a longer
// prefix of key than n
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
int end, InternalVersionT readVersion,
ReadContext *tls) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
#endif
auto remaining = key;
if (remaining.size() == 0) {
return checkMaxBetweenExclusive(n, begin, end, readVersion, tls);
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
assert(n->partialKeyLen > 0);
{
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
assert(n->partialKeyLen > int(remaining.size()));
if (begin < n->partialKey()[remaining.size()] &&
n->partialKey()[remaining.size()] < end) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n) <= readVersion;
}
return true;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
namespace {
// Return true if the max version among all keys that start with key[:prefixLen]
// that are < key is <= readVersion
bool checkRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen,
InternalVersionT readVersion, ReadContext *tls) {
auto remaining = key;
int searchPathLen = 0;
for (;; ++tls->range_read_iterations_accum) {
assert(searchPathLen <= int(key.size()));
if (remaining.size() == 0) {
goto downLeftSpine;
}
if (searchPathLen >= prefixLen) {
if (n->entryPresent && n->entry.pointVersion > readVersion) {
return false;
}
if (!checkMaxBetweenExclusive(n, -1, remaining[0], readVersion, tls)) {
return false;
}
}
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
return false;
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
goto backtrack;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
++searchPathLen;
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
return false;
}
goto backtrack;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
goto downLeftSpine;
}
}
}
backtrack:
for (;;) {
// searchPathLen > prefixLen implies n is not the root
if (searchPathLen > prefixLen && maxVersion(n) > readVersion) {
return false;
}
if (n->parent == nullptr) {
return true;
}
auto next = getChildGeq(n->parent, n->parentsIndex + 1);
if (next == nullptr) {
searchPathLen -= 1 + n->partialKeyLen;
n = n->parent;
} else {
searchPathLen -= n->partialKeyLen;
n = next;
searchPathLen += n->partialKeyLen;
goto downLeftSpine;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
} // namespace
#ifdef __x86_64__
// Explicitly instantiate with target avx512f attribute so the compiler can
// inline compare16_32bit_avx512, and generally use avx512f within more
// functions
template __attribute__((target("avx512f"))) bool
scan16<true>(const InternalVersionT *vs, const uint8_t *is, int begin, int end,
InternalVersionT readVersion);
template __attribute__((target("avx512f"))) bool
scan16<true>(const InternalVersionT *vs, int begin, int end,
InternalVersionT readVersion);
template __attribute__((target("avx512f"))) bool
checkMaxBetweenExclusiveImpl<true>(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *);
#endif
// Returns a pointer the pointer to the newly inserted node in the tree. Caller
// must set `entryPresent`, and `entry` fields. All nodes along the search path
// of the result will have `maxVersion` set to `writeVersion` as a
// postcondition. Nodes along the search path may be invalidated. Callers must
// ensure that the max version of the self argument is updated.
[[nodiscard]] TaggedNodePointer *insert(TaggedNodePointer *self,
std::span<const uint8_t> key,
InternalVersionT writeVersion,
WriteContext *tls) {
for (; key.size() != 0; ++tls->accum.insert_iterations) {
self = &getOrCreateChild(*self, key, writeVersion, tls);
}
return self;
}
void eraseTree(Node *root, WriteContext *tls) {
Arena arena;
auto toFree = vector<Node *>(arena);
toFree.push_back(root);
while (toFree.size() > 0) {
auto *n = toFree.back();
toFree.pop_back();
tls->accum.entries_erased += n->entryPresent;
++tls->accum.nodes_released;
removeKey(n);
switch (n->getType()) {
case Type_Node0: {
auto *n0 = static_cast<Node0 *>(n);
tls->release(n0);
} break;
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
for (int i = 0; i < n3->numChildren; ++i) {
toFree.push_back(n3->children[i]);
}
tls->release(n3);
} break;
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
for (int i = 0; i < n16->numChildren; ++i) {
toFree.push_back(n16->children[i]);
}
tls->release(n16);
} break;
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
for (int i = 0; i < n48->numChildren; ++i) {
toFree.push_back(n48->children[i]);
}
tls->release(n48);
} break;
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
auto *out = toFree.unsafePrepareAppend(n256->numChildren).data();
n256->bitSet.forEachSet([&](int i) { *out++ = n256->children[i]; });
assert(out == toFree.end());
tls->release(n256);
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
}
void addPointWrite(TaggedNodePointer &root, std::span<const uint8_t> key,
InternalVersionT writeVersion, WriteContext *tls) {
++tls->accum.point_writes;
auto n = *insert(&root, key, writeVersion, tls);
if (!n->entryPresent) {
++tls->accum.entries_inserted;
auto *p = nextLogical(n);
addKey(n);
n->entryPresent = true;
n->entry.pointVersion = writeVersion;
n->entry.rangeVersion =
p == nullptr ? tls->zero : std::max(p->entry.rangeVersion, tls->zero);
} else {
assert(writeVersion >= n->entry.pointVersion);
n->entry.pointVersion = writeVersion;
}
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) InternalVersionT horizontalMaxUpTo16(
InternalVersionT *vs, [[maybe_unused]] InternalVersionT z, int len) {
assume(len <= 16);
#if USE_64_BIT
// Hope it gets vectorized
InternalVersionT max = vs[0];
for (int i = 1; i < len; ++i) {
max = std::max(vs[i], max);
}
return max;
#else
uint32_t zero;
memcpy(&zero, &z, sizeof(zero));
auto zeroVec = _mm512_set1_epi32(zero);
auto max = InternalVersionT(
zero +
_mm512_reduce_max_epi32(_mm512_sub_epi32(
_mm512_mask_loadu_epi32(zeroVec, _mm512_int2mask((1 << len) - 1), vs),
zeroVec)));
return max;
#endif
}
__attribute__((target("default")))
#endif
InternalVersionT
horizontalMaxUpTo16(InternalVersionT *vs, InternalVersionT, int len) {
assume(len <= 16);
InternalVersionT max = vs[0];
for (int i = 1; i < len; ++i) {
max = std::max(vs[i], max);
}
return max;
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f"))) InternalVersionT
horizontalMax16(InternalVersionT *vs, [[maybe_unused]] InternalVersionT z) {
#if USE_64_BIT
// Hope it gets vectorized
InternalVersionT max = vs[0];
for (int i = 1; i < 16; ++i) {
max = std::max(vs[i], max);
}
return max;
#else
uint32_t zero; // GCOVR_EXCL_LINE
memcpy(&zero, &z, sizeof(zero));
auto zeroVec = _mm512_set1_epi32(zero);
return InternalVersionT(zero + _mm512_reduce_max_epi32(_mm512_sub_epi32(
_mm512_loadu_epi32(vs), zeroVec)));
#endif
}
__attribute__((target("default")))
#endif
InternalVersionT
horizontalMax16(InternalVersionT *vs, InternalVersionT) {
InternalVersionT max = vs[0];
for (int i = 1; i < 16; ++i) {
max = std::max(vs[i], max);
}
return max;
}
// Precondition: `node->entryPresent`, and node is not the root
void fixupMaxVersion(Node *node, WriteContext *tls) {
assert(node->parent);
InternalVersionT max;
assert(node->entryPresent);
max = std::max(node->entry.pointVersion, tls->zero);
switch (node->getType()) {
case Type_Node0:
break;
case Type_Node3: {
auto *self3 = static_cast<Node3 *>(node);
max = std::max(max, horizontalMaxUpTo16(self3->childMaxVersion, tls->zero,
self3->numChildren));
} break;
case Type_Node16: {
auto *self16 = static_cast<Node16 *>(node);
max = std::max(max, horizontalMaxUpTo16(self16->childMaxVersion, tls->zero,
self16->numChildren));
} break;
case Type_Node48: {
auto *self48 = static_cast<Node48 *>(node);
for (auto v : self48->maxOfMax) {
max = std::max(v, max);
}
} break;
case Type_Node256: {
auto *self256 = static_cast<Node256 *>(node);
max = std::max(max, horizontalMax16(self256->childMaxVersion, tls->zero));
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
setMaxVersion(node, max);
}
void addWriteRange(TaggedNodePointer &root, std::span<const uint8_t> begin,
std::span<const uint8_t> end, InternalVersionT writeVersion,
WriteContext *tls, ConflictSet::Impl *impl) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return addPointWrite(root, begin, writeVersion, tls);
}
++tls->accum.range_writes;
const bool beginIsPrefix = lcp == int(begin.size());
auto useAsRoot = insert(&root, begin.subspan(0, lcp), writeVersion, tls);
begin = begin.subspan(lcp, begin.size() - lcp);
end = end.subspan(lcp, end.size() - lcp);
Node *beginNode = *insert(useAsRoot, begin, writeVersion, tls);
addKey(beginNode);
if (!beginNode->entryPresent) {
++tls->accum.entries_inserted;
auto *p = nextLogical(beginNode);
beginNode->entry.rangeVersion =
p == nullptr ? tls->zero : std::max(p->entry.rangeVersion, tls->zero);
beginNode->entryPresent = true;
}
beginNode->entry.pointVersion = writeVersion;
Node *endNode = *insert(useAsRoot, end, writeVersion, tls);
addKey(endNode);
if (!endNode->entryPresent) {
++tls->accum.entries_inserted;
auto *p = nextLogical(endNode);
endNode->entry.pointVersion =
p == nullptr ? tls->zero : std::max(p->entry.rangeVersion, tls->zero);
if (beginIsPrefix) {
// beginNode may have been invalidated when inserting end
beginNode = *useAsRoot;
assert(beginNode->entryPresent);
}
endNode->entryPresent = true;
}
endNode->entry.rangeVersion = writeVersion;
// Erase nodes in range
assert(!beginNode->endOfRange);
assert(!endNode->endOfRange);
endNode->endOfRange = true;
Node *iter = beginNode;
for (iter = nextLogical(iter); !iter->endOfRange;
iter = erase(iter, tls, impl, /*logical*/ true)) {
assert(!iter->endOfRange);
}
assert(iter->endOfRange);
iter->endOfRange = false;
// Inserting end trashed the last node's maxVersion. Fix that. Safe to call
// since the end key always has non-zero size.
fixupMaxVersion(iter, tls);
}
Node *firstGeqPhysical(Node *n, const std::span<const uint8_t> key) {
auto remaining = key;
for (;;) {
if (remaining.size() == 0) {
return n;
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
return n;
} else {
n = nextSibling(n);
if (n == nullptr) {
// This line is genuinely unreachable from any entry point of the
// final library, since we can't remove a key without introducing a
// key after it, and the only production caller of firstGeq is for
// resuming the setOldestVersion scan.
return nullptr; // GCOVR_EXCL_LINE
}
return n;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
return n;
} else {
n = nextSibling(n);
return n;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
return n;
}
}
}
}
#ifndef __has_attribute
#define __has_attribute(x) 0
#endif
#if __has_attribute(musttail)
#define MUSTTAIL __attribute__((musttail))
#else
#define MUSTTAIL
#endif
#if __has_attribute(preserve_none)
#define PRESERVE_NONE __attribute__((preserve_none))
#else
#define PRESERVE_NONE
#endif
typedef PRESERVE_NONE void (*Continuation)(struct CheckJob *,
struct CheckContext *);
// State relevant to an individual query
struct CheckJob {
void setResult(bool ok) {
*result = ok ? ConflictSet::Commit : ConflictSet::Conflict;
}
void init(const ConflictSet::ReadRange *read, ConflictSet::Result *result,
Node *root, int64_t oldestVersionFullPrecision);
Node *n;
std::span<const uint8_t> begin;
InternalVersionT maxV;
std::span<const uint8_t> end; // range read only
std::span<const uint8_t> remaining; // range read only
Node *child; // range read only
int lcp; // range read only
Node *commonPrefixNode; // range read only
InternalVersionT readVersion;
ConflictSet::Result *result;
Continuation continuation;
CheckJob *prev;
CheckJob *next;
};
// State relevant to every query
struct CheckContext {
int count;
int64_t oldestVersionFullPrecision;
Node *root;
const ConflictSet::ReadRange *queries;
ConflictSet::Result *results;
int64_t started;
ReadContext *tls;
#if !__has_attribute(musttail)
CheckJob *job;
bool done;
#endif
};
PRESERVE_NONE void keepGoing(CheckJob *job, CheckContext *context) {
#if __has_attribute(musttail)
job = job->next;
MUSTTAIL return job->continuation(job, context);
#else
context->job = job->next;
return;
#endif
}
PRESERVE_NONE void complete(CheckJob *job, CheckContext *context) {
if (context->started == context->count) {
if (job->prev == job) {
#if !__has_attribute(musttail)
context->done = true;
#endif
return;
}
job->prev->next = job->next;
job->next->prev = job->prev;
job = job->prev;
} else {
int temp = context->started++;
job->init(context->queries + temp, context->results + temp, context->root,
context->oldestVersionFullPrecision);
}
MUSTTAIL return keepGoing(job, context);
}
namespace check_point_read_state_machine {
PRESERVE_NONE void begin(CheckJob *, CheckContext *);
template <class NodeT> PRESERVE_NONE void iter(CheckJob *, CheckContext *);
PRESERVE_NONE void down_left_spine(CheckJob *, CheckContext *);
static Continuation iterTable[] = {iter<Node0>, iter<Node3>, iter<Node16>,
iter<Node48>, iter<Node256>};
void begin(CheckJob *job, CheckContext *context) {
++context->tls->point_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif
if (job->begin.size() == 0) [[unlikely]] {
// We don't erase the root
assert(job->n->entryPresent);
job->setResult(job->n->entry.pointVersion <= job->readVersion);
MUSTTAIL return complete(job, context);
}
auto [taggedChild, maxV] = getChildAndMaxVersion(job->n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(job->n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
// The root never has a next sibling
job->setResult(true);
MUSTTAIL return complete(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT> void iter(CheckJob *job, CheckContext *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
job->begin = job->begin.subspan(1, job->begin.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, job->begin.size());
int i = longestCommonPrefix(n->partialKey(), job->begin.data(), commonLen);
if (i < commonLen) [[unlikely]] {
auto c = n->partialKey()[i] <=> job->begin[i];
if (c > 0) {
job->continuation = down_left_spine;
MUSTTAIL return down_left_spine(job, context);
} else {
job->n = nextSibling(n);
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
job->begin = job->begin.subspan(commonLen, job->begin.size() - commonLen);
} else if (n->partialKeyLen > int(job->begin.size())) [[unlikely]] {
// n is the first physical node greater than remaining, and there's no
// eq node
job->continuation = down_left_spine;
MUSTTAIL return down_left_spine(job, context);
}
}
++context->tls->point_read_iterations_accum;
if (job->maxV <= job->readVersion) {
job->setResult(true);
++context->tls->point_read_short_circuit_accum;
MUSTTAIL return complete(job, context);
}
if (job->begin.size() == 0) [[unlikely]] {
if (n->entryPresent) {
job->setResult(n->entry.pointVersion <= job->readVersion);
MUSTTAIL return complete(job, context);
}
job->n = getFirstChildExists(n);
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
auto [taggedChild, maxV] = getChildAndMaxVersion(n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
job->n = nextSibling(job->n);
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
void down_left_spine(CheckJob *job, CheckContext *context) {
if (job->n->entryPresent) {
job->setResult(job->n->entry.rangeVersion <= job->readVersion);
MUSTTAIL return complete(job, context);
}
job->n = getFirstChildExists(job->n);
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
} // namespace check_point_read_state_machine
namespace check_prefix_read_state_machine {
PRESERVE_NONE void begin(CheckJob *, CheckContext *);
template <class NodeT> PRESERVE_NONE void iter(CheckJob *, CheckContext *);
PRESERVE_NONE void down_left_spine(CheckJob *, CheckContext *);
static Continuation iterTable[] = {iter<Node0>, iter<Node3>, iter<Node16>,
iter<Node48>, iter<Node256>};
void begin(CheckJob *job, CheckContext *context) {
++context->tls->prefix_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str());
#endif
// There's no way to encode a prefix read of ""
assert(job->begin.size() > 0);
auto [taggedChild, maxV] = getChildAndMaxVersion(job->n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(job->n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
// The root never has a next sibling
job->setResult(true);
MUSTTAIL return complete(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT> void iter(CheckJob *job, CheckContext *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
job->begin = job->begin.subspan(1, job->begin.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, job->begin.size());
int i = longestCommonPrefix(n->partialKey(), job->begin.data(), commonLen);
if (i < commonLen) [[unlikely]] {
auto c = n->partialKey()[i] <=> job->begin[i];
if (c > 0) {
job->continuation = down_left_spine;
MUSTTAIL return down_left_spine(job, context);
} else {
job->n = nextSibling(n);
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
job->begin = job->begin.subspan(commonLen, job->begin.size() - commonLen);
} else if (n->partialKeyLen > int(job->begin.size())) [[unlikely]] {
// n is the first physical node greater than remaining, and there's no
// eq node. All physical nodes that start with prefix are reachable from
// n.
if (job->maxV > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
job->continuation = down_left_spine;
MUSTTAIL return down_left_spine(job, context);
}
}
++context->tls->prefix_read_iterations_accum;
if (job->maxV <= job->readVersion) {
job->setResult(true);
++context->tls->prefix_read_short_circuit_accum;
MUSTTAIL return complete(job, context);
}
if (job->begin.size() == 0) [[unlikely]] {
job->setResult(job->maxV <= job->readVersion);
MUSTTAIL return complete(job, context);
}
auto [taggedChild, maxV] = getChildAndMaxVersion(n, job->begin[0]);
job->maxV = maxV;
Node *child = taggedChild;
if (child == nullptr) [[unlikely]] {
auto c = getChildGeq(n, job->begin[0]);
if (c != nullptr) {
job->n = c;
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
job->n = nextSibling(job->n);
if (job->n == nullptr) {
job->setResult(true);
MUSTTAIL return complete(job, context);
}
job->continuation = down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->continuation = iterTable[taggedChild.getType()];
job->n = child;
__builtin_prefetch(child);
MUSTTAIL return keepGoing(job, context);
}
void down_left_spine(CheckJob *job, CheckContext *context) {
if (job->n->entryPresent) {
job->setResult(job->n->entry.rangeVersion <= job->readVersion);
MUSTTAIL return complete(job, context);
}
job->n = getFirstChildExists(job->n);
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
} // namespace check_prefix_read_state_machine
namespace check_range_read_state_machine {
PRESERVE_NONE void begin(CheckJob *, CheckContext *);
template <class NodeT>
PRESERVE_NONE void common_prefix_iter(CheckJob *, CheckContext *);
PRESERVE_NONE void done_common_prefix_iter(CheckJob *, CheckContext *);
static Continuation commonPrefixIterTable[] = {
common_prefix_iter<Node0>, common_prefix_iter<Node3>,
common_prefix_iter<Node16>, common_prefix_iter<Node48>,
common_prefix_iter<Node256>};
template <class NodeT, bool kFirst>
PRESERVE_NONE void left_side_iter(CheckJob *, CheckContext *);
PRESERVE_NONE void left_side_down_left_spine(CheckJob *, CheckContext *);
PRESERVE_NONE void done_left_side_iter(CheckJob *, CheckContext *);
static Continuation leftSideIterTable[2][5] = {
{left_side_iter<Node0, false>, left_side_iter<Node3, false>,
left_side_iter<Node16, false>, left_side_iter<Node48, false>,
left_side_iter<Node256, false>},
{left_side_iter<Node0, true>, left_side_iter<Node3, true>,
left_side_iter<Node16, true>, left_side_iter<Node48, true>,
left_side_iter<Node256, true>},
};
PRESERVE_NONE void begin(CheckJob *job, CheckContext *context) {
job->lcp = longestCommonPrefix(job->begin.data(), job->end.data(),
std::min(job->begin.size(), job->end.size()));
if (job->lcp == int(job->begin.size()) &&
job->end.size() == job->begin.size() + 1 && job->end.back() == 0) {
job->continuation = check_point_read_state_machine::begin;
// Call directly since we have nothing to prefetch
MUSTTAIL return job->continuation(job, context);
}
if (job->lcp == int(job->begin.size() - 1) &&
job->end.size() == job->begin.size() &&
int(job->begin.back()) + 1 == int(job->end.back())) {
job->continuation = check_prefix_read_state_machine::begin;
// Call directly since we have nothing to prefetch
MUSTTAIL return job->continuation(job, context);
}
++context->tls->range_read_accum;
job->remaining = job->begin.subspan(0, job->lcp);
if (job->remaining.size() == 0) {
MUSTTAIL return done_common_prefix_iter(job, context);
}
auto [c, maxV] = getChildAndMaxVersion(job->n, job->remaining[0]);
job->maxV = maxV;
job->child = c;
if (job->child == nullptr) {
MUSTTAIL return done_common_prefix_iter(job, context);
}
job->continuation = commonPrefixIterTable[c.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
// Advance down common prefix, but stay on a physical path in the tree
template <class NodeT>
void common_prefix_iter(CheckJob *job, CheckContext *context) {
assert(NodeT::kType == job->child->getType());
NodeT *child = static_cast<NodeT *>(job->child);
if (child->partialKeyLen > 0) {
int cl = std::min<int>(child->partialKeyLen, job->remaining.size() - 1);
int i =
longestCommonPrefix(child->partialKey(), job->remaining.data() + 1, cl);
if (i != child->partialKeyLen) {
MUSTTAIL return done_common_prefix_iter(job, context);
}
}
job->n = child;
job->remaining = job->remaining.subspan(1 + child->partialKeyLen,
job->remaining.size() -
(1 + child->partialKeyLen));
++context->tls->range_read_iterations_accum;
if (job->maxV <= job->readVersion) {
job->setResult(true);
++context->tls->range_read_short_circuit_accum;
MUSTTAIL return complete(job, context);
}
if (job->remaining.size() == 0) {
MUSTTAIL return done_common_prefix_iter(job, context);
}
auto [c, maxV] = getChildAndMaxVersion(child, job->remaining[0]);
job->maxV = maxV;
job->child = c;
if (job->child == nullptr) {
MUSTTAIL return done_common_prefix_iter(job, context);
}
job->continuation = commonPrefixIterTable[c.getType()];
__builtin_prefetch(job->child);
MUSTTAIL return keepGoing(job, context);
}
PRESERVE_NONE void done_common_prefix_iter(CheckJob *job,
CheckContext *context) {
{
Arena arena;
assert(getSearchPath(arena, job->n) <=>
job->begin.subspan(0, job->lcp - job->remaining.size()) ==
0);
}
const int consumed = job->lcp - job->remaining.size();
assume(consumed >= 0);
job->begin = job->begin.subspan(consumed, int(job->begin.size()) - consumed);
job->end = job->end.subspan(consumed, int(job->end.size()) - consumed);
job->lcp -= consumed;
job->commonPrefixNode = job->n;
if (job->lcp == int(job->begin.size())) {
job->setResult(checkRangeRightSide(job->n, job->end, job->lcp,
job->readVersion, context->tls));
MUSTTAIL return complete(job, context);
}
// If this were not true we would have returned above
assert(job->begin.size() > 0);
if (!checkRangeStartsWith(job->n, job->begin.subspan(0, job->lcp),
job->begin[job->lcp], job->end[job->lcp],
job->readVersion, context->tls)) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
job->remaining = job->begin;
auto [c, maxV] = getChildAndMaxVersion(job->n, job->remaining[0]);
job->maxV = maxV;
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(job->n, job->remaining[0]);
if (c != nullptr) {
job->n = c;
job->continuation = left_side_down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
} else {
job->n = nextSibling(job->n);
if (job->n == nullptr) {
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
}
job->continuation = left_side_down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->n = child;
job->continuation = leftSideIterTable[true][c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
// Return true if the max version among all keys that start with key[:prefixLen]
// that are >= key is <= readVersion
template <class NodeT, bool kFirst>
PRESERVE_NONE void left_side_iter(CheckJob *job, CheckContext *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
job->remaining = job->remaining.subspan(1, job->remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, job->remaining.size());
int i =
longestCommonPrefix(n->partialKey(), job->remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> job->remaining[i];
if (c > 0) {
if constexpr (kFirst) {
if (i < job->lcp) {
job->continuation = left_side_down_left_spine;
MUSTTAIL return job->continuation(job, context);
}
}
if (n->entryPresent && n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (job->maxV > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
} else {
job->n = nextSibling(n);
if (job->n == nullptr) {
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
}
job->continuation = left_side_down_left_spine;
MUSTTAIL return job->continuation(job, context);
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
job->remaining =
job->remaining.subspan(commonLen, job->remaining.size() - commonLen);
} else if (n->partialKeyLen > int(job->remaining.size())) {
if (n->entryPresent && n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
if (job->maxV > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
}
}
++context->tls->range_read_iterations_accum;
if (job->maxV <= job->readVersion) {
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
}
if (job->remaining.size() == 0) {
if (job->maxV > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
} else {
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
}
}
if (!checkMaxBetweenExclusive(n, job->remaining[0], 256, job->readVersion,
context->tls)) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
auto [c, maxV] = getChildAndMaxVersion(job->n, job->remaining[0]);
job->maxV = maxV;
Node *child = c;
if (child == nullptr) {
auto c = getChildGeq(n, job->remaining[0]);
if (c != nullptr) {
job->n = c;
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
} else {
job->n = nextSibling(job->n);
if (job->n == nullptr) {
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
}
job->continuation = left_side_down_left_spine;
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
}
job->n = child;
job->continuation = leftSideIterTable[false][c.getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
PRESERVE_NONE void done_left_side_iter(CheckJob *job, CheckContext *context) {
job->setResult(checkRangeRightSide(job->commonPrefixNode, job->end,
job->lcp + 1, job->readVersion,
context->tls));
MUSTTAIL return complete(job, context);
}
void left_side_down_left_spine(CheckJob *job, CheckContext *context) {
if (job->n->entryPresent) {
if (job->n->entry.rangeVersion > job->readVersion) {
job->setResult(false);
MUSTTAIL return complete(job, context);
}
job->continuation = done_left_side_iter;
MUSTTAIL return job->continuation(job, context);
}
job->n = getFirstChildExists(job->n);
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
} // namespace check_range_read_state_machine
void CheckJob::init(const ConflictSet::ReadRange *read,
ConflictSet::Result *result, Node *root,
int64_t oldestVersionFullPrecision) {
auto begin = std::span<const uint8_t>(read->begin.p, read->begin.len);
auto end = std::span<const uint8_t>(read->end.p, read->end.len);
if (read->readVersion < oldestVersionFullPrecision) [[unlikely]] {
*result = ConflictSet::TooOld;
continuation = complete;
} else if (end.size() == 0) {
this->begin = begin;
this->n = root;
this->readVersion = InternalVersionT(read->readVersion);
this->result = result;
continuation = check_point_read_state_machine::begin;
} else {
this->begin = begin;
this->end = end;
this->n = root;
this->readVersion = InternalVersionT(read->readVersion);
this->result = result;
continuation = check_range_read_state_machine::begin;
}
}
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
void check(const ReadRange *reads, Result *result, int count) {
assert(oldestVersionFullPrecision >=
newestVersionFullPrecision - kNominalVersionWindow);
if (count == 0) {
return;
}
ReadContext tls;
tls.impl = this;
int64_t check_byte_accum = 0;
constexpr int kConcurrent = 16;
CheckJob inProgress[kConcurrent];
CheckContext context;
context.count = count;
context.oldestVersionFullPrecision = oldestVersionFullPrecision;
context.root = root;
context.queries = reads;
context.results = result;
context.tls = &tls;
int64_t started = std::min(kConcurrent, count);
context.started = started;
for (int i = 0; i < started; i++) {
inProgress[i].init(reads + i, result + i, root,
oldestVersionFullPrecision);
}
for (int i = 0; i < started - 1; i++) {
inProgress[i].next = inProgress + i + 1;
}
for (int i = 1; i < started; i++) {
inProgress[i].prev = inProgress + i - 1;
}
inProgress[0].prev = inProgress + started - 1;
inProgress[started - 1].next = inProgress;
#if __has_attribute(musttail)
// Kick off the sequence of tail calls that finally returns once all jobs
// are done
inProgress->continuation(inProgress, &context);
#else
context.job = inProgress;
context.done = false;
while (!context.done) {
context.job->continuation(context.job, &context);
}
#endif
for (int i = 0; i < count; ++i) {
assert(reads[i].readVersion >= 0);
assert(reads[i].readVersion <= newestVersionFullPrecision);
const auto &r = reads[i];
check_byte_accum += r.begin.len + r.end.len;
tls.commits_accum += result[i] == Commit;
tls.conflicts_accum += result[i] == Conflict;
tls.too_olds_accum += result[i] == TooOld;
}
point_read_total.add(tls.point_read_accum);
prefix_read_total.add(tls.prefix_read_accum);
range_read_total.add(tls.range_read_accum);
range_read_node_scan_total.add(tls.range_read_node_scan_accum);
point_read_short_circuit_total.add(tls.point_read_short_circuit_accum);
prefix_read_short_circuit_total.add(tls.prefix_read_short_circuit_accum);
range_read_short_circuit_total.add(tls.range_read_short_circuit_accum);
point_read_iterations_total.add(tls.point_read_iterations_accum);
prefix_read_iterations_total.add(tls.prefix_read_iterations_accum);
range_read_iterations_total.add(tls.range_read_iterations_accum);
commits_total.add(tls.commits_accum);
conflicts_total.add(tls.conflicts_accum);
too_olds_total.add(tls.too_olds_accum);
check_bytes_total.add(check_byte_accum);
}
void addWrites(const WriteRange *writes, int count, int64_t writeVersion) {
#if !USE_64_BIT
// There could be other conflict sets in the same thread. We need
// InternalVersionT::zero to be correct for this conflict set for the
// lifetime of the current call frame.
InternalVersionT::zero = tls.zero = oldestVersion;
#endif
assert(writeVersion >= newestVersionFullPrecision);
assert(tls.accum.entries_erased == 0);
assert(tls.accum.entries_inserted == 0);
if (oldestExtantVersion < writeVersion - kMaxCorrectVersionWindow)
[[unlikely]] {
if (writeVersion > newestVersionFullPrecision + kNominalVersionWindow) {
eraseTree(root, &tls);
init(writeVersion - kNominalVersionWindow);
}
newestVersionFullPrecision = writeVersion;
newest_version.set(newestVersionFullPrecision);
if (newestVersionFullPrecision - kNominalVersionWindow >
oldestVersionFullPrecision) {
setOldestVersion(newestVersionFullPrecision - kNominalVersionWindow);
}
while (oldestExtantVersion <
newestVersionFullPrecision - kMaxCorrectVersionWindow) {
gcScanStep(1000);
}
} else {
newestVersionFullPrecision = writeVersion;
newest_version.set(newestVersionFullPrecision);
if (newestVersionFullPrecision - kNominalVersionWindow >
oldestVersionFullPrecision) {
setOldestVersion(newestVersionFullPrecision - kNominalVersionWindow);
}
}
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
tls.accum.write_bytes += w.begin.len + w.end.len;
auto begin = std::span<const uint8_t>(w.begin.p, w.begin.len);
auto end = std::span<const uint8_t>(w.end.p, w.end.len);
if (w.end.len > 0) {
addWriteRange(root, begin, end, InternalVersionT(writeVersion), &tls,
this);
} else {
addPointWrite(root, begin, InternalVersionT(writeVersion), &tls);
}
}
// Run gc at least 200% the rate we're inserting entries
keyUpdates +=
std::max<int64_t>(tls.accum.entries_inserted - tls.accum.entries_erased,
0) *
2;
point_writes_total.add(tls.accum.point_writes);
range_writes_total.add(tls.accum.range_writes);
nodes_allocated_total.add(tls.accum.nodes_allocated);
nodes_released_total.add(tls.accum.nodes_released);
entries_inserted_total.add(tls.accum.entries_inserted);
entries_erased_total.add(tls.accum.entries_erased);
insert_iterations_total.add(tls.accum.insert_iterations);
write_bytes_total.add(tls.accum.write_bytes);
memset(&tls.accum, 0, sizeof(tls.accum));
}
// Spends up to `fuel` gc'ing, and returns its unused fuel. Reclaims memory
// and updates oldestExtantVersion after spending enough fuel.
int64_t gcScanStep(int64_t fuel) {
Node *n = firstGeqPhysical(root, removalKey);
// There's no way to erase removalKey without introducing a key after it
assert(n != nullptr);
// Don't erase the root
if (n == root) {
rezero(n, oldestVersion);
n = nextPhysical(n);
}
int64_t set_oldest_iterations_accum = 0;
for (; fuel > 0 && n != nullptr; ++set_oldest_iterations_accum) {
rezero(n, oldestVersion);
// The "make sure gc keeps up with writes" calculations assume that we're
// scanning key by key, not node by node. Make sure we only spend fuel
// when there's a logical entry.
fuel -= n->entryPresent;
if (n->entryPresent && std::max(n->entry.pointVersion,
n->entry.rangeVersion) <= oldestVersion) {
// Any transaction n would have prevented from committing is
// going to fail with TooOld anyway.
// There's no way to insert a range such that range version of the right
// node is greater than the point version of the left node
assert(n->entry.rangeVersion <= oldestVersion);
n = erase(n, &tls, this, /*logical*/ false);
} else {
maybeDecreaseCapacity(n, &tls, this);
n = nextPhysical(n);
}
}
gc_iterations_total.add(set_oldest_iterations_accum);
if (n == nullptr) {
removalKey = {};
oldestExtantVersion = oldestVersionAtGcBegin;
oldest_extant_version.set(oldestExtantVersion);
oldestVersionAtGcBegin = oldestVersionFullPrecision;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr,
"new oldestExtantVersion: %" PRId64
", new oldestVersionAtGcBegin: %" PRId64 "\n",
oldestExtantVersion, oldestVersionAtGcBegin);
#endif
} else {
removalKeyArena = Arena();
removalKey = getSearchPath(removalKeyArena, n);
}
return fuel;
}
void setOldestVersion(int64_t newOldestVersion) {
assert(newOldestVersion >= 0);
assert(newOldestVersion <= newestVersionFullPrecision);
// If addWrites advances oldestVersion to keep within valid window, a
// subsequent setOldestVersion can be legitimately called with a version
// older than `oldestVersionFullPrecision`. < instead of <= so that we can
// do garbage collection work without advancing the oldest version.
if (newOldestVersion < oldestVersionFullPrecision) {
return;
}
InternalVersionT oldestVersion{newOldestVersion};
this->oldestVersionFullPrecision = newOldestVersion;
this->oldestVersion = oldestVersion;
#if !USE_64_BIT
InternalVersionT::zero = tls.zero = oldestVersion;
#endif
#ifdef NDEBUG
// This is here for performance reasons, since we want to amortize the cost
// of storing the search path as a string. In tests, we want to exercise the
// rest of the code often.
if (keyUpdates < 100) {
return;
}
#endif
keyUpdates = gcScanStep(keyUpdates);
nodes_allocated_total.add(std::exchange(tls.accum.nodes_allocated, 0));
nodes_released_total.add(std::exchange(tls.accum.nodes_released, 0));
entries_inserted_total.add(std::exchange(tls.accum.entries_inserted, 0));
entries_erased_total.add(std::exchange(tls.accum.entries_erased, 0));
oldest_version.set(oldestVersionFullPrecision);
}
int64_t getBytes() const { return totalBytes; }
void init(int64_t oldestVersion) {
this->oldestVersion = InternalVersionT(oldestVersion);
oldestVersionFullPrecision = oldestExtantVersion = oldestVersionAtGcBegin =
newestVersionFullPrecision = oldestVersion;
oldest_version.set(oldestVersionFullPrecision);
newest_version.set(newestVersionFullPrecision);
oldest_extant_version.set(oldestExtantVersion);
tls.~WriteContext();
new (&tls) WriteContext();
removalKeyArena = Arena{};
removalKey = {};
keyUpdates = 10;
// Insert ""
root = tls.allocate<Node0>(0);
root->numChildren = 0;
root->parent = nullptr;
root->entryPresent = false;
root->partialKeyLen = 0;
addKey(root);
root->entryPresent = true;
root->entry.pointVersion = this->oldestVersion;
root->entry.rangeVersion = this->oldestVersion;
#if !USE_64_BIT
InternalVersionT::zero = tls.zero = this->oldestVersion;
#endif
// Intentionally not resetting totalBytes
}
explicit Impl(int64_t oldestVersion) {
assert(oldestVersion >= 0);
init(oldestVersion);
metrics = initMetrics(metricsList, metricsCount);
}
~Impl() {
eraseTree(root, &tls);
safe_free(metrics, metricsCount * sizeof(metrics[0]));
}
WriteContext tls;
Arena removalKeyArena;
std::span<const uint8_t> removalKey;
int64_t keyUpdates;
TaggedNodePointer root;
InternalVersionT oldestVersion;
int64_t oldestVersionFullPrecision;
int64_t oldestExtantVersion;
int64_t oldestVersionAtGcBegin;
int64_t newestVersionFullPrecision;
int64_t totalBytes = 0;
MetricsV1 *metrics;
int metricsCount = 0;
Metric *metricsList = nullptr;
#define GAUGE(name, help) \
Gauge name { metricsList, metricsCount, #name, help }
#define COUNTER(name, help) \
Counter name { metricsList, metricsCount, #name, help }
// ==================== METRICS DEFINITIONS ====================
COUNTER(point_read_total, "Total number of point reads checked");
COUNTER(point_read_short_circuit_total,
"Total number of point reads that did not require a full search to "
"check");
COUNTER(point_read_iterations_total,
"Total number of iterations of the main loop for point read checks");
COUNTER(prefix_read_total, "Total number of prefix reads checked");
COUNTER(prefix_read_short_circuit_total,
"Total number of prefix reads that did not require a full search to "
"check");
COUNTER(prefix_read_iterations_total,
"Total number of iterations of the main loop for prefix read checks");
COUNTER(range_read_total, "Total number of range reads checked");
COUNTER(range_read_short_circuit_total,
"Total number of range reads that did not require a full search to "
"check");
COUNTER(range_read_iterations_total,
"Total number of iterations of the main loops for range read checks");
COUNTER(range_read_node_scan_total,
"Total number of scans of individual nodes while "
"checking a range read");
COUNTER(commits_total,
"Total number of checks where the result is \"commit\"");
COUNTER(conflicts_total,
"Total number of checks where the result is \"conflict\"");
COUNTER(too_olds_total,
"Total number of checks where the result is \"too old\"");
COUNTER(check_bytes_total, "Total number of key bytes checked");
COUNTER(point_writes_total, "Total number of point writes");
COUNTER(range_writes_total,
"Total number of range writes (includes prefix writes)");
GAUGE(memory_bytes, "Total number of bytes in use");
COUNTER(nodes_allocated_total,
"The total number of physical tree nodes allocated");
COUNTER(nodes_released_total,
"The total number of physical tree nodes released");
COUNTER(insert_iterations_total,
"The total number of iterations of the main loop for insertion. "
"Includes searches where the entry already existed, and so insertion "
"did not take place");
COUNTER(entries_inserted_total,
"The total number of entries inserted in the tree");
COUNTER(entries_erased_total,
"The total number of entries erased from the tree");
COUNTER(
gc_iterations_total,
"The total number of iterations of the main loop for garbage collection");
COUNTER(write_bytes_total, "Total number of key bytes in calls to addWrites");
GAUGE(oldest_version,
"The lowest version that doesn't result in \"TooOld\" for checks");
GAUGE(newest_version, "The version of the most recent call to addWrites");
GAUGE(
oldest_extant_version,
"A lower bound on the lowest version associated with an existing entry");
// ==================== END METRICS DEFINITIONS ====================
#undef GAUGE
#undef COUNTER
void getMetricsV1(MetricsV1 **metrics, int *count) {
*metrics = this->metrics;
*count = metricsCount;
}
};
TaggedNodePointer &getInTree(Node *n, ConflictSet::Impl *impl) {
return n->parent == nullptr ? impl->root
: getChildExists(n->parent, n->parentsIndex);
}
// Internal entry points. Public entry points should just delegate to these
void internal_check(ConflictSet::Impl *impl,
const ConflictSet::ReadRange *reads,
ConflictSet::Result *results, int count) {
impl->check(reads, results, count);
}
void internal_addWrites(ConflictSet::Impl *impl,
const ConflictSet::WriteRange *writes, int count,
int64_t writeVersion) {
mallocBytesDelta = 0;
impl->addWrites(writes, count, writeVersion);
impl->totalBytes += mallocBytesDelta;
impl->memory_bytes.set(impl->totalBytes);
#if SHOW_MEMORY
if (impl->totalBytes != mallocBytes) {
abort();
}
#endif
}
void internal_setOldestVersion(ConflictSet::Impl *impl, int64_t oldestVersion) {
mallocBytesDelta = 0;
impl->setOldestVersion(oldestVersion);
impl->totalBytes += mallocBytesDelta;
impl->memory_bytes.set(impl->totalBytes);
#if SHOW_MEMORY
if (impl->totalBytes != mallocBytes) {
abort();
}
#endif
}
ConflictSet::Impl *internal_create(int64_t oldestVersion) {
mallocBytesDelta = 0;
auto *result = new (safe_malloc(sizeof(ConflictSet::Impl)))
ConflictSet::Impl{oldestVersion};
result->totalBytes += mallocBytesDelta;
return result;
}
void internal_destroy(ConflictSet::Impl *impl) {
impl->~Impl();
safe_free(impl, sizeof(ConflictSet::Impl));
}
int64_t internal_getBytes(ConflictSet::Impl *impl) { return impl->getBytes(); }
void internal_getMetricsV1(ConflictSet::Impl *impl,
ConflictSet::MetricsV1 **metrics, int *count) {
impl->getMetricsV1(metrics, count);
}
double internal_getMetricValue(const ConflictSet::MetricsV1 *metric) {
return ((Metric *)metric->p)->value.load(std::memory_order_relaxed);
}
// ==================== END IMPLEMENTATION ====================
// GCOVR_EXCL_START
Node *firstGeqLogical(Node *n, const std::span<const uint8_t> key) {
auto remaining = key;
for (;;) {
if (remaining.size() == 0) {
if (n->entryPresent) {
return n;
}
n = getFirstChildExists(n);
goto downLeftSpine;
}
Node *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return nullptr;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n;
}
void ConflictSet::check(const ReadRange *reads, Result *results,
int count) const {
internal_check(impl, reads, results, count);
}
void ConflictSet::addWrites(const WriteRange *writes, int count,
int64_t writeVersion) {
internal_addWrites(impl, writes, count, writeVersion);
}
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
internal_setOldestVersion(impl, oldestVersion);
}
int64_t ConflictSet::getBytes() const { return internal_getBytes(impl); }
void ConflictSet::getMetricsV1(MetricsV1 **metrics, int *count) const {
return internal_getMetricsV1(impl, metrics, count);
}
double ConflictSet::MetricsV1::getValue() const {
return internal_getMetricValue(this);
}
ConflictSet::ConflictSet(int64_t oldestVersion)
: impl(internal_create(oldestVersion)) {}
ConflictSet::~ConflictSet() {
if (impl) {
internal_destroy(impl);
}
}
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
impl = std::exchange(other.impl, nullptr);
return *this;
}
using ConflictSet_Result = ConflictSet::Result;
using ConflictSet_Key = ConflictSet::Key;
using ConflictSet_ReadRange = ConflictSet::ReadRange;
using ConflictSet_WriteRange = ConflictSet::WriteRange;
extern "C" {
__attribute__((__visibility__("default"))) void
ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads,
ConflictSet_Result *results, int count) {
internal_check((ConflictSet::Impl *)cs, reads, results, count);
}
__attribute__((__visibility__("default"))) void
ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count,
int64_t writeVersion) {
internal_addWrites((ConflictSet::Impl *)cs, writes, count, writeVersion);
}
__attribute__((__visibility__("default"))) void
ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) {
internal_setOldestVersion((ConflictSet::Impl *)cs, oldestVersion);
}
__attribute__((__visibility__("default"))) void *
ConflictSet_create(int64_t oldestVersion) {
return internal_create(oldestVersion);
}
__attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) {
internal_destroy((ConflictSet::Impl *)cs);
}
__attribute__((__visibility__("default"))) int64_t
ConflictSet_getBytes(void *cs) {
return internal_getBytes((ConflictSet::Impl *)cs);
}
}
// Make sure abi is well-defined
static_assert(std::is_standard_layout_v<ConflictSet::Result>);
static_assert(std::is_standard_layout_v<ConflictSet::Key>);
static_assert(std::is_standard_layout_v<ConflictSet::ReadRange>);
static_assert(std::is_standard_layout_v<ConflictSet::WriteRange>);
static_assert(std::is_standard_layout_v<ConflictSet::MetricsV1>);
namespace {
std::string getSearchPathPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = vector<char>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
if (result.size() > 0) {
return printable(std::string_view((const char *)&result[0],
result.size())); // NOLINT
} else {
return std::string();
}
}
std::string getPartialKeyPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = std::string((const char *)&n->parentsIndex,
n->parent == nullptr ? 0 : 1) +
std::string((const char *)n->partialKey(), n->partialKeyLen);
return printable(result); // NOLINT
}
std::string strinc(std::string_view str, bool &ok) {
int index;
for (index = str.size() - 1; index >= 0; index--)
if ((uint8_t &)(str[index]) != 255)
break;
// Must not be called with a string that consists only of zero or more '\xff'
// bytes.
if (index < 0) {
ok = false;
return {};
}
ok = true;
auto r = std::string(str.substr(0, index + 1));
((uint8_t &)r[r.size() - 1])++;
return r;
}
std::string getSearchPath(Node *n) {
assert(n != nullptr);
Arena arena;
auto result = getSearchPath(arena, n);
return std::string((const char *)result.data(), result.size());
}
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node,
ConflictSet::Impl *impl) {
constexpr int kSeparation = 3;
struct DebugDotPrinter {
explicit DebugDotPrinter(FILE *file, ConflictSet::Impl *impl)
: file(file), impl(impl) {}
void print(Node *n, int y = 0) {
assert(n != nullptr);
if (n->entryPresent) {
fprintf(file,
" k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64
"\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, n->parent == nullptr ? -1 : maxVersion(n).toInt64(),
n->entry.pointVersion.toInt64(),
n->entry.rangeVersion.toInt64(),
getPartialKeyPrintable(n).c_str(), x, y);
} else {
fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, n->parent == nullptr ? -1 : maxVersion(n).toInt64(),
getPartialKeyPrintable(n).c_str(), x, y);
}
x += kSeparation;
for (auto c = getChildGeq(n, 0); c != nullptr;
c = getChildGeq(n, c->parentsIndex + 1)) {
fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c);
print(c, y - kSeparation);
}
}
int x = 0;
FILE *file;
ConflictSet::Impl *impl;
};
fprintf(file, "digraph ConflictSet {\n");
fprintf(file, " node [shape = box];\n");
assert(node != nullptr);
DebugDotPrinter printer{file, impl};
printer.print(node);
fprintf(file, "}\n");
}
void checkParentPointers(Node *node, bool &success) {
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
if (child->parent != node) {
fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n",
getSearchPathPrintable(node).c_str(), child->parentsIndex,
(void *)child->parent, (void *)node);
success = false;
}
checkParentPointers(child, success);
}
}
Node *firstGeq(Node *n, std::string_view key) {
return firstGeqLogical(
n, std::span<const uint8_t>((const uint8_t *)key.data(), key.size()));
}
#if USE_64_BIT
void checkVersionsGeqOldestExtant(Node *, InternalVersionT) {}
#else
void checkVersionsGeqOldestExtant(Node *n,
InternalVersionT oldestExtantVersion) {
if (n->entryPresent) {
assert(n->entry.pointVersion >= oldestExtantVersion);
assert(n->entry.rangeVersion >= oldestExtantVersion);
}
switch (n->getType()) {
case Type_Node0: {
} break;
case Type_Node3: {
auto *self = static_cast<Node3 *>(n);
for (int i = 0; i < 3; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
} break;
case Type_Node16: {
auto *self = static_cast<Node16 *>(n);
for (int i = 0; i < 16; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
} break;
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
for (int i = 0; i < 48; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
for (auto m : self->maxOfMax) {
assert(m >= oldestExtantVersion);
}
} break;
case Type_Node256: {
auto *self = static_cast<Node256 *>(n);
for (int i = 0; i < 256; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
for (auto m : self->maxOfMax) {
assert(m >= oldestExtantVersion);
}
} break;
default:
abort();
}
}
#endif
[[maybe_unused]] InternalVersionT
checkMaxVersion(Node *root, Node *node, InternalVersionT oldestVersion,
bool &success, ConflictSet::Impl *impl) {
checkVersionsGeqOldestExtant(node,
InternalVersionT(impl->oldestExtantVersion));
auto expected = oldestVersion;
if (node->entryPresent) {
expected = std::max(expected, node->entry.pointVersion);
}
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
expected = std::max(
expected, checkMaxVersion(root, child, oldestVersion, success, impl));
if (child->entryPresent) {
expected = std::max(expected, child->entry.rangeVersion);
}
}
auto key = getSearchPath(root);
bool ok;
auto inc = strinc(key, ok);
if (ok) {
auto borrowed = firstGeq(root, inc);
if (borrowed != nullptr) {
expected = std::max(expected, borrowed->entry.rangeVersion);
}
}
if (node->parent && maxVersion(node) > oldestVersion &&
maxVersion(node) != expected) {
fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",
getSearchPathPrintable(node).c_str(), maxVersion(node).toInt64(),
expected.toInt64());
success = false;
}
return expected;
}
[[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) {
int64_t total = node->entryPresent;
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
int64_t e = checkEntriesExist(child, success);
total += e;
if (e == 0) {
Arena arena;
fprintf(stderr, "%s has child %02x with no reachable entries\n",
getSearchPathPrintable(node).c_str(), child->parentsIndex);
success = false;
}
}
return total;
}
[[maybe_unused]] void checkMemoryBoundInvariants(Node *node, bool &success) {
int minNumChildren;
switch (node->getType()) {
case Type_Node0:
minNumChildren = kMinChildrenNode0;
break;
case Type_Node3:
minNumChildren = kMinChildrenNode3;
break;
case Type_Node16:
minNumChildren = kMinChildrenNode16;
break;
case Type_Node48:
minNumChildren = kMinChildrenNode48;
break;
case Type_Node256:
minNumChildren = kMinChildrenNode256;
break;
default:
abort();
}
if (node->numChildren + int(node->entryPresent) < minNumChildren) {
fprintf(stderr,
"%s has %d children + %d entries, which is less than the minimum "
"required %d\n",
getSearchPathPrintable(node).c_str(), node->numChildren,
int(node->entryPresent), minNumChildren);
success = false;
}
// TODO check that the max capacity property eventually holds
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
checkMemoryBoundInvariants(child, success);
}
}
[[maybe_unused]] bool checkCorrectness(Node *node,
InternalVersionT oldestVersion,
ConflictSet::Impl *impl) {
bool success = true;
if (node->partialKeyLen > 0) {
fprintf(stderr, "Root cannot have a partial key\n");
success = false;
}
checkParentPointers(node, success);
checkMaxVersion(node, node, oldestVersion, success, impl);
checkEntriesExist(node, success);
checkMemoryBoundInvariants(node, success);
return success;
}
} // namespace
#if SHOW_MEMORY
int64_t nodeBytes = 0;
int64_t peakNodeBytes = 0;
int64_t partialCapacityBytes = 0;
int64_t peakPartialCapacityBytes = 0;
int64_t totalKeys = 0;
int64_t peakKeys = 0;
int64_t keyBytes = 0;
int64_t peakKeyBytes = 0;
int64_t getNodeSize(struct Node *n) {
switch (n->getType()) {
case Type_Node0:
return sizeof(Node0);
case Type_Node3:
return sizeof(Node3);
case Type_Node16:
return sizeof(Node16);
case Type_Node48:
return sizeof(Node48);
case Type_Node256:
return sizeof(Node256);
default:
abort();
}
}
int64_t getSearchPathLength(Node *n) {
assert(n != nullptr);
int64_t result = 0;
for (;;) {
result += n->partialKeyLen;
if (n->parent == nullptr) {
break;
}
++result;
n = n->parent;
}
return result;
}
void addNode(Node *n) {
nodeBytes += getNodeSize(n);
partialCapacityBytes += n->getCapacity();
if (nodeBytes > peakNodeBytes) {
peakNodeBytes = nodeBytes;
}
if (partialCapacityBytes > peakPartialCapacityBytes) {
peakPartialCapacityBytes = partialCapacityBytes;
}
}
void removeNode(Node *n) {
nodeBytes -= getNodeSize(n);
partialCapacityBytes -= n->getCapacity();
}
void addKey(Node *n) {
if (!n->entryPresent) {
++totalKeys;
keyBytes += getSearchPathLength(n);
if (totalKeys > peakKeys) {
peakKeys = totalKeys;
}
if (keyBytes > peakKeyBytes) {
peakKeyBytes = keyBytes;
}
}
}
void removeKey(Node *n) {
if (n->entryPresent) {
--totalKeys;
keyBytes -= getSearchPathLength(n);
}
}
struct __attribute__((visibility("default"))) PeakPrinter {
~PeakPrinter() {
printf("--- radix_tree ---\n");
printf("malloc bytes: %g\n", double(mallocBytes));
printf("Peak malloc bytes: %g\n", double(peakMallocBytes));
printf("Node bytes: %g\n", double(nodeBytes));
printf("Peak node bytes: %g\n", double(peakNodeBytes));
printf("Expected worst case node bytes: %g\n",
double(peakKeys * kBytesPerKey));
printf("Key bytes: %g\n", double(keyBytes));
printf("Peak key bytes: %g (not sharing common prefixes)\n",
double(peakKeyBytes));
printf("Partial key capacity bytes: %g\n", double(partialCapacityBytes));
printf("Peak partial key capacity bytes: %g\n",
double(peakPartialCapacityBytes));
}
} peakPrinter;
#endif
#ifdef ENABLE_MAIN
#include "third_party/nanobench.h"
template <int kN> void benchRezero() {
static_assert(kN % 16 == 0);
ankerl::nanobench::Bench bench;
InternalVersionT vs[kN];
InternalVersionT zero;
bench.run("rezero" + std::to_string(kN), [&]() {
bench.doNotOptimizeAway(vs);
bench.doNotOptimizeAway(zero);
for (int i = 0; i < kN; i += 16) {
rezero16(vs + i, zero);
}
});
}
template <int kN> void benchScan1() {
static_assert(kN % 16 == 0);
ankerl::nanobench::Bench bench;
InternalVersionT vs[kN];
uint8_t is[kN];
uint8_t begin;
uint8_t end;
InternalVersionT v;
bench.run("scan" + std::to_string(kN), [&]() {
bench.doNotOptimizeAway(vs);
bench.doNotOptimizeAway(is);
bench.doNotOptimizeAway(begin);
bench.doNotOptimizeAway(end);
bench.doNotOptimizeAway(v);
for (int i = 0; i < kN; i += 16) {
scan16</*kAVX512=*/true>(vs + i, is + i, begin, end, v);
}
});
}
template <int kN> void benchScan2() {
static_assert(kN % 16 == 0);
ankerl::nanobench::Bench bench;
InternalVersionT vs[kN];
uint8_t is[kN];
uint8_t begin;
uint8_t end;
InternalVersionT v;
bench.run("scan" + std::to_string(kN), [&]() {
bench.doNotOptimizeAway(vs);
bench.doNotOptimizeAway(begin);
bench.doNotOptimizeAway(end);
bench.doNotOptimizeAway(v);
for (int i = 0; i < kN; i += 16) {
scan16</*kAVX512=*/true>(vs + i, begin, end, v);
}
});
}
void benchHorizontal16() {
ankerl::nanobench::Bench bench;
InternalVersionT vs[16];
for (int i = 0; i < 16; ++i) {
vs[i] = InternalVersionT(rand() % 1000 + 1000);
}
#if !USE_64_BIT
InternalVersionT::zero = InternalVersionT(rand() % 1000);
#endif
bench.run("horizontal16", [&]() {
bench.doNotOptimizeAway(horizontalMax16(vs, InternalVersionT::zero));
});
int x = rand() % 15 + 1;
bench.run("horizontalUpTo16", [&]() {
bench.doNotOptimizeAway(horizontalMaxUpTo16(vs, InternalVersionT::zero, x));
});
}
void benchLCP(int len) {
ankerl::nanobench::Bench bench;
std::vector<uint8_t> lhs(len);
std::vector<uint8_t> rhs(len);
bench.run("lcp " + std::to_string(len), [&]() {
bench.doNotOptimizeAway(lhs);
bench.doNotOptimizeAway(rhs);
bench.doNotOptimizeAway(longestCommonPrefix(lhs.data(), rhs.data(), len));
});
}
void printTree() {
int64_t writeVersion = 0;
ConflictSet::Impl cs{writeVersion};
ReferenceImpl refImpl{writeVersion};
Arena arena;
ConflictSet::WriteRange write;
write.begin = "and"_s;
write.end = "ant"_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "any"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "are"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "art"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
debugPrintDot(stdout, cs.root, &cs);
}
int main(void) { benchHorizontal16(); }
#endif
#ifdef ENABLE_FUZZ
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
Arbitrary arbitrary({data, size});
TestDriver<ConflictSet::Impl> driver1{arbitrary};
TestDriver<ConflictSet::Impl> driver2{arbitrary};
bool done1 = false;
bool done2 = false;
for (;;) {
if (!done1) {
done1 = driver1.next();
if (!driver1.ok) {
debugPrintDot(stdout, driver1.cs.root, &driver1.cs);
fflush(stdout);
abort();
}
if (!checkCorrectness(driver1.cs.root, driver1.cs.oldestVersion,
&driver1.cs)) {
debugPrintDot(stdout, driver1.cs.root, &driver1.cs);
fflush(stdout);
abort();
}
}
if (!done2) {
done2 = driver2.next();
if (!driver2.ok) {
debugPrintDot(stdout, driver2.cs.root, &driver2.cs);
fflush(stdout);
abort();
}
if (!checkCorrectness(driver2.cs.root, driver2.cs.oldestVersion,
&driver2.cs)) {
debugPrintDot(stdout, driver2.cs.root, &driver2.cs);
fflush(stdout);
abort();
}
}
if (done1 && done2) {
break;
}
}
return 0;
}
#endif
// GCOVR_EXCL_STOP