Files
conflict-set/ConflictSet.cpp
2024-08-06 14:45:52 -07:00

4290 lines
132 KiB
C++

/*
Copyright 2024 Andrew Noyes
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "ConflictSet.h"
#include "Internal.h"
#include <algorithm>
#include <atomic>
#include <bit>
#include <cassert>
#include <compare>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <inttypes.h>
#include <limits>
#include <span>
#include <string>
#include <string_view>
#include <sys/time.h>
#include <type_traits>
#include <utility>
#ifdef HAS_AVX
#include <immintrin.h>
#elif defined(HAS_ARM_NEON)
#include <arm_neon.h>
#endif
#if defined(__has_feature)
#if __has_feature(thread_sanitizer)
#define __SANITIZE_THREAD__
#endif
#endif
#include <memcheck.h>
using namespace weaselab;
// Use assert for checking potentially complex properties during tests.
// Use assume to hint simple properties to the optimizer.
// TODO use the c++23 version when that's available
#ifdef NDEBUG
#if __has_builtin(__builtin_assume)
#define assume(e) __builtin_assume(e)
#else
#define assume(e) \
if (!(e)) \
__builtin_unreachable()
#endif
#else
#define assume assert
#endif
#if SHOW_MEMORY
void addNode(struct Node *);
void removeNode(struct Node *);
void addKey(struct Node *);
void removeKey(struct Node *);
#else
constexpr void addNode(struct Node *) {}
constexpr void removeNode(struct Node *) {}
constexpr void addKey(struct Node *) {}
constexpr void removeKey(struct Node *) {}
#endif
// ==================== BEGIN IMPLEMENTATION ====================
constexpr int64_t kNominalVersionWindow = 2e9;
constexpr int64_t kMaxCorrectVersionWindow =
std::numeric_limits<int32_t>::max();
static_assert(kNominalVersionWindow <= kMaxCorrectVersionWindow);
struct InternalVersionT {
constexpr InternalVersionT() = default;
constexpr explicit InternalVersionT(int64_t value) : value(value) {}
constexpr int64_t toInt64() const { return value; } // GCOVR_EXCL_LINE
constexpr auto operator<=>(const InternalVersionT &rhs) const {
// Maintains ordering after overflow, as long as the full-precision versions
// are within `kMaxCorrectVersionWindow` of eachother.
return int32_t(value - rhs.value) <=> 0;
}
constexpr bool operator==(const InternalVersionT &) const = default;
static thread_local InternalVersionT zero;
private:
uint32_t value;
};
thread_local InternalVersionT InternalVersionT::zero;
struct Entry {
InternalVersionT pointVersion;
InternalVersionT rangeVersion;
};
struct BitSet {
bool test(int i) const;
void set(int i);
void reset(int i);
int firstSetGeq(int i) const;
// Calls `f` with the index of each bit set
template <class F> void forEachSet(F f) const {
// See section 3.1 in https://arxiv.org/pdf/1709.07821.pdf for details about
// this approach
for (int begin = 0; begin < 256; begin += 64) {
uint64_t word = words[begin >> 6];
while (word) {
uint64_t temp = word & -word;
int index = begin + std::countr_zero(word);
f(index);
word ^= temp;
}
}
}
void init() {
for (auto &w : words) {
w = 0;
}
}
private:
uint64_t words[4];
};
bool BitSet::test(int i) const {
assert(0 <= i);
assert(i < 256);
return words[i >> 6] & (uint64_t(1) << (i & 63));
}
void BitSet::set(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] |= uint64_t(1) << (i & 63);
}
void BitSet::reset(int i) {
assert(0 <= i);
assert(i < 256);
words[i >> 6] &= ~(uint64_t(1) << (i & 63));
}
int BitSet::firstSetGeq(int i) const {
assume(0 <= i);
// i may be >= 256
uint64_t mask = uint64_t(-1) << (i & 63);
for (int j = i >> 6; j < 4; ++j) {
uint64_t masked = mask & words[j];
if (masked) {
return (j << 6) + std::countr_zero(masked);
}
mask = -1;
}
return -1;
}
enum Type {
Type_Node0,
Type_Node3,
Type_Node16,
Type_Node48,
Type_Node256,
};
template <class T> struct BoundedFreeListAllocator;
struct Node {
/* begin section that's copied to the next node */
Entry entry;
Node *parent;
int32_t partialKeyLen;
int16_t numChildren;
bool entryPresent;
uint8_t parentsIndex;
/* end section that's copied to the next node */
uint8_t *partialKey();
size_t size() const;
Type getType() const { return type; }
int32_t getCapacity() const { return partialKeyCapacity; }
private:
template <class T> friend struct BoundedFreeListAllocator;
// These are publically readable, but should only be written by
// BoundedFreeListAllocator
Type type;
int32_t partialKeyCapacity;
};
constexpr int kNodeCopyBegin = offsetof(Node, entry);
constexpr int kNodeCopySize =
offsetof(Node, parentsIndex) + sizeof(Node::parentsIndex) - kNodeCopyBegin;
// copyChildrenAndKeyFrom is responsible for copying all
// public members of Node, copying the partial key, logically copying the
// children (converting representation if necessary), and updating all the
// children's parent pointers. The caller must then insert the new node into the
// tree.
struct Node0 : Node {
constexpr static auto kType = Type_Node0;
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node0 &other);
void copyChildrenAndKeyFrom(const struct Node3 &other);
size_t size() const { return sizeof(Node0) + getCapacity(); }
};
struct Node3 : Node {
constexpr static auto kMaxNodes = 3;
constexpr static auto kType = Type_Node3;
// Sorted
uint8_t index[kMaxNodes];
Node *children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node0 &other);
void copyChildrenAndKeyFrom(const Node3 &other);
void copyChildrenAndKeyFrom(const struct Node16 &other);
size_t size() const { return sizeof(Node3) + getCapacity(); }
};
struct Node16 : Node {
constexpr static auto kType = Type_Node16;
constexpr static auto kMaxNodes = 16;
// Sorted
uint8_t index[kMaxNodes];
Node *children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node3 &other);
void copyChildrenAndKeyFrom(const Node16 &other);
void copyChildrenAndKeyFrom(const struct Node48 &other);
size_t size() const { return sizeof(Node16) + getCapacity(); }
};
struct Node48 : Node {
constexpr static auto kType = Type_Node48;
constexpr static auto kMaxNodes = 48;
BitSet bitSet;
int8_t nextFree;
int8_t index[256];
Node *children[kMaxNodes];
InternalVersionT childMaxVersion[kMaxNodes];
uint8_t reverseIndex[kMaxNodes];
constexpr static int kMaxOfMaxPageSize = 16;
constexpr static int kMaxOfMaxShift =
std::countr_zero(uint32_t(kMaxOfMaxPageSize));
constexpr static int kMaxOfMaxTotalPages = kMaxNodes / kMaxOfMaxPageSize;
InternalVersionT maxOfMax[kMaxOfMaxTotalPages];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node16 &other);
void copyChildrenAndKeyFrom(const Node48 &other);
void copyChildrenAndKeyFrom(const struct Node256 &other);
size_t size() const { return sizeof(Node48) + getCapacity(); }
};
struct Node256 : Node {
constexpr static auto kType = Type_Node256;
BitSet bitSet;
Node *children[256];
InternalVersionT childMaxVersion[256];
constexpr static int kMaxOfMaxPageSize = 16;
constexpr static int kMaxOfMaxShift =
std::countr_zero(uint32_t(kMaxOfMaxPageSize));
constexpr static int kMaxOfMaxTotalPages = 256 / kMaxOfMaxPageSize;
InternalVersionT maxOfMax[kMaxOfMaxTotalPages];
uint8_t *partialKey() { return (uint8_t *)(this + 1); }
void copyChildrenAndKeyFrom(const Node48 &other);
void copyChildrenAndKeyFrom(const Node256 &other);
size_t size() const { return sizeof(Node256) + getCapacity(); }
};
inline void Node0::copyChildrenAndKeyFrom(const Node0 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node0::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node3::copyChildrenAndKeyFrom(const Node0 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node3::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, sizeof(*this) - sizeof(Node));
memcpy(partialKey(), &other + 1, partialKeyLen);
for (int i = 0; i < numChildren; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node3::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, kMaxNodes);
memcpy(children, other.children, kMaxNodes * sizeof(children[0])); // NOLINT
memcpy(childMaxVersion, other.childMaxVersion,
kMaxNodes * sizeof(childMaxVersion[0]));
memcpy(partialKey(), &other + 1, partialKeyLen);
for (int i = 0; i < numChildren; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node16::copyChildrenAndKeyFrom(const Node3 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, Node3::kMaxNodes);
memcpy(children, other.children,
Node3::kMaxNodes * sizeof(children[0])); // NOLINT
memcpy(childMaxVersion, other.childMaxVersion,
Node3::kMaxNodes * sizeof(childMaxVersion[0]));
memcpy(partialKey(), &other + 1, partialKeyLen);
assert(numChildren == Node3::kMaxNodes);
for (int i = 0; i < Node3::kMaxNodes; ++i) {
assert(children[i]->parent == &other);
children[i]->parent = this;
}
}
inline void Node16::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memcpy(index, other.index, sizeof(index));
for (int i = 0; i < numChildren; ++i) {
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
}
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node16::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
int i = 0;
other.bitSet.forEachSet([&](int c) {
// Suppress a false positive -Waggressive-loop-optimizations warning
// in gcc
assume(i < Node16::kMaxNodes);
index[i] = c;
children[i] = other.children[other.index[c]];
childMaxVersion[i] = other.childMaxVersion[other.index[c]];
assert(children[i]->parent == &other);
children[i]->parent = this;
++i;
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node48::copyChildrenAndKeyFrom(const Node16 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
assert(numChildren == Node16::kMaxNodes);
memset(index, -1, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
memcpy(partialKey(), &other + 1, partialKeyLen);
bitSet.init();
nextFree = Node16::kMaxNodes;
int i = 0;
for (auto x : other.index) {
bitSet.set(x);
index[x] = i;
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
reverseIndex[i] = x;
maxOfMax[i >> Node48::kMaxOfMaxShift] =
std::max(maxOfMax[i >> Node48::kMaxOfMaxShift], childMaxVersion[i]);
++i;
}
}
inline void Node48::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
bitSet = other.bitSet;
nextFree = other.nextFree;
memcpy(index, other.index, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
for (int i = 0; i < numChildren; ++i) {
children[i] = other.children[i];
childMaxVersion[i] = other.childMaxVersion[i];
assert(children[i]->parent == &other);
children[i]->parent = this;
}
memcpy(reverseIndex, other.reverseIndex, sizeof(reverseIndex));
memcpy(maxOfMax, other.maxOfMax, sizeof(maxOfMax));
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node48::copyChildrenAndKeyFrom(const Node256 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memset(index, -1, sizeof(index));
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
nextFree = other.numChildren;
bitSet = other.bitSet;
int i = 0;
bitSet.forEachSet([&](int c) {
// Suppress a false positive -Waggressive-loop-optimizations warning
// in gcc.
assume(i < Node48::kMaxNodes);
index[c] = i;
children[i] = other.children[c];
childMaxVersion[i] = other.childMaxVersion[c];
assert(children[i]->parent == &other);
children[i]->parent = this;
reverseIndex[i] = c;
maxOfMax[i >> Node48::kMaxOfMaxShift] =
std::max(maxOfMax[i >> Node48::kMaxOfMaxShift], childMaxVersion[i]);
++i;
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node256::copyChildrenAndKeyFrom(const Node48 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
bitSet = other.bitSet;
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
for (auto &v : maxOfMax) {
v = z;
}
bitSet.forEachSet([&](int c) {
children[c] = other.children[other.index[c]];
childMaxVersion[c] = other.childMaxVersion[other.index[c]];
assert(children[c]->parent == &other);
children[c]->parent = this;
maxOfMax[c >> Node256::kMaxOfMaxShift] =
std::max(maxOfMax[c >> Node256::kMaxOfMaxShift], childMaxVersion[c]);
});
memcpy(partialKey(), &other + 1, partialKeyLen);
}
inline void Node256::copyChildrenAndKeyFrom(const Node256 &other) {
memcpy((char *)this + kNodeCopyBegin, (char *)&other + kNodeCopyBegin,
kNodeCopySize);
memset(children, 0, sizeof(children));
const auto z = InternalVersionT::zero;
for (auto &v : childMaxVersion) {
v = z;
}
bitSet = other.bitSet;
bitSet.forEachSet([&](int c) {
children[c] = other.children[c];
childMaxVersion[c] = other.childMaxVersion[c];
assert(children[c]->parent == &other);
children[c]->parent = this;
});
memcpy(maxOfMax, other.maxOfMax, sizeof(maxOfMax));
memcpy(partialKey(), &other + 1, partialKeyLen);
}
namespace {
std::string getSearchPathPrintable(Node *n);
std::string getSearchPath(Node *n);
} // namespace
// Bound memory usage following the analysis in the ART paper
// Each node with an entry present gets a budget of kBytesPerKey. Node0 always
// has an entry present.
// Induction hypothesis is that each node's surplus is >= kMinNodeSurplus
constexpr int kBytesPerKey = 112;
constexpr int kMinNodeSurplus = 80;
constexpr int kMinChildrenNode3 = 2;
constexpr int kMinChildrenNode16 = 4;
constexpr int kMinChildrenNode48 = 17;
constexpr int kMinChildrenNode256 = 49;
constexpr int kNode256Surplus =
kMinChildrenNode256 * kMinNodeSurplus - sizeof(Node256);
static_assert(kNode256Surplus >= kMinNodeSurplus);
constexpr int kNode48Surplus =
kMinChildrenNode48 * kMinNodeSurplus - sizeof(Node48);
static_assert(kNode48Surplus >= kMinNodeSurplus);
constexpr int kNode16Surplus =
kMinChildrenNode16 * kMinNodeSurplus - sizeof(Node16);
static_assert(kNode16Surplus >= kMinNodeSurplus);
constexpr int kNode3Surplus =
kMinChildrenNode3 * kMinNodeSurplus - sizeof(Node3);
static_assert(kNode3Surplus >= kMinNodeSurplus);
static_assert(kBytesPerKey - sizeof(Node0) >= kMinNodeSurplus);
// setOldestVersion will additionally try to maintain this property:
// `(children + entryPresent) * length >= capacity`
//
// Which should give us the budget to pay for the key bytes. (children +
// entryPresent) is a lower bound on how many keys these bytes are a prefix of
constexpr int64_t kFreeListMaxMemory = 1 << 20;
struct Metric {
Metric *prev;
const char *name;
const char *help;
ConflictSet::MetricsV1::Type type;
std::atomic<int64_t> value;
protected:
Metric(ConflictSet::Impl *impl, const char *name, const char *help,
ConflictSet::MetricsV1::Type type);
};
struct Gauge : private Metric {
Gauge(ConflictSet::Impl *impl, const char *name, const char *help)
: Metric(impl, name, help, ConflictSet::MetricsV1::Gauge) {}
void set(int64_t value) {
this->value.store(value, std::memory_order_relaxed);
}
};
struct Counter : private Metric {
Counter(ConflictSet::Impl *impl, const char *name, const char *help)
: Metric(impl, name, help, ConflictSet::MetricsV1::Counter) {}
// Expensive. Accumulate locally and then call add instead of repeatedly
// calling add.
void add(int64_t value) {
assert(value >= 0);
static_assert(std::atomic<int64_t>::is_always_lock_free);
this->value.fetch_add(value, std::memory_order_relaxed);
}
};
template <class T> struct BoundedFreeListAllocator {
static_assert(sizeof(T) >= sizeof(void *));
static_assert(std::derived_from<T, Node>);
static_assert(std::is_trivial_v<T>);
T *allocate_helper(int partialKeyCapacity) {
if (freeList != nullptr) {
T *n = (T *)freeList;
VALGRIND_MAKE_MEM_DEFINED(freeList, sizeof(freeList));
memcpy(&freeList, freeList, sizeof(freeList));
VALGRIND_MAKE_MEM_UNDEFINED(n, sizeof(T));
VALGRIND_MAKE_MEM_DEFINED(&n->partialKeyCapacity,
sizeof(n->partialKeyCapacity));
VALGRIND_MAKE_MEM_DEFINED(&n->type, sizeof(n->type));
assert(n->type == T::kType);
VALGRIND_MAKE_MEM_UNDEFINED(n + 1, n->partialKeyCapacity);
freeListBytes -= sizeof(T) + n->partialKeyCapacity;
if (n->partialKeyCapacity >= partialKeyCapacity) {
return n;
} else {
// The intent is to filter out too-small nodes in the freelist
removeNode(n);
safe_free(n, sizeof(T) + n->partialKeyCapacity);
}
}
auto *result = (T *)safe_malloc(sizeof(T) + partialKeyCapacity);
result->type = T::kType;
result->partialKeyCapacity = partialKeyCapacity;
addNode(result);
return result;
}
T *allocate(int partialKeyCapacity) {
T *result = allocate_helper(partialKeyCapacity);
if constexpr (!std::is_same_v<T, Node0>) {
memset(result->children, 0, sizeof(result->children));
const auto z = InternalVersionT::zero;
for (auto &v : result->childMaxVersion) {
v = z;
}
}
if constexpr (std::is_same_v<T, Node48> || std::is_same_v<T, Node256>) {
const auto z = InternalVersionT::zero;
for (auto &v : result->maxOfMax) {
v = z;
}
}
return result;
}
void release(T *p) {
if (freeListBytes >= kFreeListMaxMemory) {
removeNode(p);
return safe_free(p, sizeof(T) + p->partialKeyCapacity);
}
memcpy((void *)p, &freeList, sizeof(freeList));
freeList = p;
freeListBytes += sizeof(T) + p->partialKeyCapacity;
VALGRIND_MAKE_MEM_NOACCESS(freeList, sizeof(T) + p->partialKeyCapacity);
}
BoundedFreeListAllocator() = default;
BoundedFreeListAllocator(const BoundedFreeListAllocator &) = delete;
BoundedFreeListAllocator &
operator=(const BoundedFreeListAllocator &) = delete;
BoundedFreeListAllocator(BoundedFreeListAllocator &&) = delete;
BoundedFreeListAllocator &operator=(BoundedFreeListAllocator &&) = delete;
~BoundedFreeListAllocator() {
for (void *iter = freeList; iter != nullptr;) {
VALGRIND_MAKE_MEM_DEFINED(iter, sizeof(Node));
auto *tmp = (T *)iter;
memcpy(&iter, iter, sizeof(void *));
removeNode((tmp));
safe_free(tmp, sizeof(T) + tmp->partialKeyCapacity);
}
}
private:
int64_t freeListBytes = 0;
void *freeList = nullptr;
};
uint8_t *Node::partialKey() {
switch (type) {
case Type_Node0:
return ((Node0 *)this)->partialKey();
case Type_Node3:
return ((Node3 *)this)->partialKey();
case Type_Node16:
return ((Node16 *)this)->partialKey();
case Type_Node48:
return ((Node48 *)this)->partialKey();
case Type_Node256:
return ((Node256 *)this)->partialKey();
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
size_t Node::size() const {
switch (type) {
case Type_Node0:
return ((Node0 *)this)->size();
case Type_Node3:
return ((Node3 *)this)->size();
case Type_Node16:
return ((Node16 *)this)->size();
case Type_Node48:
return ((Node48 *)this)->size();
case Type_Node256:
return ((Node256 *)this)->size();
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// A type that's plumbed along the check call tree. Lifetime ends after each
// check call.
struct ReadContext {
int64_t point_read_accum = 0;
int64_t prefix_read_accum = 0;
int64_t range_read_accum = 0;
int64_t point_read_short_circuit_accum = 0;
int64_t prefix_read_short_circuit_accum = 0;
int64_t range_read_short_circuit_accum = 0;
int64_t point_read_iterations_accum = 0;
int64_t prefix_read_iterations_accum = 0;
int64_t range_read_iterations_accum = 0;
int64_t range_read_node_scan_accum = 0;
int64_t commits_accum = 0;
int64_t conflicts_accum = 0;
int64_t too_olds_accum = 0;
ConflictSet::Impl *impl;
};
// A type that's plumbed along the non-const call tree. Same lifetime as
// ConflictSet::Impl
struct WriteContext {
struct Accum {
int64_t entries_erased;
int64_t insert_iterations;
int64_t entries_inserted;
int64_t nodes_allocated;
int64_t nodes_released;
int64_t point_writes;
int64_t range_writes;
int64_t write_bytes;
} accum;
// Cache a copy of InternalVersionT::zero, so we don't need to do the TLS
// lookup as often.
InternalVersionT zero;
WriteContext() { memset(&accum, 0, sizeof(accum)); }
template <class T> T *allocate(int c) {
++accum.nodes_allocated;
if constexpr (std::is_same_v<T, Node0>) {
return node0.allocate(c);
} else if constexpr (std::is_same_v<T, Node3>) {
return node3.allocate(c);
} else if constexpr (std::is_same_v<T, Node16>) {
return node16.allocate(c);
} else if constexpr (std::is_same_v<T, Node48>) {
return node48.allocate(c);
} else if constexpr (std::is_same_v<T, Node256>) {
return node256.allocate(c);
}
}
template <class T> void release(T *c) {
static_assert(!std::is_same_v<T, Node>);
++accum.nodes_released;
if constexpr (std::is_same_v<T, Node0>) {
return node0.release(c);
} else if constexpr (std::is_same_v<T, Node3>) {
return node3.release(c);
} else if constexpr (std::is_same_v<T, Node16>) {
return node16.release(c);
} else if constexpr (std::is_same_v<T, Node48>) {
return node48.release(c);
} else if constexpr (std::is_same_v<T, Node256>) {
return node256.release(c);
}
}
private:
BoundedFreeListAllocator<Node0> node0;
BoundedFreeListAllocator<Node3> node3;
BoundedFreeListAllocator<Node16> node16;
BoundedFreeListAllocator<Node48> node48;
BoundedFreeListAllocator<Node256> node256;
};
template <class NodeT> int getNodeIndex(NodeT *self, uint8_t index) {
static_assert(std::is_same_v<NodeT, Node3> || std::is_same_v<NodeT, Node16>);
// cachegrind says the plain loop is fewer instructions and more mis-predicted
// branches. Microbenchmark says plain loop is faster. It's written in this
// weird "generic" way though in case someday we can use the simd
// implementation easily if we want.
if constexpr (std::is_same_v<NodeT, Node3>) {
Node3 *n = (Node3 *)self;
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] == index) {
return i;
}
}
return -1;
}
#ifdef HAS_AVX
// Based on https://www.the-paper-trail.org/post/art-paper-notes/
// key_vec is 16 repeated copies of the searched-for byte, one for every
// possible position in child_keys that needs to be searched.
__m128i key_vec = _mm_set1_epi8(index);
// Compare all child_keys to 'index' in parallel. Don't worry if some of the
// keys aren't valid, we'll mask the results to only consider the valid ones
// below.
__m128i indices;
memcpy(&indices, self->index, NodeT::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, indices);
// Build a mask to select only the first node->num_children values from the
// comparison (because the other values are meaningless)
uint32_t mask = (1 << self->numChildren) - 1;
// Change the results of the comparison into a bitfield, masking off any
// invalid comparisons.
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
// No match if there are no '1's in the bitfield.
if (bitfield == 0)
return -1;
// Find the index of the first '1' in the bitfield by counting the leading
// zeros.
return std::countr_zero(bitfield);
#elif defined(HAS_ARM_NEON)
// Based on
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
uint8x16_t indices;
memcpy(&indices, self->index, NodeT::kMaxNodes);
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices));
static_assert(NodeT::kMaxNodes <= 16);
assume(self->numChildren <= NodeT::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each match in valid range
uint64_t bitfield =
vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask;
if (bitfield == 0)
return -1;
return std::countr_zero(bitfield) / 4;
#else
for (int i = 0; i < self->numChildren; ++i) {
if (self->index[i] == index) {
return i;
}
}
return -1;
#endif
}
// Precondition - an entry for index must exist in the node
Node *&getChildExists(Node3 *self, uint8_t index) {
return self->children[getNodeIndex(self, index)];
}
// Precondition - an entry for index must exist in the node
Node *&getChildExists(Node16 *self, uint8_t index) {
return self->children[getNodeIndex(self, index)];
}
// Precondition - an entry for index must exist in the node
Node *&getChildExists(Node48 *self, uint8_t index) {
assert(self->bitSet.test(index));
return self->children[self->index[index]];
}
// Precondition - an entry for index must exist in the node
Node *&getChildExists(Node256 *self, uint8_t index) {
assert(self->bitSet.test(index));
return self->children[index];
}
// Precondition - an entry for index must exist in the node
Node *&getChildExists(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
return getChildExists(static_cast<Node3 *>(self), index);
}
case Type_Node16: {
return getChildExists(static_cast<Node16 *>(self), index);
}
case Type_Node48: {
return getChildExists(static_cast<Node48 *>(self), index);
}
case Type_Node256: {
return getChildExists(static_cast<Node256 *>(self), index);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
InternalVersionT maxVersion(Node *n, ConflictSet::Impl *);
InternalVersionT exchangeMaxVersion(Node *n, InternalVersionT newMax,
ConflictSet::Impl *);
void setMaxVersion(Node *n, ConflictSet::Impl *, InternalVersionT maxVersion);
Node *&getInTree(Node *n, ConflictSet::Impl *);
Node *getChild(Node0 *, uint8_t) { return nullptr; }
Node *getChild(Node3 *self, uint8_t index) {
int i = getNodeIndex(self, index);
return i < 0 ? nullptr : self->children[i];
}
Node *getChild(Node16 *self, uint8_t index) {
int i = getNodeIndex(self, index);
return i < 0 ? nullptr : self->children[i];
}
Node *getChild(Node48 *self, uint8_t index) {
int i = self->index[index];
return i < 0 ? nullptr : self->children[i];
}
Node *getChild(Node256 *self, uint8_t index) { return self->children[index]; }
Node *getChild(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChild(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChild(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChild(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChild(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChild(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
struct ChildAndMaxVersion {
Node *child;
InternalVersionT maxVersion;
};
ChildAndMaxVersion getChildAndMaxVersion(Node0 *, uint8_t) { return {}; }
ChildAndMaxVersion getChildAndMaxVersion(Node3 *self, uint8_t index) {
int i = getNodeIndex(self, index);
if (i < 0) {
return {};
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node16 *self, uint8_t index) {
int i = getNodeIndex(self, index);
if (i < 0) {
return {};
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node48 *self, uint8_t index) {
int i = self->index[index];
if (i < 0) {
return {};
}
return {self->children[i], self->childMaxVersion[i]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node256 *self, uint8_t index) {
return {self->children[index], self->childMaxVersion[index]};
}
ChildAndMaxVersion getChildAndMaxVersion(Node *self, uint8_t index) {
switch (self->getType()) {
case Type_Node0:
return getChildAndMaxVersion(static_cast<Node0 *>(self), index);
case Type_Node3:
return getChildAndMaxVersion(static_cast<Node3 *>(self), index);
case Type_Node16:
return getChildAndMaxVersion(static_cast<Node16 *>(self), index);
case Type_Node48:
return getChildAndMaxVersion(static_cast<Node48 *>(self), index);
case Type_Node256:
return getChildAndMaxVersion(static_cast<Node256 *>(self), index);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
template <class NodeT> Node *getChildGeqSimd(NodeT *self, int child) {
static_assert(std::is_same_v<NodeT, Node3> || std::is_same_v<NodeT, Node16>);
// cachegrind says the plain loop is fewer instructions and more mis-predicted
// branches. Microbenchmark says plain loop is faster. It's written in this
// weird "generic" way though so that someday we can use the simd
// implementation easily if we want.
if constexpr (std::is_same_v<NodeT, Node3>) {
Node3 *n = (Node3 *)self;
for (int i = 0; i < n->numChildren; ++i) {
if (n->index[i] >= child) {
return n->children[i];
}
}
return nullptr;
}
if (child > 255) {
return nullptr;
}
#ifdef HAS_AVX
__m128i key_vec = _mm_set1_epi8(child);
__m128i indices;
memcpy(&indices, self->index, NodeT::kMaxNodes);
__m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices));
int mask = (1 << self->numChildren) - 1;
uint32_t bitfield = _mm_movemask_epi8(results) & mask;
return bitfield == 0 ? nullptr : self->children[std::countr_zero(bitfield)];
#elif defined(HAS_ARM_NEON)
uint8x16_t indices;
memcpy(&indices, self->index, sizeof(self->index));
// 0xff for each leq
auto results = vcleq_u8(vdupq_n_u8(child), indices);
static_assert(NodeT::kMaxNodes <= 16);
assume(self->numChildren <= NodeT::kMaxNodes);
uint64_t mask = self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren * 4)) - 1;
// 0xf for each 0xff (within mask)
uint64_t bitfield =
vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)),
0) &
mask;
return bitfield == 0 ? nullptr
: self->children[std::countr_zero(bitfield) / 4];
#else
for (int i = 0; i < self->numChildren; ++i) {
if (i > 0) {
assert(self->index[i - 1] < self->index[i]);
}
if (self->index[i] >= child) {
return self->children[i];
}
}
return nullptr;
#endif
}
Node *getChildGeq(Node0 *, int) { return nullptr; }
Node *getChildGeq(Node3 *self, int child) {
return getChildGeqSimd(self, child);
}
Node *getChildGeq(Node16 *self, int child) {
return getChildGeqSimd(self, child);
}
Node *getChildGeq(Node48 *self, int child) {
int c = self->bitSet.firstSetGeq(child);
if (c < 0) {
return nullptr;
}
return self->children[self->index[c]];
}
Node *getChildGeq(Node256 *self, int child) {
int c = self->bitSet.firstSetGeq(child);
if (c < 0) {
return nullptr;
}
return self->children[c];
}
Node *getChildGeq(Node *self, int child) {
switch (self->getType()) {
case Type_Node0:
return getChildGeq(static_cast<Node0 *>(self), child);
case Type_Node3:
return getChildGeq(static_cast<Node3 *>(self), child);
case Type_Node16:
return getChildGeq(static_cast<Node16 *>(self), child);
case Type_Node48:
return getChildGeq(static_cast<Node48 *>(self), child);
case Type_Node256:
return getChildGeq(static_cast<Node256 *>(self), child);
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition: self has a child
Node *getFirstChildExists(Node3 *self) {
assert(self->numChildren > 0);
return self->children[0];
}
// Precondition: self has a child
Node *getFirstChildExists(Node16 *self) {
assert(self->numChildren > 0);
return self->children[0];
}
// Precondition: self has a child
Node *getFirstChildExists(Node48 *self) {
return self->children[self->index[self->bitSet.firstSetGeq(0)]];
}
// Precondition: self has a child
Node *getFirstChildExists(Node256 *self) {
return self->children[self->bitSet.firstSetGeq(0)];
}
// Precondition: self has a child
Node *getFirstChildExists(Node *self) {
switch (self->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3:
return getFirstChildExists(static_cast<Node3 *>(self));
case Type_Node16:
return getFirstChildExists(static_cast<Node16 *>(self));
case Type_Node48:
return getFirstChildExists(static_cast<Node48 *>(self));
case Type_Node256:
return getFirstChildExists(static_cast<Node256 *>(self));
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Caller is responsible for assigning a non-null pointer to the returned
// reference if null. Updates child's max version to `newMaxVersion` if child
// exists but does not have a partial key.
Node *&getOrCreateChild(Node *&self, uint8_t index,
InternalVersionT newMaxVersion, WriteContext *tls) {
// Fast path for if it exists already
switch (self->getType()) {
case Type_Node0:
break;
case Type_Node3: {
auto *self3 = static_cast<Node3 *>(self);
int i = getNodeIndex(self3, index);
if (i >= 0) {
if (self3->children[i]->partialKeyLen == 0) {
self3->childMaxVersion[i] = newMaxVersion;
}
return self3->children[i];
}
} break;
case Type_Node16: {
auto *self16 = static_cast<Node16 *>(self);
int i = getNodeIndex(self16, index);
if (i >= 0) {
if (self16->children[i]->partialKeyLen == 0) {
self16->childMaxVersion[i] = newMaxVersion;
}
return self16->children[i];
}
} break;
case Type_Node48: {
auto *self48 = static_cast<Node48 *>(self);
int secondIndex = self48->index[index];
if (secondIndex >= 0) {
if (self48->children[secondIndex]->partialKeyLen == 0) {
self48->childMaxVersion[secondIndex] = newMaxVersion;
self48->maxOfMax[secondIndex >> Node48::kMaxOfMaxShift] =
std::max(self48->maxOfMax[secondIndex >> Node48::kMaxOfMaxShift],
newMaxVersion);
}
return self48->children[secondIndex];
}
} break;
case Type_Node256: {
auto *self256 = static_cast<Node256 *>(self);
if (auto &result = self256->children[index]; result != nullptr) {
if (self256->children[index]->partialKeyLen == 0) {
self256->childMaxVersion[index] = newMaxVersion;
self256->maxOfMax[index >> Node256::kMaxOfMaxShift] = std::max(
self256->maxOfMax[index >> Node256::kMaxOfMaxShift], newMaxVersion);
}
return result;
}
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
switch (self->getType()) {
case Type_Node0: {
auto *self0 = static_cast<Node0 *>(self);
auto *newSelf = tls->allocate<Node3>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self0);
tls->release(self0);
self = newSelf;
goto insert3;
}
case Type_Node3: {
if (self->numChildren == Node3::kMaxNodes) {
auto *self3 = static_cast<Node3 *>(self);
auto *newSelf = tls->allocate<Node16>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self3);
tls->release(self3);
self = newSelf;
goto insert16;
}
insert3:
auto *self3 = static_cast<Node3 *>(self);
int i = self->numChildren - 1;
for (; i >= 0; --i) {
if (int(self3->index[i]) < int(index)) {
break;
}
self3->index[i + 1] = self3->index[i];
self3->children[i + 1] = self3->children[i];
self3->childMaxVersion[i + 1] = self3->childMaxVersion[i];
}
self3->index[i + 1] = index;
auto &result = self3->children[i + 1];
result = nullptr;
++self->numChildren;
return result;
}
case Type_Node16: {
if (self->numChildren == Node16::kMaxNodes) {
auto *self16 = static_cast<Node16 *>(self);
auto *newSelf = tls->allocate<Node48>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self16);
tls->release(self16);
self = newSelf;
goto insert48;
}
insert16:
assert(self->getType() == Type_Node16);
auto *self16 = static_cast<Node16 *>(self);
int i = self->numChildren - 1;
for (; i >= 0; --i) {
if (int(self16->index[i]) < int(index)) {
break;
}
self16->index[i + 1] = self16->index[i];
self16->children[i + 1] = self16->children[i];
self16->childMaxVersion[i + 1] = self16->childMaxVersion[i];
}
self16->index[i + 1] = index;
auto &result = self16->children[i + 1];
result = nullptr;
++self->numChildren;
return result;
}
case Type_Node48: {
if (self->numChildren == 48) {
auto *self48 = static_cast<Node48 *>(self);
auto *newSelf = tls->allocate<Node256>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self48);
tls->release(self48);
self = newSelf;
goto insert256;
}
insert48:
auto *self48 = static_cast<Node48 *>(self);
self48->bitSet.set(index);
++self->numChildren;
assert(self48->nextFree < 48);
int nextFree = self48->nextFree++;
self48->index[index] = nextFree;
self48->reverseIndex[nextFree] = index;
auto &result = self48->children[nextFree];
result = nullptr;
return result;
}
case Type_Node256: {
insert256:
auto *self256 = static_cast<Node256 *>(self);
++self->numChildren;
self256->bitSet.set(index);
return self256->children[index];
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
Node *nextPhysical(Node *node) {
int index = -1;
for (;;) {
auto nextChild = getChildGeq(node, index + 1);
if (nextChild != nullptr) {
return nextChild;
}
index = node->parentsIndex;
node = node->parent;
if (node == nullptr) {
return nullptr;
}
}
}
Node *nextLogical(Node *node) {
for (node = nextPhysical(node); node != nullptr && !node->entryPresent;
node = nextPhysical(node))
;
return node;
}
// Invalidates `self`, replacing it with a node of at least capacity.
// Does not return nodes to freelists when kUseFreeList is false.
void freeAndMakeCapacityAtLeast(Node *&self, int capacity, WriteContext *tls,
ConflictSet::Impl *impl,
const bool kUseFreeList) {
switch (self->getType()) {
case Type_Node0: {
auto *self0 = (Node0 *)self;
auto *newSelf = tls->allocate<Node0>(capacity);
newSelf->copyChildrenAndKeyFrom(*self0);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self0);
} else {
removeNode(self0);
safe_free(self0, self0->size());
}
self = newSelf;
} break;
case Type_Node3: {
auto *self3 = (Node3 *)self;
auto *newSelf = tls->allocate<Node3>(capacity);
newSelf->copyChildrenAndKeyFrom(*self3);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self3);
} else {
removeNode(self3);
safe_free(self3, self3->size());
}
self = newSelf;
} break;
case Type_Node16: {
auto *self16 = (Node16 *)self;
auto *newSelf = tls->allocate<Node16>(capacity);
newSelf->copyChildrenAndKeyFrom(*self16);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self16);
} else {
removeNode(self16);
safe_free(self16, self16->size());
}
self = newSelf;
} break;
case Type_Node48: {
auto *self48 = (Node48 *)self;
auto *newSelf = tls->allocate<Node48>(capacity);
newSelf->copyChildrenAndKeyFrom(*self48);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self48);
} else {
removeNode(self48);
safe_free(self48, self48->size());
}
self = newSelf;
} break;
case Type_Node256: {
auto *self256 = (Node256 *)self;
auto *newSelf = tls->allocate<Node256>(capacity);
newSelf->copyChildrenAndKeyFrom(*self256);
getInTree(self, impl) = newSelf;
if (kUseFreeList) {
tls->release(self256);
} else {
removeNode(self256);
safe_free(self256, self256->size());
}
self = newSelf;
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Fix larger-than-desired capacities. Does not return nodes to freelists,
// since that wouldn't actually reclaim the memory used for partial key
// capacity.
void maybeDecreaseCapacity(Node *&self, WriteContext *tls,
ConflictSet::Impl *impl) {
const int maxCapacity =
(self->numChildren + int(self->entryPresent)) * (self->partialKeyLen + 1);
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "maybeDecreaseCapacity: current: %d, max: %d, key: %s\n",
self->getCapacity(), maxCapacity,
getSearchPathPrintable(self).c_str());
#endif
if (self->getCapacity() <= maxCapacity) {
return;
}
freeAndMakeCapacityAtLeast(self, maxCapacity, tls, impl, false);
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
// This gets covered in local development
// GCOVR_EXCL_START
__attribute__((target("avx512f"))) void rezero16(InternalVersionT *vs,
InternalVersionT zero) {
uint32_t z;
memcpy(&z, &zero, sizeof(z));
const auto zvec = _mm512_set1_epi32(z);
auto m = _mm512_cmplt_epi32_mask(
_mm512_sub_epi32(_mm512_loadu_epi32(vs), zvec), _mm512_setzero_epi32());
_mm512_mask_storeu_epi32(vs, m, zvec);
}
// GCOVR_EXCL_STOP
__attribute__((target("default")))
#endif
void rezero16(InternalVersionT *vs, InternalVersionT zero) {
for (int i = 0; i < 16; ++i) {
vs[i] = std::max(vs[i], zero);
}
}
void rezero(Node *n, InternalVersionT z) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "rezero to %" PRId64 ": %s\n", z.toInt64(),
getSearchPathPrintable(n).c_str());
#endif
if (n->entryPresent) {
n->entry.pointVersion = std::max(n->entry.pointVersion, z);
n->entry.rangeVersion = std::max(n->entry.rangeVersion, z);
}
switch (n->getType()) {
case Type_Node0: {
} break;
case Type_Node3: {
auto *self = static_cast<Node3 *>(n);
for (int i = 0; i < 3; ++i) {
self->childMaxVersion[i] = std::max(self->childMaxVersion[i], z);
}
} break;
case Type_Node16: {
auto *self = static_cast<Node16 *>(n);
rezero16(self->childMaxVersion, z);
} break;
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
for (int i = 0; i < 48; i += 16) {
rezero16(self->childMaxVersion + i, z);
}
for (auto &m : self->maxOfMax) {
m = std::max(m, z);
}
} break;
case Type_Node256: {
auto *self = static_cast<Node256 *>(n);
for (int i = 0; i < 256; i += 16) {
rezero16(self->childMaxVersion + i, z);
}
rezero16(self->maxOfMax, z);
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
void maybeDownsize(Node *self, WriteContext *tls, ConflictSet::Impl *impl,
Node *&dontInvalidate) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "maybeDownsize: %s\n", getSearchPathPrintable(self).c_str());
#endif
switch (self->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *self3 = (Node3 *)self;
if (self->numChildren == 0) {
auto *newSelf = tls->allocate<Node0>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self3);
getInTree(self, impl) = newSelf;
tls->release(self3);
} else if (self->numChildren == 1 && !self->entryPresent) {
auto *child = self3->children[0];
int minCapacity = self3->partialKeyLen + 1 + child->partialKeyLen;
if (minCapacity > child->getCapacity()) {
const bool update = child == dontInvalidate;
freeAndMakeCapacityAtLeast(child, minCapacity, tls, impl, true);
if (update) {
dontInvalidate = child;
}
}
// Merge partial key with child
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Merge %s into %s\n",
getSearchPathPrintable(self).c_str(),
getSearchPathPrintable(child).c_str());
#endif
InternalVersionT childMaxVersion = maxVersion(child, impl);
// Construct new partial key for child
memmove(child->partialKey() + self3->partialKeyLen + 1,
child->partialKey(), child->partialKeyLen);
memcpy(child->partialKey(), self3->partialKey(), self->partialKeyLen);
child->partialKey()[self3->partialKeyLen] = self3->index[0];
child->partialKeyLen += 1 + self3->partialKeyLen;
child->parent = self->parent;
child->parentsIndex = self->parentsIndex;
// Max versions are stored in the parent, so we need to update it now
// that we have a new parent.
setMaxVersion(child, impl, childMaxVersion);
if (child->parent) {
rezero(child->parent, tls->zero);
}
getInTree(self, impl) = child;
tls->release(self3);
}
} break;
case Type_Node16:
if (self->numChildren + int(self->entryPresent) < kMinChildrenNode16) {
auto *self16 = (Node16 *)self;
auto *newSelf = tls->allocate<Node3>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self16);
getInTree(self, impl) = newSelf;
tls->release(self16);
}
break;
case Type_Node48:
if (self->numChildren + int(self->entryPresent) < kMinChildrenNode48) {
auto *self48 = (Node48 *)self;
auto *newSelf = tls->allocate<Node16>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self48);
getInTree(self, impl) = newSelf;
tls->release(self48);
}
break;
case Type_Node256:
if (self->numChildren + int(self->entryPresent) < kMinChildrenNode256) {
auto *self256 = (Node256 *)self;
auto *newSelf = tls->allocate<Node48>(self->partialKeyLen);
newSelf->copyChildrenAndKeyFrom(*self256);
getInTree(self, impl) = newSelf;
tls->release(self256);
}
break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
// Precondition: self is not the root. May invalidate nodes along the search
// path to self. May invalidate children of self->parent. Returns a pointer to
// the node after self. If erase invalidates the pointee of `dontInvalidate`, it
// will update it to its new pointee as well. Precondition: `self->entryPresent`
Node *erase(Node *self, WriteContext *tls, ConflictSet::Impl *impl,
bool logical, Node *&dontInvalidate) {
++tls->accum.entries_erased;
assert(self->parent != nullptr);
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Erase: %s\n", getSearchPathPrintable(self).c_str());
#endif
Node *parent = self->parent;
uint8_t parentsIndex = self->parentsIndex;
auto *result = logical ? nextLogical(self) : nextPhysical(self);
removeKey(self);
assert(self->entryPresent);
self->entryPresent = false;
if (self->numChildren != 0) {
const bool update = result == dontInvalidate;
maybeDownsize(self, tls, impl, result);
if (update) {
dontInvalidate = result;
}
return result;
}
assert(self->getType() == Type_Node0);
tls->release((Node0 *)self);
switch (parent->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *parent3 = static_cast<Node3 *>(parent);
int nodeIndex = getNodeIndex(parent3, parentsIndex);
assert(nodeIndex >= 0);
--parent->numChildren;
for (int i = nodeIndex; i < parent->numChildren; ++i) {
parent3->index[i] = parent3->index[i + 1];
parent3->children[i] = parent3->children[i + 1];
parent3->childMaxVersion[i] = parent3->childMaxVersion[i + 1];
}
assert(parent->numChildren > 0 || parent->entryPresent);
} break;
case Type_Node16: {
auto *parent16 = static_cast<Node16 *>(parent);
int nodeIndex = getNodeIndex(parent16, parentsIndex);
assert(nodeIndex >= 0);
--parent->numChildren;
for (int i = nodeIndex; i < parent->numChildren; ++i) {
parent16->index[i] = parent16->index[i + 1];
parent16->children[i] = parent16->children[i + 1];
parent16->childMaxVersion[i] = parent16->childMaxVersion[i + 1];
}
// By kMinChildrenNode16
assert(parent->numChildren > 0);
} break;
case Type_Node48: {
auto *parent48 = static_cast<Node48 *>(parent);
parent48->bitSet.reset(parentsIndex);
int8_t toRemoveChildrenIndex =
std::exchange(parent48->index[parentsIndex], -1);
int8_t lastChildrenIndex = --parent48->nextFree;
assert(toRemoveChildrenIndex >= 0);
assert(lastChildrenIndex >= 0);
if (toRemoveChildrenIndex != lastChildrenIndex) {
parent48->children[toRemoveChildrenIndex] =
parent48->children[lastChildrenIndex];
parent48->childMaxVersion[toRemoveChildrenIndex] =
parent48->childMaxVersion[lastChildrenIndex];
parent48->maxOfMax[toRemoveChildrenIndex >> Node48::kMaxOfMaxShift] =
std::max(parent48->maxOfMax[toRemoveChildrenIndex >>
Node48::kMaxOfMaxShift],
parent48->childMaxVersion[toRemoveChildrenIndex]);
auto parentIndex =
parent48->children[toRemoveChildrenIndex]->parentsIndex;
parent48->index[parentIndex] = toRemoveChildrenIndex;
parent48->reverseIndex[toRemoveChildrenIndex] = parentIndex;
}
parent48->childMaxVersion[lastChildrenIndex] = tls->zero;
--parent->numChildren;
// By kMinChildrenNode48
assert(parent->numChildren > 0);
} break;
case Type_Node256: {
auto *parent256 = static_cast<Node256 *>(parent);
parent256->bitSet.reset(parentsIndex);
parent256->children[parentsIndex] = nullptr;
--parent->numChildren;
// By kMinChildrenNode256
assert(parent->numChildren > 0);
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
const bool update = result == dontInvalidate;
maybeDownsize(parent, tls, impl, result);
if (update) {
dontInvalidate = result;
}
return result;
}
Node *nextSibling(Node *node) {
for (;;) {
if (node->parent == nullptr) {
return nullptr;
}
auto next = getChildGeq(node->parent, node->parentsIndex + 1);
if (next == nullptr) {
node = node->parent;
} else {
return next;
}
}
}
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
constexpr int kStride = 64;
#else
constexpr int kStride = 16;
#endif
constexpr int kUnrollFactor = 4;
bool compareStride(const uint8_t *ap, const uint8_t *bp) {
#if defined(HAS_ARM_NEON)
static_assert(kStride == 64);
uint8x16_t x[4]; // GCOVR_EXCL_LINE
for (int i = 0; i < 4; ++i) {
x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16));
}
auto results = vreinterpretq_u16_u8(
vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3])));
bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) ==
uint64_t(-1);
#elif defined(HAS_AVX)
static_assert(kStride == 64);
__m128i x[4]; // GCOVR_EXCL_LINE
for (int i = 0; i < 4; ++i) {
x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)),
_mm_loadu_si128((__m128i *)(bp + i * 16)));
}
auto eq =
_mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]),
_mm_and_si128(x[2], x[3]))) == 0xffff;
#else
// Hope it gets vectorized
auto eq = memcmp(ap, bp, kStride) == 0;
#endif
return eq;
}
// Precondition: ap[:kStride] != bp[:kStride]
int firstNeqStride(const uint8_t *ap, const uint8_t *bp) {
#if defined(HAS_AVX)
static_assert(kStride == 64);
uint64_t c[kStride / 16]; // GCOVR_EXCL_LINE
for (int i = 0; i < kStride; i += 16) {
const auto a = _mm_loadu_si128((__m128i *)(ap + i));
const auto b = _mm_loadu_si128((__m128i *)(bp + i));
const auto compared = _mm_cmpeq_epi8(a, b);
c[i / 16] = _mm_movemask_epi8(compared) & 0xffff;
}
return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48));
#elif defined(HAS_ARM_NEON)
static_assert(kStride == 64);
for (int i = 0; i < kStride; i += 16) {
// 0xff for each match
uint16x8_t results =
vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i)));
// 0xf for each mismatch
uint64_t bitfield =
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
if (bitfield) {
return i + (std::countr_zero(bitfield) >> 2);
}
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
#else
int i = 0;
for (; i < kStride - 1; ++i) {
if (*ap++ != *bp++) {
break;
}
}
return i;
#endif
}
// This gets covered in local development
// GCOVR_EXCL_START
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
__attribute__((target("avx512f,avx512bw"))) int
longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
int i = 0;
int end = cl & ~63;
while (i < end) {
const uint64_t eq =
_mm512_cmpeq_epi8_mask(_mm512_loadu_epi8(ap), _mm512_loadu_epi8(bp));
if (eq != uint64_t(-1)) {
return i + std::countr_one(eq);
}
i += 64;
ap += 64;
bp += 64;
}
if (i < cl) {
const uint64_t mask = (uint64_t(1) << (cl - i)) - 1;
const uint64_t eq = _mm512_cmpeq_epi8_mask(
_mm512_maskz_loadu_epi8(mask, ap), _mm512_maskz_loadu_epi8(mask, bp));
return i + std::countr_one(eq & mask);
}
assert(i == cl);
return i;
}
__attribute__((target("default")))
#endif
// GCOVR_EXCL_STOP
int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
assume(cl >= 0);
int i = 0;
int end;
// kStride * kUnrollCount at a time
end = cl & ~(kStride * kUnrollFactor - 1);
while (i < end) {
for (int j = 0; j < kUnrollFactor; ++j) {
if (!compareStride(ap, bp)) {
return i + firstNeqStride(ap, bp);
}
i += kStride;
ap += kStride;
bp += kStride;
}
}
// kStride at a time
end = cl & ~(kStride - 1);
while (i < end) {
if (!compareStride(ap, bp)) {
return i + firstNeqStride(ap, bp);
}
i += kStride;
ap += kStride;
bp += kStride;
}
// word at a time
end = cl & ~(sizeof(uint64_t) - 1);
while (i < end) {
uint64_t a; // GCOVR_EXCL_LINE
uint64_t b; // GCOVR_EXCL_LINE
memcpy(&a, ap, 8);
memcpy(&b, bp, 8);
const auto mismatched = a ^ b;
if (mismatched) {
return i + std::countr_zero(mismatched) / 8;
}
i += 8;
ap += 8;
bp += 8;
}
// byte at a time
while (i < cl) {
if (*ap != *bp) {
break;
}
++ap;
++bp;
++i;
}
return i;
}
// Logically this is the same as performing firstGeq and then checking against
// point or range version according to cmp, but this version short circuits as
// soon as it can prove that there's no conflict.
bool checkPointRead(Node *n, const std::span<const uint8_t> key,
InternalVersionT readVersion, ReadContext *tls) {
++tls->point_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check point read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
for (;; ++tls->point_read_iterations_accum) {
if (remaining.size() == 0) {
if (n->entryPresent) {
return n->entry.pointVersion <= readVersion;
}
n = getFirstChildExists(n);
goto downLeftSpine;
}
auto [child, maxV] = getChildAndMaxVersion(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
if (maxV <= readVersion) {
++tls->point_read_short_circuit_accum;
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
// Logically this is the same as performing firstGeq and then checking against
// max version or range version if this prefix doesn't exist, but this version
// short circuits as soon as it can prove that there's no conflict.
bool checkPrefixRead(Node *n, const std::span<const uint8_t> key,
InternalVersionT readVersion, ReadContext *tls) {
++tls->prefix_read_accum;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check prefix read: %s\n", printable(key).c_str());
#endif
auto remaining = key;
auto *impl = tls->impl;
for (;; ++tls->prefix_read_iterations_accum) {
if (remaining.size() == 0) {
return maxVersion(n, impl) <= readVersion;
}
auto [child, maxV] = getChildAndMaxVersion(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node. All physical nodes that start with prefix are reachable from
// n.
if (maxVersion(n, impl) > readVersion) {
return false;
}
goto downLeftSpine;
}
}
if (maxV <= readVersion) {
++tls->prefix_read_short_circuit_accum;
return true;
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
#ifdef HAS_AVX
uint32_t compare16_32bit(const InternalVersionT *vs, InternalVersionT rv) {
uint32_t compared = 0;
__m128i w[4];
memcpy(w, vs, sizeof(w));
uint32_t r;
memcpy(&r, &rv, sizeof(r));
const auto rvVec = _mm_set1_epi32(r);
const auto zero = _mm_setzero_si128();
for (int i = 0; i < 4; ++i) {
compared |= _mm_movemask_ps(
__m128(_mm_cmpgt_epi32(_mm_sub_epi32(w[i], rvVec), zero)))
<< (i * 4);
}
return compared;
}
// This gets covered in local development
// GCOVR_EXCL_START
__attribute__((target("avx512f"))) uint32_t
compare16_32bit_avx512(const InternalVersionT *vs, InternalVersionT rv) {
uint32_t r;
memcpy(&r, &rv, sizeof(r));
return _mm512_cmpgt_epi32_mask(
_mm512_sub_epi32(_mm512_loadu_epi32(vs), _mm512_set1_epi32(r)),
_mm512_setzero_epi32());
}
#endif
// GCOVR_EXCL_STOP
// Returns true if v[i] <= readVersion for all i such that begin <= is[i] < end
// Preconditions: begin <= end, end - begin < 256
template <bool kAVX512>
bool scan16(const InternalVersionT *vs, const uint8_t *is, int begin, int end,
InternalVersionT readVersion) {
assert(begin <= end);
assert(end - begin < 256);
#ifdef HAS_ARM_NEON
uint8x16_t indices;
memcpy(&indices, is, 16);
// 0xff for each in bounds
auto results =
vcltq_u8(vsubq_u8(indices, vdupq_n_u8(begin)), vdupq_n_u8(end - begin));
// 0xf for each 0xff
uint64_t mask = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0);
uint32x4_t w4[4];
memcpy(w4, vs, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] =
vmovn_u32(vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t compared = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
return !(compared & mask);
#elif defined(HAS_AVX)
__m128i indices;
memcpy(&indices, is, 16);
indices = _mm_sub_epi8(indices, _mm_set1_epi8(begin));
uint32_t mask = ~_mm_movemask_epi8(_mm_cmpeq_epi8(
indices, _mm_max_epu8(indices, _mm_set1_epi8(end - begin))));
uint32_t compared = 0;
if constexpr (kAVX512) {
compared = compare16_32bit_avx512(vs, readVersion); // GCOVR_EXCL_LINE
} else {
compared = compare16_32bit(vs, readVersion); // GCOVR_EXCL_LINE
}
return !(compared & mask);
#else
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) { return c - shiftAmount < shiftUpperBound; };
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (vs[i] > readVersion) << i;
}
uint32_t mask = 0;
for (int i = 0; i < 16; ++i) {
mask |= inBounds(is[i]) << i;
}
return !(compared & mask);
#endif
}
// Returns true if v[i] <= readVersion for all i such that begin <= i < end
//
// always_inline So that we can optimize when begin or end is a constant.
// gcovr exclude annotation necessary because of always_inline?
template <bool kAVX512>
inline __attribute__((always_inline)) bool
scan16(const InternalVersionT *vs, int begin, int end, // GCOVR_EXCL_LINE
InternalVersionT readVersion) { // GCOVR_EXCL_LINE
assert(0 <= begin && begin < 16);
assert(0 <= end && end <= 16);
assert(begin <= end);
#if defined(HAS_ARM_NEON)
uint32x4_t w4[4];
memcpy(w4, vs, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] =
vmovn_u32(vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t conflict = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
conflict &= end == 16 ? -1 : (uint64_t(1) << (end << 2)) - 1;
conflict >>= begin << 2;
return !conflict;
#elif defined(HAS_AVX)
uint32_t conflict;
if constexpr (kAVX512) {
conflict = compare16_32bit_avx512(vs, readVersion); // GCOVR_EXCL_LINE
} else {
conflict = compare16_32bit(vs, readVersion); // GCOVR_EXCL_LINE
}
conflict &= (1 << end) - 1;
conflict >>= begin;
return !conflict;
#else
uint64_t conflict = 0;
for (int i = 0; i < 16; ++i) {
conflict |= (vs[i] > readVersion) << i;
}
conflict &= (1 << end) - 1;
conflict >>= begin;
return !conflict;
#endif
}
// Return whether or not the max version among all keys starting with the search
// path of n + [child], where child in (begin, end) is <= readVersion. Does not
// account for the range version of firstGt(searchpath(n) + [end - 1])
template <bool kAVX512>
bool checkMaxBetweenExclusiveImpl(Node *n, int begin, int end,
InternalVersionT readVersion,
ReadContext *tls) {
++tls->range_read_node_scan_accum;
assume(-1 <= begin);
assume(begin <= 256);
assume(-1 <= end);
assume(end <= 256);
assume(begin < end);
assert(!(begin == -1 && end == 256));
switch (n->getType()) {
case Type_Node0:
return true;
case Type_Node3: {
auto *self = static_cast<Node3 *>(n);
++begin;
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) {
return c - shiftAmount < shiftUpperBound;
};
uint32_t mask = 0;
for (int i = 0; i < Node3::kMaxNodes; ++i) {
mask |= inBounds(self->index[i]) << i;
}
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
auto *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
for (int i = 0; i < Node3::kMaxNodes; ++i) {
compared |= (self->childMaxVersion[i] > readVersion) << i;
}
return !(compared & mask) && firstRangeOk;
}
case Type_Node16: {
auto *self = static_cast<Node16 *>(n);
++begin;
assert(begin <= end);
assert(end - begin < 256);
#ifdef HAS_ARM_NEON
uint8x16_t indices;
memcpy(&indices, self->index, 16);
// 0xff for each in bounds
auto results =
vcltq_u8(vsubq_u8(indices, vdupq_n_u8(begin)), vdupq_n_u8(end - begin));
// 0xf for each 0xff
uint64_t mask = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0);
mask &= self->numChildren == 16
? uint64_t(-1)
: (uint64_t(1) << (self->numChildren << 2)) - 1;
if (!mask) {
return true;
}
auto *child = self->children[std::countr_zero(mask) >> 2];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32x4_t w4[4];
memcpy(w4, self->childMaxVersion, sizeof(w4));
uint32_t rv;
memcpy(&rv, &readVersion, sizeof(rv));
const auto rvVec = vdupq_n_u32(rv);
int32x4_t z;
memset(&z, 0, sizeof(z));
uint16x4_t conflicting[4];
for (int i = 0; i < 4; ++i) {
conflicting[i] = vmovn_u32(
vcgtq_s32(vreinterpretq_s32_u32(vsubq_u32(w4[i], rvVec)), z));
}
auto combined =
vcombine_u8(vmovn_u16(vcombine_u16(conflicting[0], conflicting[1])),
vmovn_u16(vcombine_u16(conflicting[2], conflicting[3])));
uint64_t compared = vget_lane_u64(
vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(combined), 4)), 0);
return !(compared & mask) && firstRangeOk;
#elif defined(HAS_AVX)
__m128i indices;
memcpy(&indices, self->index, 16);
indices = _mm_sub_epi8(indices, _mm_set1_epi8(begin));
uint32_t mask =
0xffff &
~_mm_movemask_epi8(_mm_cmpeq_epi8(
indices, _mm_max_epu8(indices, _mm_set1_epi8(end - begin))));
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
auto *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
if constexpr (kAVX512) {
compared = // GCOVR_EXCL_LINE
compare16_32bit_avx512(self->childMaxVersion, // GCOVR_EXCL_LINE
readVersion); // GCOVR_EXCL_LINE
} else {
compared = compare16_32bit(self->childMaxVersion,
readVersion); // GCOVR_EXCL_LINE
}
return !(compared & mask) && firstRangeOk;
#else
const unsigned shiftUpperBound = end - begin;
const unsigned shiftAmount = begin;
auto inBounds = [&](unsigned c) {
return c - shiftAmount < shiftUpperBound;
};
uint32_t mask = 0;
for (int i = 0; i < 16; ++i) {
mask |= inBounds(self->index[i]) << i;
}
mask &= (1 << self->numChildren) - 1;
if (!mask) {
return true;
}
auto *child = self->children[std::countr_zero(mask)];
const bool firstRangeOk =
!child->entryPresent || child->entry.rangeVersion <= readVersion;
uint32_t compared = 0;
for (int i = 0; i < 16; ++i) {
compared |= (self->childMaxVersion[i] > readVersion) << i;
}
return !(compared & mask) && firstRangeOk;
#endif
}
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
{
int c = self->bitSet.firstSetGeq(begin + 1);
if (c >= 0 && c < end) {
auto *child = self->children[self->index[c]];
if (child->entryPresent) {
if (!(child->entry.rangeVersion <= readVersion)) {
return false;
};
}
begin = c;
} else {
return true;
}
// [begin, end) is now the half-open interval of children we're interested
// in.
assert(begin < end);
}
// Check all pages
static_assert(Node48::kMaxOfMaxPageSize == 16);
for (int i = 0; i < Node48::kMaxOfMaxTotalPages; ++i) {
if (self->maxOfMax[i] > readVersion) {
if (!scan16<kAVX512>(self->childMaxVersion +
(i << Node48::kMaxOfMaxShift),
self->reverseIndex + (i << Node48::kMaxOfMaxShift),
begin, end, readVersion)) {
return false;
}
}
}
return true;
}
case Type_Node256: {
static_assert(Node256::kMaxOfMaxTotalPages == 16);
auto *self = static_cast<Node256 *>(n);
{
int c = self->bitSet.firstSetGeq(begin + 1);
if (c >= 0 && c < end) {
auto *child = self->children[c];
if (child->entryPresent) {
if (!(child->entry.rangeVersion <= readVersion)) {
return false;
};
}
begin = c;
} else {
return true;
}
// [begin, end) is now the half-open interval of children we're interested
// in.
assert(begin < end);
}
const int firstPage = begin >> Node256::kMaxOfMaxShift;
const int lastPage = (end - 1) >> Node256::kMaxOfMaxShift;
// Check the only page if there's only one
if (firstPage == lastPage) {
if (self->maxOfMax[firstPage] <= readVersion) {
return true;
}
const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1);
const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift);
return scan16<kAVX512>(self->childMaxVersion +
(firstPage << Node256::kMaxOfMaxShift),
intraPageBegin, intraPageEnd, readVersion);
}
// Check the first page
if (self->maxOfMax[firstPage] > readVersion) {
const int intraPageBegin = begin & (Node256::kMaxOfMaxPageSize - 1);
if (!scan16<kAVX512>(self->childMaxVersion +
(firstPage << Node256::kMaxOfMaxShift),
intraPageBegin, 16, readVersion)) {
return false;
}
}
// Check the last page
if (self->maxOfMax[lastPage] > readVersion) {
const int intraPageEnd = end - (lastPage << Node256::kMaxOfMaxShift);
if (!scan16<kAVX512>(self->childMaxVersion +
(lastPage << Node256::kMaxOfMaxShift),
0, intraPageEnd, readVersion)) {
return false;
}
}
// Check inner pages
const int innerPageBegin = (begin >> Node256::kMaxOfMaxShift) + 1;
const int innerPageEnd = (end - 1) >> Node256::kMaxOfMaxShift;
return scan16<kAVX512>(self->maxOfMax, innerPageBegin, innerPageEnd,
readVersion);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
// This gets covered in local development
// GCOVR_EXCL_START
__attribute__((target("avx512f"))) bool
checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *tls) {
return checkMaxBetweenExclusiveImpl<true>(n, begin, end, readVersion, tls);
}
// GCOVR_EXCL_STOP
__attribute__((target("default")))
#endif
bool checkMaxBetweenExclusive(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *tls) {
return checkMaxBetweenExclusiveImpl<false>(n, begin, end, readVersion, tls);
}
Vector<uint8_t> getSearchPath(Arena &arena, Node *n) {
assert(n != nullptr);
auto result = vector<uint8_t>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
return result;
} // GCOVR_EXCL_LINE
// Return true if the max version among all keys that start with key + [child],
// where begin < child < end, is <= readVersion.
//
// Precondition: transitively, no child of n has a search path that's a longer
// prefix of key than n
bool checkRangeStartsWith(Node *n, std::span<const uint8_t> key, int begin,
int end, InternalVersionT readVersion,
ReadContext *tls) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "%s(%02x,%02x)*\n", printable(key).c_str(), begin, end);
#endif
auto remaining = key;
auto *impl = tls->impl;
if (remaining.size() == 0) {
return checkMaxBetweenExclusive(n, begin, end, readVersion, tls);
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
assert(n->partialKeyLen > 0);
{
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return true;
}
goto downLeftSpine;
}
}
assert(n->partialKeyLen > int(remaining.size()));
if (begin < n->partialKey()[remaining.size()] &&
n->partialKey()[remaining.size()] < end) {
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
return false;
}
return maxVersion(n, impl) <= readVersion;
}
return true;
}
__builtin_unreachable(); // GCOVR_EXCL_LINE
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n->entry.rangeVersion <= readVersion;
}
namespace {
// Return true if the max version among all keys that start with key[:prefixLen]
// that are >= key is <= readVersion
struct CheckRangeLeftSide {
CheckRangeLeftSide(Node *n, std::span<const uint8_t> key, int prefixLen,
InternalVersionT readVersion, ReadContext *tls)
: n(n), remaining(key), prefixLen(prefixLen), readVersion(readVersion),
impl(tls->impl), tls(tls) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range left side from %s for keys starting with %s\n",
printable(key).c_str(),
printable(key.subspan(0, prefixLen)).c_str());
#endif
}
Node *n;
std::span<const uint8_t> remaining;
int prefixLen;
InternalVersionT readVersion;
ConflictSet::Impl *impl;
ReadContext *tls;
int searchPathLen = 0;
bool ok;
bool step() {
if (remaining.size() == 0) {
assert(searchPathLen >= prefixLen);
ok = maxVersion(n, impl) <= readVersion;
return true;
}
if (searchPathLen >= prefixLen) {
if (!checkMaxBetweenExclusive(n, remaining[0], 256, readVersion, tls)) {
ok = false;
return true;
}
}
auto [child, maxV] = getChildAndMaxVersion(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
if (searchPathLen < prefixLen) {
n = c;
return downLeftSpine();
}
n = c;
ok = maxVersion(n, impl) <= readVersion;
return true;
} else {
n = nextSibling(n);
if (n == nullptr) {
ok = true;
return true;
}
return downLeftSpine();
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
if (searchPathLen < prefixLen) {
return downLeftSpine();
}
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
ok = maxVersion(n, impl) <= readVersion;
return true;
} else {
n = nextSibling(n);
if (n == nullptr) {
ok = true;
return true;
}
return downLeftSpine();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
assert(searchPathLen >= prefixLen);
if (n->entryPresent && n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
ok = maxVersion(n, impl) <= readVersion;
return true;
}
}
if (maxV <= readVersion) {
ok = true;
return true;
}
return false;
}
bool downLeftSpine() {
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
ok = n->entry.rangeVersion <= readVersion;
return true;
}
};
// Return true if the max version among all keys that start with key[:prefixLen]
// that are < key is <= readVersion
struct CheckRangeRightSide {
CheckRangeRightSide(Node *n, std::span<const uint8_t> key, int prefixLen,
InternalVersionT readVersion, ReadContext *tls)
: n(n), key(key), remaining(key), prefixLen(prefixLen),
readVersion(readVersion), impl(tls->impl), tls(tls) {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr, "Check range right side to %s for keys starting with %s\n",
printable(key).c_str(),
printable(key.subspan(0, prefixLen)).c_str());
#endif
}
Node *n;
std::span<const uint8_t> key;
std::span<const uint8_t> remaining;
int prefixLen;
InternalVersionT readVersion;
ConflictSet::Impl *impl;
ReadContext *tls;
int searchPathLen = 0;
bool ok;
bool step() {
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr,
"Search path: %s, searchPathLen: %d, prefixLen: %d, remaining: "
"%s\n",
getSearchPathPrintable(n).c_str(), searchPathLen, prefixLen,
printable(remaining).c_str());
#endif
assert(searchPathLen <= int(key.size()));
if (remaining.size() == 0) {
return downLeftSpine();
}
if (searchPathLen >= prefixLen) {
if (n->entryPresent && n->entry.pointVersion > readVersion) {
ok = false;
return true;
}
if (!checkMaxBetweenExclusive(n, -1, remaining[0], readVersion, tls)) {
ok = false;
return true;
}
}
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
return downLeftSpine();
} else {
return backtrack();
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
++searchPathLen;
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
searchPathLen += i;
if (i < commonLen) {
++searchPathLen;
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
return downLeftSpine();
} else {
if (searchPathLen > prefixLen && n->entryPresent &&
n->entry.rangeVersion > readVersion) {
ok = false;
return true;
}
return backtrack();
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
return downLeftSpine();
}
}
return false;
}
bool backtrack() {
for (;;) {
if (searchPathLen > prefixLen && maxVersion(n, impl) > readVersion) {
ok = false;
return true;
}
if (n->parent == nullptr) {
ok = true;
return true;
}
auto next = getChildGeq(n->parent, n->parentsIndex + 1);
if (next == nullptr) {
searchPathLen -= 1 + n->partialKeyLen;
n = n->parent;
} else {
searchPathLen -= n->partialKeyLen;
n = next;
searchPathLen += n->partialKeyLen;
return downLeftSpine();
}
}
}
bool downLeftSpine() {
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
ok = n->entry.rangeVersion <= readVersion;
return true;
}
};
} // namespace
bool checkRangeRead(Node *n, std::span<const uint8_t> begin,
std::span<const uint8_t> end, InternalVersionT readVersion,
ReadContext *tls) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return checkPointRead(n, begin, readVersion, tls);
}
if (lcp == int(begin.size() - 1) && end.size() == begin.size() &&
int(begin.back()) + 1 == int(end.back())) {
return checkPrefixRead(n, begin, readVersion, tls);
}
++tls->range_read_accum;
auto remaining = begin.subspan(0, lcp);
Arena arena;
for (;; ++tls->range_read_iterations_accum) {
assert(getSearchPath(arena, n) <=>
begin.subspan(0, lcp - remaining.size()) ==
0);
if (remaining.size() == 0) {
break;
}
auto [child, v] = getChildAndMaxVersion(n, remaining[0]);
if (child == nullptr) {
break;
}
if (child->partialKeyLen > 0) {
int cl = std::min<int>(child->partialKeyLen, remaining.size() - 1);
int i =
longestCommonPrefix(child->partialKey(), remaining.data() + 1, cl);
if (i != child->partialKeyLen) {
break;
}
}
if (v <= readVersion) {
++tls->range_read_short_circuit_accum;
return true;
}
n = child;
remaining =
remaining.subspan(1 + child->partialKeyLen,
remaining.size() - (1 + child->partialKeyLen));
}
assert(getSearchPath(arena, n) <=> begin.subspan(0, lcp - remaining.size()) ==
0);
const int consumed = lcp - remaining.size();
assume(consumed >= 0);
begin = begin.subspan(consumed, int(begin.size()) - consumed);
end = end.subspan(consumed, int(end.size()) - consumed);
lcp -= consumed;
if (lcp == int(begin.size())) {
CheckRangeRightSide checkRangeRightSide{n, end, lcp, readVersion, tls};
while (!checkRangeRightSide.step())
;
return checkRangeRightSide.ok;
}
if (!checkRangeStartsWith(n, begin.subspan(0, lcp), begin[lcp], end[lcp],
readVersion, tls)) {
return false;
}
CheckRangeLeftSide checkRangeLeftSide{n, begin, lcp + 1, readVersion, tls};
CheckRangeRightSide checkRangeRightSide{n, end, lcp + 1, readVersion, tls};
for (;;) {
bool leftDone = checkRangeLeftSide.step();
bool rightDone = checkRangeRightSide.step();
if (!leftDone && !rightDone) {
tls->range_read_iterations_accum += 2;
continue;
}
if (leftDone && rightDone) {
break;
} else if (leftDone) {
while (!checkRangeRightSide.step()) {
++tls->range_read_iterations_accum;
}
break;
} else {
assert(rightDone);
while (!checkRangeLeftSide.step()) {
++tls->range_read_iterations_accum;
}
}
break;
}
return checkRangeLeftSide.ok && checkRangeRightSide.ok;
}
#ifdef __x86_64__
// Explicitly instantiate with target avx512f attribute so the compiler can
// inline compare16_32bit_avx512, and generally use avx512f within more
// functions
template __attribute__((target("avx512f"))) bool
scan16<true>(const InternalVersionT *vs, const uint8_t *is, int begin, int end,
InternalVersionT readVersion);
template __attribute__((always_inline, target("avx512f"))) bool
scan16<true>(const InternalVersionT *vs, int begin, int end,
InternalVersionT readVersion);
template __attribute__((target("avx512f"))) bool
checkMaxBetweenExclusiveImpl<true>(Node *n, int begin, int end,
InternalVersionT readVersion, ReadContext *);
#endif
// Consume the partial key of `self` (which must exist), and update `self` and
// `key` such that `self` is along the search path of `key`
void consumePartialKey(Node *&self, std::span<const uint8_t> &key,
InternalVersionT writeVersion, WriteContext *tls,
ConflictSet::Impl *impl) {
assert(self->partialKeyLen > 0);
// Handle an existing partial key
int commonLen = std::min<int>(self->partialKeyLen, key.size());
int partialKeyIndex =
longestCommonPrefix(self->partialKey(), key.data(), commonLen);
if (partialKeyIndex < self->partialKeyLen) {
auto *old = self;
InternalVersionT oldMaxVersion =
exchangeMaxVersion(old, writeVersion, impl);
// *self will have one child (old)
auto *newSelf = tls->allocate<Node3>(partialKeyIndex);
newSelf->parent = old->parent;
newSelf->parentsIndex = old->parentsIndex;
newSelf->partialKeyLen = partialKeyIndex;
newSelf->entryPresent = false;
newSelf->numChildren = 1;
memcpy(newSelf->partialKey(), old->partialKey(), newSelf->partialKeyLen);
uint8_t oldDistinguishingByte = old->partialKey()[partialKeyIndex];
old->parent = newSelf;
old->parentsIndex = oldDistinguishingByte;
newSelf->index[0] = oldDistinguishingByte;
newSelf->children[0] = old;
newSelf->childMaxVersion[0] = oldMaxVersion;
self = newSelf;
memmove(old->partialKey(), old->partialKey() + partialKeyIndex + 1,
old->partialKeyLen - (partialKeyIndex + 1));
old->partialKeyLen -= partialKeyIndex + 1;
// We would consider decreasing capacity here, but we can't invalidate
// old since it's not on the search path. setOldestVersion will clean it
// up.
}
key = key.subspan(partialKeyIndex, key.size() - partialKeyIndex);
}
// Returns a pointer to the newly inserted node. Caller must set
// `entryPresent`, and `entry` fields. All nodes along the search path of the
// result will have `maxVersion` set to `writeVersion` as a postcondition. Nodes
// along the search path may be invalidated.
[[nodiscard]]
Node *insert(Node **self, std::span<const uint8_t> key,
InternalVersionT writeVersion, WriteContext *tls,
ConflictSet::Impl *impl) {
if ((*self)->partialKeyLen > 0) {
consumePartialKey(*self, key, writeVersion, tls, impl);
}
assert(maxVersion(*self, impl) <= writeVersion);
setMaxVersion(*self, impl, writeVersion);
for (;; ++tls->accum.insert_iterations) {
if (key.size() == 0) {
return *self;
}
auto &child = getOrCreateChild(*self, key.front(), writeVersion, tls);
if (!child) {
child = tls->allocate<Node0>(key.size() - 1);
child->numChildren = 0;
child->entryPresent = false;
child->partialKeyLen = key.size() - 1;
child->parent = *self;
child->parentsIndex = key.front();
setMaxVersion(child, impl, writeVersion);
memcpy(child->partialKey(), key.data() + 1, child->partialKeyLen);
return child;
}
self = &child;
key = key.subspan(1, key.size() - 1);
if ((*self)->partialKeyLen > 0) {
consumePartialKey(*self, key, writeVersion, tls, impl);
assert(maxVersion(*self, impl) <= writeVersion);
setMaxVersion(*self, impl, writeVersion);
}
}
}
void destroyTree(Node *root) {
Arena arena;
auto toFree = vector<Node *>(arena);
toFree.push_back(root);
#if SHOW_MEMORY
for (auto *iter = root; iter != nullptr; iter = nextPhysical(iter)) {
removeNode(iter);
removeKey(iter);
}
#endif
while (toFree.size() > 0) {
auto *n = toFree.back();
toFree.pop_back();
// Add all children to toFree
for (auto c = getChildGeq(n, 0); c != nullptr;
c = getChildGeq(n, c->parentsIndex + 1)) {
assert(c != nullptr);
toFree.push_back(c);
}
safe_free(n, n->size());
}
}
void addPointWrite(Node *&root, std::span<const uint8_t> key,
InternalVersionT writeVersion, WriteContext *tls,
ConflictSet::Impl *impl) {
++tls->accum.point_writes;
auto *n = insert(&root, key, writeVersion, tls, impl);
if (!n->entryPresent) {
++tls->accum.entries_inserted;
auto *p = nextLogical(n);
addKey(n);
n->entryPresent = true;
n->entry.pointVersion = writeVersion;
n->entry.rangeVersion =
p == nullptr ? tls->zero : std::max(p->entry.rangeVersion, tls->zero);
} else {
assert(writeVersion >= n->entry.pointVersion);
n->entry.pointVersion = writeVersion;
}
}
void fixupMaxVersion(Node *node, ConflictSet::Impl *impl, WriteContext *tls) {
InternalVersionT max;
if (node->entryPresent) {
max = std::max(node->entry.pointVersion, tls->zero);
} else {
max = tls->zero;
}
switch (node->getType()) {
case Type_Node0:
break;
case Type_Node3: {
auto *self3 = static_cast<Node3 *>(node);
for (int i = 0; i < self3->numChildren; ++i) {
max = std::max(self3->childMaxVersion[i], max);
}
} break;
case Type_Node16: {
auto *self16 = static_cast<Node16 *>(node);
for (int i = 0; i < self16->numChildren; ++i) {
max = std::max(self16->childMaxVersion[i], max);
}
} break;
case Type_Node48: {
auto *self48 = static_cast<Node48 *>(node);
for (auto v : self48->maxOfMax) {
max = std::max(v, max);
}
} break;
case Type_Node256: {
auto *self256 = static_cast<Node256 *>(node);
for (auto v : self256->maxOfMax) {
max = std::max(v, max);
}
} break;
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
setMaxVersion(node, impl, max);
}
void addWriteRange(Node *&root, std::span<const uint8_t> begin,
std::span<const uint8_t> end, InternalVersionT writeVersion,
WriteContext *tls, ConflictSet::Impl *impl) {
int lcp = longestCommonPrefix(begin.data(), end.data(),
std::min(begin.size(), end.size()));
if (lcp == int(begin.size()) && end.size() == begin.size() + 1 &&
end.back() == 0) {
return addPointWrite(root, begin, writeVersion, tls, impl);
}
++tls->accum.range_writes;
const bool beginIsPrefix = lcp == int(begin.size());
auto remaining = begin.subspan(0, lcp);
auto *n = root;
for (;;) {
if (int(remaining.size()) <= n->partialKeyLen) {
break;
}
int i = longestCommonPrefix(n->partialKey(), remaining.data(),
n->partialKeyLen);
if (i != n->partialKeyLen) {
break;
}
auto *child = getChild(n, remaining[n->partialKeyLen]);
if (child == nullptr) {
break;
}
assert(maxVersion(n, impl) <= writeVersion);
setMaxVersion(n, impl, writeVersion);
remaining = remaining.subspan(n->partialKeyLen + 1,
remaining.size() - (n->partialKeyLen + 1));
n = child;
}
Node **useAsRoot = &getInTree(n, impl);
int consumed = lcp - remaining.size();
begin = begin.subspan(consumed, begin.size() - consumed);
end = end.subspan(consumed, end.size() - consumed);
auto *beginNode = insert(useAsRoot, begin, writeVersion, tls, impl);
const bool insertedBegin = !beginNode->entryPresent;
addKey(beginNode);
beginNode->entryPresent = true;
if (insertedBegin) {
++tls->accum.entries_inserted;
auto *p = nextLogical(beginNode);
beginNode->entry.rangeVersion =
p == nullptr ? tls->zero : std::max(p->entry.rangeVersion, tls->zero);
beginNode->entry.pointVersion = writeVersion;
}
assert(writeVersion >= beginNode->entry.pointVersion);
beginNode->entry.pointVersion = writeVersion;
auto *endNode = insert(useAsRoot, end, writeVersion, tls, impl);
const bool insertedEnd = !endNode->entryPresent;
addKey(endNode);
endNode->entryPresent = true;
if (insertedEnd) {
++tls->accum.entries_inserted;
auto *p = nextLogical(endNode);
endNode->entry.pointVersion =
p == nullptr ? tls->zero : std::max(p->entry.rangeVersion, tls->zero);
}
endNode->entry.rangeVersion = writeVersion;
if (beginIsPrefix && insertedEnd) {
// beginNode may have been invalidated when inserting end. TODO can we do
// better?
beginNode = insert(useAsRoot, begin, writeVersion, tls, impl);
assert(beginNode->entryPresent);
}
for (beginNode = nextLogical(beginNode); beginNode != endNode;
beginNode = erase(beginNode, tls, impl, /*logical*/ true, endNode)) {
}
// Inserting end trashed endNode's maxVersion. Fix that
fixupMaxVersion(endNode, impl, tls);
}
Node *firstGeqPhysical(Node *n, const std::span<const uint8_t> key) {
auto remaining = key;
for (;;) {
if (remaining.size() == 0) {
return n;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
return n;
} else {
n = nextSibling(n);
if (n == nullptr) {
// This line is genuinely unreachable from any entry point of the
// final library, since we can't remove a key without introducing a
// key after it, and the only production caller of firstGeq is for
// resuming the setOldestVersion scan.
return nullptr; // GCOVR_EXCL_LINE
}
return n;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
return n;
} else {
n = nextSibling(n);
return n;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
return n;
}
}
}
}
struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
void check(const ReadRange *reads, Result *result, int count) {
ReadContext tls;
tls.impl = this;
int64_t check_byte_accum = 0;
for (int i = 0; i < count; ++i) {
const auto &r = reads[i];
check_byte_accum += r.begin.len + r.end.len;
auto begin = std::span<const uint8_t>(r.begin.p, r.begin.len);
auto end = std::span<const uint8_t>(r.end.p, r.end.len);
assert(oldestVersionFullPrecision >=
newestVersionFullPrecision - kNominalVersionWindow);
result[i] =
reads[i].readVersion < oldestVersionFullPrecision ? TooOld
: (end.size() > 0
? checkRangeRead(root, begin, end,
InternalVersionT(reads[i].readVersion), &tls)
: checkPointRead(root, begin,
InternalVersionT(reads[i].readVersion), &tls))
? Commit
: Conflict;
tls.commits_accum += result[i] == Commit;
tls.conflicts_accum += result[i] == Conflict;
tls.too_olds_accum += result[i] == TooOld;
}
point_read_total.add(tls.point_read_accum);
prefix_read_total.add(tls.prefix_read_accum);
range_read_total.add(tls.range_read_accum);
range_read_node_scan_total.add(tls.range_read_node_scan_accum);
point_read_short_circuit_total.add(tls.point_read_short_circuit_accum);
prefix_read_short_circuit_total.add(tls.prefix_read_short_circuit_accum);
range_read_short_circuit_total.add(tls.range_read_short_circuit_accum);
point_read_iterations_total.add(tls.point_read_iterations_accum);
prefix_read_iterations_total.add(tls.prefix_read_iterations_accum);
range_read_iterations_total.add(tls.range_read_iterations_accum);
commits_total.add(tls.commits_accum);
conflicts_total.add(tls.conflicts_accum);
too_olds_total.add(tls.too_olds_accum);
check_bytes_total.add(check_byte_accum);
}
void addWrites(const WriteRange *writes, int count, int64_t writeVersion) {
// There could be other conflict sets in the same thread. We need
// InternalVersionT::zero to be correct for this conflict set for the
// lifetime of the current call frame.
InternalVersionT::zero = tls.zero = oldestVersion;
assert(writeVersion >= newestVersionFullPrecision);
if (oldestExtantVersion < writeVersion - kMaxCorrectVersionWindow)
[[unlikely]] {
if (writeVersion > newestVersionFullPrecision + kNominalVersionWindow) {
destroyTree(root);
init(writeVersion - kNominalVersionWindow);
}
newestVersionFullPrecision = writeVersion;
newest_version.set(newestVersionFullPrecision);
setOldestVersion(newestVersionFullPrecision - kNominalVersionWindow);
while (oldestExtantVersion <
newestVersionFullPrecision - kMaxCorrectVersionWindow) {
gcScanStep(1000);
}
} else {
newestVersionFullPrecision = writeVersion;
newest_version.set(newestVersionFullPrecision);
setOldestVersion(newestVersionFullPrecision - kNominalVersionWindow);
}
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
tls.accum.write_bytes += w.begin.len + w.end.len;
auto begin = std::span<const uint8_t>(w.begin.p, w.begin.len);
auto end = std::span<const uint8_t>(w.end.p, w.end.len);
if (w.end.len > 0) {
keyUpdates += 3;
addWriteRange(root, begin, end, InternalVersionT(writeVersion), &tls,
this);
} else {
keyUpdates += 2;
addPointWrite(root, begin, InternalVersionT(writeVersion), &tls, this);
}
}
memory_bytes.set(totalBytes);
point_writes_total.add(tls.accum.point_writes);
range_writes_total.add(tls.accum.range_writes);
nodes_allocated_total.add(tls.accum.nodes_allocated);
nodes_released_total.add(tls.accum.nodes_released);
entries_inserted_total.add(tls.accum.entries_inserted);
entries_erased_total.add(tls.accum.entries_erased);
insert_iterations_total.add(tls.accum.insert_iterations);
write_bytes_total.add(tls.accum.write_bytes);
memset(&tls.accum, 0, sizeof(tls.accum));
}
// Spends up to `fuel` gc'ing, and returns its unused fuel. Reclaims memory
// and updates oldestExtantVersion after spending enough fuel.
int64_t gcScanStep(int64_t fuel) {
Node *n = firstGeqPhysical(root, removalKey);
// There's no way to erase removalKey without introducing a key after it
assert(n != nullptr);
// Don't erase the root
if (n == root) {
rezero(n, oldestVersion);
rootMaxVersion = std::max(rootMaxVersion, oldestVersion);
n = nextPhysical(n);
}
int64_t set_oldest_iterations_accum = 0;
for (; fuel > 0 && n != nullptr; ++set_oldest_iterations_accum) {
rezero(n, oldestVersion);
// The "make sure gc keeps up with writes" calculations assume that we're
// scanning key by key, not node by node. Make sure we only spend fuel
// when there's a logical entry.
fuel -= n->entryPresent;
if (n->entryPresent && std::max(n->entry.pointVersion,
n->entry.rangeVersion) <= oldestVersion) {
// Any transaction n would have prevented from committing is
// going to fail with TooOld anyway.
// There's no way to insert a range such that range version of the right
// node is greater than the point version of the left node
assert(n->entry.rangeVersion <= oldestVersion);
Node *dummy = nullptr;
n = erase(n, &tls, this, /*logical*/ false, dummy);
} else {
maybeDecreaseCapacity(n, &tls, this);
n = nextPhysical(n);
}
}
gc_iterations_total.add(set_oldest_iterations_accum);
if (n == nullptr) {
removalKey = {};
oldestExtantVersion = oldestVersionAtGcBegin;
oldest_extant_version.set(oldestExtantVersion);
oldestVersionAtGcBegin = oldestVersionFullPrecision;
#if DEBUG_VERBOSE && !defined(NDEBUG)
fprintf(stderr,
"new oldestExtantVersion: %" PRId64
", new oldestVersionAtGcBegin: %" PRId64 "\n",
oldestExtantVersion, oldestVersionAtGcBegin);
#endif
} else {
removalKeyArena = Arena();
removalKey = getSearchPath(removalKeyArena, n);
}
return fuel;
}
void setOldestVersion(int64_t o) {
if (o <= oldestVersionFullPrecision) {
return;
}
InternalVersionT oldestVersion{o};
this->oldestVersionFullPrecision = o;
this->oldestVersion = oldestVersion;
InternalVersionT::zero = tls.zero = oldestVersion;
#ifdef NDEBUG
// This is here for performance reasons, since we want to amortize the cost
// of storing the search path as a string. In tests, we want to exercise the
// rest of the code often.
if (keyUpdates < 100) {
return;
}
#endif
keyUpdates = gcScanStep(keyUpdates);
memory_bytes.set(totalBytes);
nodes_allocated_total.add(std::exchange(tls.accum.nodes_allocated, 0));
nodes_released_total.add(std::exchange(tls.accum.nodes_released, 0));
entries_inserted_total.add(std::exchange(tls.accum.entries_inserted, 0));
entries_erased_total.add(std::exchange(tls.accum.entries_erased, 0));
oldest_version.set(oldestVersionFullPrecision);
}
int64_t getBytes() const { return totalBytes; }
void init(int64_t oldestVersion) {
this->oldestVersion = InternalVersionT(oldestVersion);
oldestVersionFullPrecision = oldestExtantVersion = oldestVersionAtGcBegin =
newestVersionFullPrecision = oldestVersion;
oldest_version.set(oldestVersionFullPrecision);
newest_version.set(newestVersionFullPrecision);
oldest_extant_version.set(oldestExtantVersion);
tls.~WriteContext();
new (&tls) WriteContext();
removalKeyArena = Arena{};
removalKey = {};
keyUpdates = 10;
// Insert ""
root = tls.allocate<Node0>(0);
root->numChildren = 0;
root->parent = nullptr;
rootMaxVersion = this->oldestVersion;
root->entryPresent = false;
root->partialKeyLen = 0;
addKey(root);
root->entryPresent = true;
root->entry.pointVersion = this->oldestVersion;
root->entry.rangeVersion = this->oldestVersion;
InternalVersionT::zero = tls.zero = this->oldestVersion;
// Intentionally not resetting totalBytes
}
explicit Impl(int64_t oldestVersion) {
init(oldestVersion);
initMetrics();
}
~Impl() {
destroyTree(root);
safe_free(metrics, metricsCount * sizeof(metrics[0]));
}
WriteContext tls;
Arena removalKeyArena;
std::span<const uint8_t> removalKey;
int64_t keyUpdates;
Node *root;
InternalVersionT rootMaxVersion;
InternalVersionT oldestVersion;
int64_t oldestVersionFullPrecision;
int64_t oldestExtantVersion;
int64_t oldestVersionAtGcBegin;
int64_t newestVersionFullPrecision;
int64_t totalBytes = 0;
MetricsV1 *metrics;
int metricsCount = 0;
void initMetrics() {
metrics = (MetricsV1 *)safe_malloc(metricsCount * sizeof(metrics[0]));
for (auto [i, m] = std::make_tuple(metricsCount - 1, metricList); i >= 0;
--i, m = m->prev) {
metrics[i].name = m->name;
metrics[i].help = m->help;
metrics[i].p = m;
metrics[i].type = m->type;
}
}
Metric *metricList = nullptr;
#define GAUGE(name, help) \
Gauge name { this, #name, help }
#define COUNTER(name, help) \
Counter name { this, #name, help }
// ==================== METRICS DEFINITIONS ====================
COUNTER(point_read_total, "Total number of point reads checked");
COUNTER(point_read_short_circuit_total,
"Total number of point reads that did not require a full search to "
"check");
COUNTER(point_read_iterations_total,
"Total number of iterations of the main loop for point read checks");
COUNTER(prefix_read_total, "Total number of prefix reads checked");
COUNTER(prefix_read_short_circuit_total,
"Total number of prefix reads that did not require a full search to "
"check");
COUNTER(prefix_read_iterations_total,
"Total number of iterations of the main loop for prefix read checks");
COUNTER(range_read_total, "Total number of range reads checked");
COUNTER(range_read_short_circuit_total,
"Total number of range reads that did not require a full search to "
"check");
COUNTER(range_read_iterations_total,
"Total number of iterations of the main loops for range read checks");
COUNTER(range_read_node_scan_total,
"Total number of scans of individual nodes while "
"checking a range read");
COUNTER(commits_total,
"Total number of checks where the result is \"commit\"");
COUNTER(conflicts_total,
"Total number of checks where the result is \"conflict\"");
COUNTER(too_olds_total,
"Total number of checks where the result is \"too old\"");
COUNTER(check_bytes_total, "Total number of key bytes checked");
COUNTER(point_writes_total, "Total number of point writes");
COUNTER(range_writes_total,
"Total number of range writes (includes prefix writes)");
GAUGE(memory_bytes, "Total number of bytes in use");
COUNTER(nodes_allocated_total,
"The total number of physical tree nodes allocated");
COUNTER(nodes_released_total,
"The total number of physical tree nodes released");
COUNTER(insert_iterations_total,
"The total number of iterations of the main loop for insertion. "
"Includes searches where the entry already existed, and so insertion "
"did not take place");
COUNTER(entries_inserted_total,
"The total number of entries inserted in the tree");
COUNTER(entries_erased_total,
"The total number of entries erased from the tree");
COUNTER(
gc_iterations_total,
"The total number of iterations of the main loop for garbage collection");
COUNTER(write_bytes_total, "Total number of key bytes in calls to addWrites");
GAUGE(oldest_version,
"The lowest version that doesn't result in \"TooOld\" for checks");
GAUGE(newest_version, "The version of the most recent call to addWrites");
GAUGE(
oldest_extant_version,
"A lower bound on the lowest version associated with an existing entry");
// ==================== END METRICS DEFINITIONS ====================
#undef GAUGE
#undef COUNTER
void getMetricsV1(MetricsV1 **metrics, int *count) {
*metrics = this->metrics;
*count = metricsCount;
}
};
Metric::Metric(ConflictSet::Impl *impl, const char *name, const char *help,
ConflictSet::MetricsV1::Type type)
: prev(std::exchange(impl->metricList, this)), name(name), help(help),
type(type), value(0) {
++impl->metricsCount;
}
InternalVersionT maxVersion(Node *n, ConflictSet::Impl *impl) {
int index = n->parentsIndex;
n = n->parent;
if (n == nullptr) {
return impl->rootMaxVersion;
}
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndex(n3, index);
return n3->childMaxVersion[i];
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndex(n16, index);
return n16->childMaxVersion[i];
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
return n48->childMaxVersion[n48->index[index]];
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
return n256->childMaxVersion[index];
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
InternalVersionT exchangeMaxVersion(Node *n, InternalVersionT newMax,
ConflictSet::Impl *impl) {
int index = n->parentsIndex;
n = n->parent;
if (n == nullptr) {
return std::exchange(impl->rootMaxVersion, newMax);
}
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndex(n3, index);
return std::exchange(n3->childMaxVersion[i], newMax);
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndex(n16, index);
return std::exchange(n16->childMaxVersion[i], newMax);
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
return std::exchange(n48->childMaxVersion[n48->index[index]], newMax);
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
return std::exchange(n256->childMaxVersion[index], newMax);
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
void setMaxVersion(Node *n, ConflictSet::Impl *impl, InternalVersionT newMax) {
int index = n->parentsIndex;
n = n->parent;
if (n == nullptr) {
impl->rootMaxVersion = newMax;
return;
}
switch (n->getType()) {
case Type_Node0: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
case Type_Node3: {
auto *n3 = static_cast<Node3 *>(n);
int i = getNodeIndex(n3, index);
n3->childMaxVersion[i] = newMax;
return;
}
case Type_Node16: {
auto *n16 = static_cast<Node16 *>(n);
int i = getNodeIndex(n16, index);
n16->childMaxVersion[i] = newMax;
return;
}
case Type_Node48: {
auto *n48 = static_cast<Node48 *>(n);
assert(n48->bitSet.test(index));
int i = n48->index[index];
n48->childMaxVersion[i] = newMax;
n48->maxOfMax[i >> Node48::kMaxOfMaxShift] = std::max<InternalVersionT>(
n48->maxOfMax[i >> Node48::kMaxOfMaxShift], newMax);
return;
}
case Type_Node256: {
auto *n256 = static_cast<Node256 *>(n);
assert(n256->bitSet.test(index));
n256->childMaxVersion[index] = newMax;
n256->maxOfMax[index >> Node256::kMaxOfMaxShift] =
std::max<InternalVersionT>(
n256->maxOfMax[index >> Node256::kMaxOfMaxShift], newMax);
return;
}
default: // GCOVR_EXCL_LINE
__builtin_unreachable(); // GCOVR_EXCL_LINE
}
}
Node *&getInTree(Node *n, ConflictSet::Impl *impl) {
return n->parent == nullptr ? impl->root
: getChildExists(n->parent, n->parentsIndex);
}
// Internal entry points. Public entry points should just delegate to these
void internal_check(ConflictSet::Impl *impl,
const ConflictSet::ReadRange *reads,
ConflictSet::Result *results, int count) {
impl->check(reads, results, count);
}
void internal_addWrites(ConflictSet::Impl *impl,
const ConflictSet::WriteRange *writes, int count,
int64_t writeVersion) {
mallocBytesDelta = 0;
impl->addWrites(writes, count, writeVersion);
impl->totalBytes += mallocBytesDelta;
#if SHOW_MEMORY
if (impl->totalBytes != mallocBytes) {
abort();
}
#endif
}
void internal_setOldestVersion(ConflictSet::Impl *impl, int64_t oldestVersion) {
mallocBytesDelta = 0;
impl->setOldestVersion(oldestVersion);
impl->totalBytes += mallocBytesDelta;
#if SHOW_MEMORY
if (impl->totalBytes != mallocBytes) {
abort();
}
#endif
}
ConflictSet::Impl *internal_create(int64_t oldestVersion) {
mallocBytesDelta = 0;
auto *result = new (safe_malloc(sizeof(ConflictSet::Impl)))
ConflictSet::Impl{oldestVersion};
result->totalBytes += mallocBytesDelta;
return result;
}
void internal_destroy(ConflictSet::Impl *impl) {
impl->~Impl();
safe_free(impl, sizeof(ConflictSet::Impl));
}
int64_t internal_getBytes(ConflictSet::Impl *impl) { return impl->getBytes(); }
void internal_getMetricsV1(ConflictSet::Impl *impl,
ConflictSet::MetricsV1 **metrics, int *count) {
impl->getMetricsV1(metrics, count);
}
double internal_getMetricValue(const ConflictSet::MetricsV1 *metric) {
return ((Metric *)metric->p)->value.load(std::memory_order_relaxed);
}
// ==================== END IMPLEMENTATION ====================
// GCOVR_EXCL_START
Node *firstGeqLogical(Node *n, const std::span<const uint8_t> key) {
auto remaining = key;
for (;;) {
if (remaining.size() == 0) {
if (n->entryPresent) {
return n;
}
n = getFirstChildExists(n);
goto downLeftSpine;
}
auto *child = getChild(n, remaining[0]);
if (child == nullptr) {
auto c = getChildGeq(n, remaining[0]);
if (c != nullptr) {
n = c;
goto downLeftSpine;
} else {
n = nextSibling(n);
if (n == nullptr) {
return nullptr;
}
goto downLeftSpine;
}
}
n = child;
remaining = remaining.subspan(1, remaining.size() - 1);
if (n->partialKeyLen > 0) {
int commonLen = std::min<int>(n->partialKeyLen, remaining.size());
int i = longestCommonPrefix(n->partialKey(), remaining.data(), commonLen);
if (i < commonLen) {
auto c = n->partialKey()[i] <=> remaining[i];
if (c > 0) {
goto downLeftSpine;
} else {
n = nextSibling(n);
goto downLeftSpine;
}
}
if (commonLen == n->partialKeyLen) {
// partial key matches
remaining = remaining.subspan(commonLen, remaining.size() - commonLen);
} else if (n->partialKeyLen > int(remaining.size())) {
// n is the first physical node greater than remaining, and there's no
// eq node
goto downLeftSpine;
}
}
}
downLeftSpine:
for (; !n->entryPresent; n = getFirstChildExists(n)) {
}
return n;
}
void ConflictSet::check(const ReadRange *reads, Result *results,
int count) const {
internal_check(impl, reads, results, count);
}
void ConflictSet::addWrites(const WriteRange *writes, int count,
int64_t writeVersion) {
internal_addWrites(impl, writes, count, writeVersion);
}
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
internal_setOldestVersion(impl, oldestVersion);
}
int64_t ConflictSet::getBytes() const { return internal_getBytes(impl); }
void ConflictSet::getMetricsV1(MetricsV1 **metrics, int *count) const {
return internal_getMetricsV1(impl, metrics, count);
}
double ConflictSet::MetricsV1::getValue() const {
return internal_getMetricValue(this);
}
ConflictSet::ConflictSet(int64_t oldestVersion)
: impl(internal_create(oldestVersion)) {}
ConflictSet::~ConflictSet() {
if (impl) {
internal_destroy(impl);
}
}
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
impl = std::exchange(other.impl, nullptr);
return *this;
}
using ConflictSet_Result = ConflictSet::Result;
using ConflictSet_Key = ConflictSet::Key;
using ConflictSet_ReadRange = ConflictSet::ReadRange;
using ConflictSet_WriteRange = ConflictSet::WriteRange;
extern "C" {
__attribute__((__visibility__("default"))) void
ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads,
ConflictSet_Result *results, int count) {
internal_check((ConflictSet::Impl *)cs, reads, results, count);
}
__attribute__((__visibility__("default"))) void
ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count,
int64_t writeVersion) {
internal_addWrites((ConflictSet::Impl *)cs, writes, count, writeVersion);
}
__attribute__((__visibility__("default"))) void
ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) {
internal_setOldestVersion((ConflictSet::Impl *)cs, oldestVersion);
}
__attribute__((__visibility__("default"))) void *
ConflictSet_create(int64_t oldestVersion) {
return internal_create(oldestVersion);
}
__attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) {
internal_destroy((ConflictSet::Impl *)cs);
}
__attribute__((__visibility__("default"))) int64_t
ConflictSet_getBytes(void *cs) {
return internal_getBytes((ConflictSet::Impl *)cs);
}
}
// Make sure abi is well-defined
static_assert(std::is_standard_layout_v<ConflictSet::Result>);
static_assert(std::is_standard_layout_v<ConflictSet::Key>);
static_assert(std::is_standard_layout_v<ConflictSet::ReadRange>);
static_assert(std::is_standard_layout_v<ConflictSet::WriteRange>);
static_assert(std::is_standard_layout_v<ConflictSet::MetricsV1>);
namespace {
std::string getSearchPathPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = vector<char>(arena);
for (;;) {
for (int i = n->partialKeyLen - 1; i >= 0; --i) {
result.push_back(n->partialKey()[i]);
}
if (n->parent == nullptr) {
break;
}
result.push_back(n->parentsIndex);
n = n->parent;
}
std::reverse(result.begin(), result.end());
if (result.size() > 0) {
return printable(std::string_view((const char *)&result[0],
result.size())); // NOLINT
} else {
return std::string();
}
}
std::string getPartialKeyPrintable(Node *n) {
Arena arena;
if (n == nullptr) {
return "<end>";
}
auto result = std::string((const char *)&n->parentsIndex,
n->parent == nullptr ? 0 : 1) +
std::string((const char *)n->partialKey(), n->partialKeyLen);
return printable(result); // NOLINT
}
std::string strinc(std::string_view str, bool &ok) {
int index;
for (index = str.size() - 1; index >= 0; index--)
if ((uint8_t &)(str[index]) != 255)
break;
// Must not be called with a string that consists only of zero or more '\xff'
// bytes.
if (index < 0) {
ok = false;
return {};
}
ok = true;
auto r = std::string(str.substr(0, index + 1));
((uint8_t &)r[r.size() - 1])++;
return r;
}
std::string getSearchPath(Node *n) {
assert(n != nullptr);
Arena arena;
auto result = getSearchPath(arena, n);
return std::string((const char *)result.data(), result.size());
}
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node,
ConflictSet::Impl *impl) {
constexpr int kSeparation = 3;
struct DebugDotPrinter {
explicit DebugDotPrinter(FILE *file, ConflictSet::Impl *impl)
: file(file), impl(impl) {}
void print(Node *n, int y = 0) {
assert(n != nullptr);
if (n->entryPresent) {
fprintf(file,
" k_%p [label=\"m=%" PRId64 " p=%" PRId64 " r=%" PRId64
"\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, maxVersion(n, impl).toInt64(),
n->entry.pointVersion.toInt64(),
n->entry.rangeVersion.toInt64(),
getPartialKeyPrintable(n).c_str(), x, y);
} else {
fprintf(file, " k_%p [label=\"m=%" PRId64 "\n%s\", pos=\"%d,%d!\"];\n",
(void *)n, maxVersion(n, impl).toInt64(),
getPartialKeyPrintable(n).c_str(), x, y);
}
x += kSeparation;
for (auto c = getChildGeq(n, 0); c != nullptr;
c = getChildGeq(n, c->parentsIndex + 1)) {
fprintf(file, " k_%p -> k_%p;\n", (void *)n, (void *)c);
print(c, y - kSeparation);
}
}
int x = 0;
FILE *file;
ConflictSet::Impl *impl;
};
fprintf(file, "digraph ConflictSet {\n");
fprintf(file, " node [shape = box];\n");
assert(node != nullptr);
DebugDotPrinter printer{file, impl};
printer.print(node);
fprintf(file, "}\n");
}
void checkParentPointers(Node *node, bool &success) {
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
if (child->parent != node) {
fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n",
getSearchPathPrintable(node).c_str(), child->parentsIndex,
(void *)child->parent, (void *)node);
success = false;
}
checkParentPointers(child, success);
}
}
Node *firstGeq(Node *n, std::string_view key) {
return firstGeqLogical(
n, std::span<const uint8_t>((const uint8_t *)key.data(), key.size()));
}
void checkVersionsGeqOldestExtant(Node *n,
InternalVersionT oldestExtantVersion) {
if (n->entryPresent) {
assert(n->entry.pointVersion >= oldestExtantVersion);
assert(n->entry.rangeVersion >= oldestExtantVersion);
}
switch (n->getType()) {
case Type_Node0: {
} break;
case Type_Node3: {
auto *self = static_cast<Node3 *>(n);
for (int i = 0; i < 3; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
} break;
case Type_Node16: {
auto *self = static_cast<Node16 *>(n);
for (int i = 0; i < 16; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
} break;
case Type_Node48: {
auto *self = static_cast<Node48 *>(n);
for (int i = 0; i < 48; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
for (auto m : self->maxOfMax) {
assert(m >= oldestExtantVersion);
}
} break;
case Type_Node256: {
auto *self = static_cast<Node256 *>(n);
for (int i = 0; i < 256; ++i) {
assert(self->childMaxVersion[i] >= oldestExtantVersion);
}
for (auto m : self->maxOfMax) {
assert(m >= oldestExtantVersion);
}
} break;
default:
abort();
}
}
[[maybe_unused]] InternalVersionT
checkMaxVersion(Node *root, Node *node, InternalVersionT oldestVersion,
bool &success, ConflictSet::Impl *impl) {
checkVersionsGeqOldestExtant(node,
InternalVersionT(impl->oldestExtantVersion));
auto expected = oldestVersion;
if (node->entryPresent) {
expected = std::max(expected, node->entry.pointVersion);
}
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
expected = std::max(
expected, checkMaxVersion(root, child, oldestVersion, success, impl));
if (child->entryPresent) {
expected = std::max(expected, child->entry.rangeVersion);
}
}
auto key = getSearchPath(root);
bool ok;
auto inc = strinc(key, ok);
if (ok) {
auto borrowed = firstGeq(root, inc);
if (borrowed != nullptr) {
expected = std::max(expected, borrowed->entry.rangeVersion);
}
}
if (maxVersion(node, impl) > oldestVersion &&
maxVersion(node, impl) != expected) {
fprintf(stderr, "%s has max version %" PRId64 " . Expected %" PRId64 "\n",
getSearchPathPrintable(node).c_str(),
maxVersion(node, impl).toInt64(), expected.toInt64());
success = false;
}
return expected;
}
[[maybe_unused]] int64_t checkEntriesExist(Node *node, bool &success) {
int64_t total = node->entryPresent;
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
int64_t e = checkEntriesExist(child, success);
total += e;
if (e == 0) {
Arena arena;
fprintf(stderr, "%s has child %02x with no reachable entries\n",
getSearchPathPrintable(node).c_str(), child->parentsIndex);
success = false;
}
}
return total;
}
[[maybe_unused]] void checkMemoryBoundInvariants(Node *node, bool &success) {
int minNumChildren;
switch (node->getType()) {
case Type_Node0:
minNumChildren = 0;
break;
case Type_Node3:
minNumChildren = kMinChildrenNode3;
break;
case Type_Node16:
minNumChildren = kMinChildrenNode16;
break;
case Type_Node48:
minNumChildren = kMinChildrenNode48;
break;
case Type_Node256:
minNumChildren = kMinChildrenNode256;
break;
default:
abort();
}
if (node->numChildren + int(node->entryPresent) < minNumChildren) {
fprintf(stderr,
"%s has %d children + %d entries, which is less than the minimum "
"required %d\n",
getSearchPathPrintable(node).c_str(), node->numChildren,
int(node->entryPresent), minNumChildren);
success = false;
}
// TODO check that the max capacity property eventually holds
for (auto child = getChildGeq(node, 0); child != nullptr;
child = getChildGeq(node, child->parentsIndex + 1)) {
checkMemoryBoundInvariants(child, success);
}
}
[[maybe_unused]] bool checkCorrectness(Node *node,
InternalVersionT oldestVersion,
ConflictSet::Impl *impl) {
bool success = true;
checkParentPointers(node, success);
checkMaxVersion(node, node, oldestVersion, success, impl);
checkEntriesExist(node, success);
checkMemoryBoundInvariants(node, success);
return success;
}
} // namespace
#if SHOW_MEMORY
int64_t nodeBytes = 0;
int64_t peakNodeBytes = 0;
int64_t partialCapacityBytes = 0;
int64_t peakPartialCapacityBytes = 0;
int64_t totalKeys = 0;
int64_t peakKeys = 0;
int64_t keyBytes = 0;
int64_t peakKeyBytes = 0;
int64_t getNodeSize(struct Node *n) {
switch (n->getType()) {
case Type_Node0:
return sizeof(Node0);
case Type_Node3:
return sizeof(Node3);
case Type_Node16:
return sizeof(Node16);
case Type_Node48:
return sizeof(Node48);
case Type_Node256:
return sizeof(Node256);
default:
abort();
}
}
int64_t getSearchPathLength(Node *n) {
assert(n != nullptr);
int64_t result = 0;
for (;;) {
result += n->partialKeyLen;
if (n->parent == nullptr) {
break;
}
++result;
n = n->parent;
}
return result;
}
void addNode(Node *n) {
nodeBytes += getNodeSize(n);
partialCapacityBytes += n->getCapacity();
if (nodeBytes > peakNodeBytes) {
peakNodeBytes = nodeBytes;
}
if (partialCapacityBytes > peakPartialCapacityBytes) {
peakPartialCapacityBytes = partialCapacityBytes;
}
}
void removeNode(Node *n) {
nodeBytes -= getNodeSize(n);
partialCapacityBytes -= n->getCapacity();
}
void addKey(Node *n) {
if (!n->entryPresent) {
++totalKeys;
keyBytes += getSearchPathLength(n);
if (totalKeys > peakKeys) {
peakKeys = totalKeys;
}
if (keyBytes > peakKeyBytes) {
peakKeyBytes = keyBytes;
}
}
}
void removeKey(Node *n) {
if (n->entryPresent) {
--totalKeys;
keyBytes -= getSearchPathLength(n);
}
}
struct __attribute__((visibility("default"))) PeakPrinter {
~PeakPrinter() {
printf("--- radix_tree ---\n");
printf("malloc bytes: %g\n", double(mallocBytes));
printf("Peak malloc bytes: %g\n", double(peakMallocBytes));
printf("Node bytes: %g\n", double(nodeBytes));
printf("Peak node bytes: %g\n", double(peakNodeBytes));
printf("Expected worst case node bytes: %g\n",
double(peakKeys * kBytesPerKey));
printf("Key bytes: %g\n", double(keyBytes));
printf("Peak key bytes: %g (not sharing common prefixes)\n",
double(peakKeyBytes));
printf("Partial key capacity bytes: %g\n", double(partialCapacityBytes));
printf("Peak partial key capacity bytes: %g\n",
double(peakPartialCapacityBytes));
}
} peakPrinter;
#endif
#ifdef ENABLE_MAIN
void printTree() {
int64_t writeVersion = 0;
ConflictSet::Impl cs{writeVersion};
ReferenceImpl refImpl{writeVersion};
Arena arena;
ConflictSet::WriteRange write;
write.begin = "and"_s;
write.end = "ant"_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "any"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "are"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
write.begin = "art"_s;
write.end = ""_s;
cs.addWrites(&write, 1, ++writeVersion);
debugPrintDot(stdout, cs.root, &cs);
}
int main(void) { printTree(); }
#endif
#ifdef ENABLE_FUZZ
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
Arbitrary arbitrary({data, size});
TestDriver<ConflictSet::Impl> driver1{arbitrary};
TestDriver<ConflictSet::Impl> driver2{arbitrary};
bool done1 = false;
bool done2 = false;
for (;;) {
if (!done1) {
done1 = driver1.next();
if (!driver1.ok) {
debugPrintDot(stdout, driver1.cs.root, &driver1.cs);
fflush(stdout);
abort();
}
if (!checkCorrectness(driver1.cs.root, driver1.cs.oldestVersion,
&driver1.cs)) {
debugPrintDot(stdout, driver1.cs.root, &driver1.cs);
fflush(stdout);
abort();
}
}
if (!done2) {
done2 = driver2.next();
if (!driver2.ok) {
debugPrintDot(stdout, driver2.cs.root, &driver2.cs);
fflush(stdout);
abort();
}
if (!checkCorrectness(driver2.cs.root, driver2.cs.oldestVersion,
&driver2.cs)) {
debugPrintDot(stdout, driver2.cs.root, &driver2.cs);
fflush(stdout);
abort();
}
}
if (done1 && done2) {
break;
}
}
return 0;
}
#endif
// GCOVR_EXCL_STOP