Files
conflict-set/ConflictSet.cpp

390 lines
12 KiB
C++

#include "ConflictSet.h"
#include <cassert>
#include <compare>
#include <memory>
#include <span>
#include <utility>
using Key = ConflictSet::Key;
auto operator<=>(const Key &lhs, const Key &rhs) {
const int minLen = std::min(lhs.len, rhs.len);
const int c = memcmp(lhs.p, rhs.p, minLen);
return c != 0 ? c <=> 0 : lhs.len <=> rhs.len;
}
namespace {
// A node in the tree representing write conflict history. This tree maintains
// several invariants:
// 1. BST invariant: all keys in the tree rooted at the left child of a node
// compare less than that node's key, and all keys in the tree rooted at the
// right child of a node compare greater than that node's key.
// 2. Heap invariant: the priority of a node is greater than all the priorities
// of its children (transitively)
// 3. Max invariant: `maxVersion` is the max among all values of `pointVersion`
// and `beyondVersion` for this node and its children (transitively)
// 4. The lowest key (an empty byte sequence) is always physically present in
// the tree so that "last less than or equal" queries are always well-defined.
// Logically, the contents of the tree represent a "range map" where all of the
// infinitely many points in the key space are associated with a writeVersion.
// If a point is physically present in the tree, then its writeVersion is its
// node's `pointVersion`. Otherwise, its writeVersion is the `rangeVersion` of
// the node with the last key less than point.
struct Node {
// See "Max invariant" above
int64_t maxVersion;
// The write version of the point in the key space represented by this node's
// key
int64_t pointVersion;
// The write version of the range immediately after this node's key, until
// just before the next key in the tree. I.e. (this key, next key)
int64_t rangeVersion;
// child[0] is the left child or nullptr. child[1] is the right child or
// nullptr
Node *child[2];
// The parent of this node in the tree, or nullptr if this node is the root
Node *parent;
// As a treap, this tree satisfies the heap invariant on each node's priority
uint32_t priority;
// The length of this node's key
int len;
// The contents of this node's key
// uint8_t[len];
auto operator<=>(const Node &other) const {
const int minLen = std::min(len, other.len);
const int c = memcmp(this + 1, &other + 1, minLen);
return c != 0 ? c <=> 0 : len <=> other.len;
}
auto operator<=>(const ConflictSet::Key &other) const {
const int minLen = std::min<int>(len, other.len);
const int c = memcmp(this + 1, other.p, minLen);
return c != 0 ? c <=> 0 : len <=> other.len;
}
};
// TODO: use a better prng. This is technically vulnerable to a
// denial-of-service attack that can make conflict-checking linear in the
// number of nodes in the tree.
thread_local uint32_t gSeed = 1013904223L;
uint32_t fastRand() {
auto result = gSeed;
gSeed = gSeed * 1664525L + 1013904223L;
return result;
}
// Note: `rangeVersion` is left uninitialized.
Node *createNode(const Key &key, Node *parent, int64_t pointVersion) {
assert(key.len <= std::numeric_limits<int>::max());
Node *result = (Node *)malloc(sizeof(Node) + key.len);
result->maxVersion = pointVersion;
result->pointVersion = pointVersion;
result->child[0] = nullptr;
result->child[1] = nullptr;
result->parent = parent;
result->priority = 0xff & fastRand();
result->len = key.len;
memcpy(result + 1, key.p, key.len);
return result;
}
void destroyNode(Node *node) {
assert(node->child[0] == nullptr);
assert(node->child[1] == nullptr);
free(node);
}
struct Iterator {
Node *node;
int cmp;
};
void lastLeqMulti(Node *root, std::span<Key> keys, Iterator *results) {
assert(std::is_sorted(keys.begin(), keys.end()));
if (keys.size() == 0) {
return;
}
struct Coro {
Node *current;
Node *result;
const Key *key;
int resultC = -1;
int index;
int lcp[2]{};
std::strong_ordering c = std::strong_ordering::equal;
Coro() {}
Coro(Node *current, Node *result, const Key &key, int index)
: current(current), result(result), key(&key), index(index) {}
bool step() {
if (current == nullptr) {
return true;
}
c = *current <=> *key;
if (c == 0) {
result = current;
resultC = 0;
return true;
}
result = c < 0 ? current : result;
current = current->child[c < 0];
return false;
}
};
auto coros = std::unique_ptr<Coro[]>{new Coro[keys.size()]};
// Descend until queries for front and back diverge
Node *current = root;
Node *resultP = nullptr;
auto coroBegin = Coro(current, resultP, keys.front(), -1);
auto coroEnd = Coro(current, resultP, keys.back(), -1);
for (;;) {
bool done1 = coroBegin.step();
bool done2 = coroEnd.step();
if (!done1 && !done2 && coroBegin.c == coroEnd.c) {
assert(coroBegin.current == coroEnd.current);
assert(coroBegin.result == coroEnd.result);
current = coroBegin.current;
resultP = coroBegin.result;
} else {
break;
}
}
int index = 0;
{
auto iter = coros.get();
for (const auto &k : keys) {
*iter++ = Coro(current, resultP, k, index++);
}
}
auto remaining = std::span<Coro>(coros.get(), keys.size());
while (remaining.size() > 0) {
for (int i = 0; i < int(remaining.size());) {
bool done = remaining[i].step();
if (done) {
const auto &c = remaining[i];
results[c.index] = Iterator{c.result, c.resultC};
if (i != int(remaining.size()) - 1) {
remaining[i] = remaining.back();
}
remaining = remaining.subspan(0, remaining.size() - 1);
} else {
++i;
}
}
}
}
// Return a pointer to the node whose key immediately follows `n`'s key (if
// `dir` is false, precedes). Return nullptr if none exists.
[[maybe_unused]] Node *next(Node *n, bool dir) {
// Traverse left spine of right child (when moving right, i.e. dir = true)
if (n->child[dir]) {
n = n->child[dir];
while (n->child[!dir]) {
n = n->child[!dir];
}
} else {
// Search upward for a node such that we're the left child (when moving
// right, i.e. dir = true)
while (n->parent && n == n->parent->child[dir]) {
n = n->parent;
}
n = n->parent;
}
return n;
}
// Return a pointer to the node whose key is greatest among keys in the tree
// rooted at `n` (if dir = false, least). Return nullptr if none exists (i.e.
// `n` is null).
[[maybe_unused]] Node *extrema(Node *n, bool dir) {
if (n == nullptr) {
return nullptr;
}
while (n->child[dir] != nullptr) {
n = n->child[dir];
}
return n;
}
[[maybe_unused]] void debugPrintDot(FILE *file, Node *node) {
struct DebugDotPrinter {
explicit DebugDotPrinter(FILE *file) : file(file) {}
void print(Node *node) {
for (int i = 0; i < 2; ++i) {
if (node->child[i] != nullptr) {
fprintf(file, " k_%.*s -> k_%.*s;\n", node->len,
(const char *)(node + 1), node->child[i]->len,
(const char *)(node->child[i] + 1));
print(node->child[i]);
} else {
fprintf(file, " k_%.*s -> null%d;\n", node->len,
(const char *)(node + 1), id);
++id;
}
}
}
int id = 0;
FILE *file;
};
fprintf(file, "digraph TreeSet {\n");
fprintf(file, " node [fontname=\"Scientifica\"];\n");
if (node != nullptr) {
DebugDotPrinter printer{file};
fprintf(file, "\n");
printer.print(node);
fprintf(file, "\n");
for (auto iter = extrema(node, false); iter != nullptr;
iter = next(iter, true)) {
fprintf(file, " k_%.*s [label=\"k=%.*s;p=%u;m=%d;v=%d,r=%d\"];\n",
iter->len, (const char *)(iter + 1), iter->len,
(const char *)(iter + 1), iter->priority, int(iter->maxVersion),
int(iter->pointVersion), int(iter->rangeVersion));
}
for (int i = 0; i < printer.id; ++i) {
fprintf(file, " null%d [shape=point];\n", i);
}
}
fprintf(file, "}\n");
}
} // namespace
struct ConflictSet::Impl {
Node *root;
int64_t oldestVersion;
explicit Impl(int64_t oldestVersion) noexcept
: root(createNode({nullptr, 0}, nullptr, oldestVersion)),
oldestVersion(oldestVersion) {
root->rangeVersion = oldestVersion;
}
void check(const ReadRange *reads, Result *results, int count) const {
auto iters = std::unique_ptr<Iterator[]>{new Iterator[count]};
auto begins = std::unique_ptr<Key[]>{new Key[count]};
for (int i = 0; i < count; ++i) {
begins.get()[i] = reads[i].begin;
}
lastLeqMulti(root, std::span<Key>(begins.get(), count), iters.get());
// TODO check non-singleton reads lol
for (int i = 0; i < count; ++i) {
assert(reads[i].end.len == 0);
assert(iters[i].node != nullptr);
if ((iters[i].cmp == 0
? iters[i].node->pointVersion
: iters[i].node->rangeVersion) > reads[i].readVersion) {
results[i] = ConflictSet::Conflict;
}
}
}
void addWriteNaive(const WriteRange &write) {
// TODO handle non-singleton writes lol
Node **current = &root;
Node *parent = nullptr;
const auto &key = write.begin;
for (;;) {
if (*current == nullptr) {
auto *newNode = createNode(key, parent, write.writeVersion);
*current = newNode;
auto *prev = ::next(newNode, false);
assert(prev != nullptr);
assert(prev->rangeVersion <= write.writeVersion);
newNode->rangeVersion = prev->rangeVersion;
break;
} else {
// TODO this assert won't be valid in the final design
assert((*current)->maxVersion <= write.writeVersion);
(*current)->maxVersion = write.writeVersion;
auto c = key <=> **current;
if (c == 0) {
(*current)->pointVersion = write.writeVersion;
break;
}
parent = *current;
current = &((*current)->child[c > 0]);
}
}
}
void addWrites(const WriteRange *writes, int count) {
for (const auto &w : std::span<const WriteRange>(writes, count)) {
addWriteNaive(w);
}
}
void setOldestVersion(int64_t oldestVersion) {
assert(oldestVersion > this->oldestVersion);
this->oldestVersion = oldestVersion;
}
~Impl() {
std::vector<Node *> toFree;
if (root != nullptr) {
toFree.push_back(root);
}
while (toFree.size() > 0) {
Node *n = toFree.back();
toFree.pop_back();
for (int i = 0; i < 2; ++i) {
auto *c = std::exchange(n->child[i], nullptr);
if (c != nullptr) {
toFree.push_back(c);
}
}
destroyNode(n);
}
}
};
void ConflictSet::check(const ReadRange *reads, Result *results,
int count) const {
return impl->check(reads, results, count);
}
void ConflictSet::addWrites(const WriteRange *writes, int count) {
return impl->addWrites(writes, count);
}
void ConflictSet::setOldestVersion(int64_t oldestVersion) {
return impl->setOldestVersion(oldestVersion);
}
ConflictSet::ConflictSet(int64_t oldestVersion)
: impl(new Impl{oldestVersion}) {}
ConflictSet::~ConflictSet() { delete impl; }
ConflictSet::ConflictSet(ConflictSet &&other) noexcept
: impl(std::exchange(other.impl, nullptr)) {}
ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
impl = std::exchange(other.impl, nullptr);
return *this;
}
#ifdef ENABLE_TESTS
int main(void) {
ConflictSet::Impl cs{0};
ConflictSet::WriteRange write;
write.begin.p = (const uint8_t *)"0000";
write.begin.len = 4;
write.end.len = 0;
write.writeVersion = 1;
cs.addWrites(&write, 1);
debugPrintDot(stdout, cs.root);
}
#endif