From f6e48cca0ee0fc55b24dcd65a6bedf2961ca36ce Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Mon, 22 Jan 2024 16:24:48 -0800 Subject: [PATCH] Switch to radix tree. WIP --- ConflictSet.cpp | 1112 +++++++++++++++++++---------------------------- 1 file changed, 441 insertions(+), 671 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 3f1a897..c2ec4a4 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -2,9 +2,8 @@ #include #include -#include +#include #include -#include #include #include #include @@ -187,137 +186,6 @@ bool operator!=(const ArenaAlloc &lhs, const ArenaAlloc &rhs) { // ==================== END ARENA IMPL ==================== -// ==================== BEGIN RANDOM IMPL ==================== - -struct Random { - // *Really* minimal PCG32 code / (c) 2014 M.E. O'Neill / pcg-random.org - // Licensed under Apache License 2.0 (NO WARRANTY, etc. see website) - // - // Modified - mostly c -> c++ - Random() = default; - - Random(uint64_t initState, uint64_t initSeq) { - pcg32_srandom_r(initState, initSeq); - next(); - } - - /// Draws from a uniform distribution of uint32_t's - uint32_t next() { - auto result = next_; - next_ = pcg32_random_r(); - return result; - } - - /// Draws from a uniform distribution of [0, s). From - /// https://arxiv.org/pdf/1805.10941.pdf - uint32_t bounded(uint32_t s) { - assert(s != 0); - uint32_t x = next(); - auto m = uint64_t(x) * uint64_t(s); - auto l = uint32_t(m); - if (l < s) { - uint32_t t = -s % s; - while (l < t) { - x = next(); - m = uint64_t(x) * uint64_t(s); - l = uint32_t(m); - } - } - uint32_t result = m >> 32; - return result; - } - - /// Fill `bytes` with `size` random bytes - void randomBytes(uint8_t *bytes, int size); - - /// Fill `bytes` with `size` random hex bytes - void randomHex(uint8_t *bytes, int size); - - template >> - T randT() { - T t; - randomBytes((uint8_t *)&t, sizeof(T)); - return t; - } - -private: - uint32_t pcg32_random_r() { - uint64_t oldState = state; - // Advance internal state - state = oldState * 6364136223846793005ULL + inc; - // Calculate output function (XSH RR), uses old state for max ILP - uint32_t xorShifted = ((oldState >> 18u) ^ oldState) >> 27u; - uint32_t rot = oldState >> 59u; - return (xorShifted >> rot) | (xorShifted << ((-rot) & 31)); - } - - // Seed the rng. Specified in two parts, state initializer and a - // sequence selection constant (a.k.a. stream id) - void pcg32_srandom_r(uint64_t initstate, uint64_t initSeq) { - state = 0U; - inc = (initSeq << 1u) | 1u; - pcg32_random_r(); - state += initstate; - pcg32_random_r(); - } - uint32_t next_{}; - // RNG state. All values are possible. - uint64_t state{}; - // Controls which RNG sequence (stream) is selected. Must *always* be odd. - uint64_t inc{}; -}; - -template void shuffle(Random &rand, Container &x) { - using std::swap; - for (int i = x.size() - 1; i > 0; --i) { - int j = rand.bounded(i + 1); - if (i != j) { - swap(x[i], x[j]); - } - } -} - -void Random::randomBytes(uint8_t *bytes, int size) { - int i = 0; - for (; i + 4 < size; i += 4) { - uint32_t random = next(); - memcpy(bytes + i, &random, 4); - } - if (i < size) { - uint32_t random = next(); - memcpy(bytes + i, &random, size - i); - } -} - -void Random::randomHex(uint8_t *bytes, int size) { - int i = 0; - while (i + 8 < size) { - uint32_t r = next(); - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - } - uint32_t r = next(); - while (i < size) { - bytes[i++] = "0123456789abcdef"[r & 0b1111]; - r >>= 4; - } -} - -// ==================== END RANDOM IMPL ==================== - // ==================== BEGIN ARBITRARY IMPL ==================== /// Think of `Arbitrary` as an attacker-controlled random number generator. @@ -564,581 +432,488 @@ Key toKeyAfter(Arena &arena, int n) { // ==================== BEGIN IMPLEMENTATION ==================== -#define SHOW_PRIORITY 0 -#define DEBUG 0 - -static auto operator<=>(const Key &lhs, const Key &rhs) { - const int minLen = std::min(lhs.len, rhs.len); - const int c = memcmp(lhs.p, rhs.p, minLen); - return c != 0 ? c <=> 0 : lhs.len <=> rhs.len; -} - -// A node in the tree representing write conflict history. This tree maintains -// several invariants: - -// 1. BST invariant: all keys in the tree rooted at the left child of a node -// compare less than that node's key, and all keys in the tree rooted at the -// right child of a node compare greater than that node's key. -// 2. Heap invariant: the priority of a node is >= all the priorities -// of its children (transitively) -// 3. Max invariant: `maxVersion` is the max among all values of `pointVersion` -// and `beyondVersion` for this node and its children (transitively) -// 4. The lowest key (an empty byte sequence) is always physically present in -// the tree so that "last less than or equal" queries are always well-defined. - -// Logically, the contents of the tree represent a "range map" where all of the -// infinitely many points in the key space are associated with a writeVersion. -// If a point is physically present in the tree, then its writeVersion is its -// node's `pointVersion`. Otherwise, its writeVersion is the `rangeVersion` of -// the node with the last key less than point. -struct Node { - // See "Max invariant" above - int64_t maxVersion; - // The write version of the point in the key space represented by this node's - // key +struct Entry { int64_t pointVersion; - // The write version of the range immediately after this node's key, until - // just before the next key in the tree. I.e. (this key, next key) int64_t rangeVersion; - // child[0] is the left child or nullptr. child[1] is the right child or - // nullptr - Node *child[2]; - // The parent of this node in the tree, or nullptr if this node is the root - Node *parent; - // As a treap, this tree satisfies the heap invariant on each node's priority - uint32_t priority; - // The length of this node's key - int len; - // The contents of this node's key - // uint8_t[len]; +}; - auto operator<=>(const ConflictSet::Key &other) const { - const int minLen = std::min(len, other.len); - const int c = memcmp(this + 1, other.p, minLen); - return c != 0 ? c <=> 0 : len <=> other.len; +enum class Type : int8_t { + Node4, + Node16, + Node48, + Node256, + Invalid, +}; +struct Node { + /* begin section that's copied to the next node */ + Node *parent = nullptr; + int64_t maxVersion; + Entry entry; + int16_t numChildren = 0; + bool entryPresent = false; + uint8_t parentsIndex = 0; + /* end section that's copied to the next node */ + + Type type = Type::Invalid; +}; +Node *getChild(Node *self, uint8_t index); +int getChildGeq(Node *self, int child); +Node *&getOrCreateChild(Node *&self, uint8_t index); +Node *newNode(); +void eraseChild(Node *self, uint8_t index); + +struct Node4 : Node { + // Sorted + uint8_t index[4] = {}; + Node *children[4] = {}; + Node4() { this->type = Type::Node4; } +}; + +Node *newNode() { return new (safe_malloc(sizeof(Node4))) Node4; } + +struct Node16 : Node { + // Sorted + uint8_t index[16] = {}; + Node *children[16] = {}; + Node16() { this->type = Type::Node16; } +}; + +struct Node48 : Node { + int8_t nextFree = 0; + int8_t index[256]; + Node *children[48] = {}; + Node48() { + this->type = Type::Node48; + memset(index, -1, 256); } }; -// Note: `rangeVersion` is left uninitialized. -Node *createNode(const Key &key, Node *parent, int64_t pointVersion, - Random &rand) { - assert(key.len <= std::numeric_limits::max()); - Node *result = (Node *)safe_malloc(sizeof(Node) + key.len); - result->maxVersion = pointVersion; - result->pointVersion = pointVersion; - result->child[0] = nullptr; - result->child[1] = nullptr; - result->parent = parent; - result->priority = rand.next(); -#if SHOW_PRIORITY - result->priority &= 0xff; +struct Node256 : Node { + Node *children[256] = {}; + Node256() { this->type = Type::Node256; } +}; + +static int getNodeIndex(Node4 *self, uint8_t index) { + for (int i = 0; i < self->numChildren; ++i) { + if (self->index[i] == index) { + return i; + } + } + return -1; +} + +static int getNodeIndex(Node16 *self, uint8_t index) { +#ifdef HAS_AVX + // Based on https://www.the-paper-trail.org/post/art-paper-notes/ + + // key_vec is 16 repeated copies of the searched-for byte, one for every + // possible position in child_keys that needs to be searched. + __m128i key_vec = _mm_set1_epi8(index); + + // Compare all child_keys to 'index' in parallel. Don't worry if some of the + // keys aren't valid, we'll mask the results to only consider the valid ones + // below. + __m128i indices; + memcpy(&indices, self->index, sizeof(self->index)); + __m128i results = _mm_cmpeq_epi8(key_vec, indices); + + // Build a mask to select only the first node->num_children values from the + // comparison (because the other values are meaningless) + int mask = (1 << self->numChildren) - 1; + + // Change the results of the comparison into a bitfield, masking off any + // invalid comparisons. + int bitfield = _mm_movemask_epi8(results) & mask; + + // No match if there are no '1's in the bitfield. + if (bitfield == 0) + return -1; + + // Find the index of the first '1' in the bitfield by counting the leading + // zeros. + return __builtin_ctz(bitfield); +#elif defined(HAS_ARM_NEON) + // Based on + // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon + + uint8x16_t indices; + memcpy(&indices, self->index, sizeof(self->index)); + // 0xff for each match + uint16x8_t results = + vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices)); + uint64_t mask = self->numChildren == 16 + ? uint64_t(-1) + : (uint64_t(1) << (self->numChildren * 4)) - 1; + // 0xf for each match in valid range + uint64_t bitfield = + vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask; + if (bitfield == 0) + return -1; + return __builtin_ctzll(bitfield) / 4; +#else + for (int i = 0; i < self->numChildren; ++i) { + if (self->index[i] == index) { + return i; + } + } + return -1; #endif - result->len = key.len; - if (key.len > 0) { - memcpy(result + 1, key.p, key.len); - } - return result; } -void destroyNode(Node *node) { - assert(node->child[0] == nullptr); - assert(node->child[1] == nullptr); - free(node); +#ifdef HAS_AVX +static int firstNonNeg1(const int8_t x[16]) { + __m128i key_vec = _mm_set1_epi8(-1); + __m128i indices; + memcpy(&indices, x, 16); + __m128i results = _mm_cmpeq_epi8(key_vec, indices); + uint32_t bitfield = _mm_movemask_epi8(results) ^ 0xffff; + if (bitfield == 0) + return -1; + return __builtin_ctz(bitfield); } +#endif -struct Iterator { - Node *node; - int cmp; -}; +#ifdef HAS_ARM_NEON +static int firstNonNeg1(const int8_t x[16]) { + uint8x16_t indices; + memcpy(&indices, x, 16); + uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(-1), indices)); + uint64_t bitfield = + ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); + if (bitfield == 0) + return -1; + return __builtin_ctzll(bitfield) / 4; +} +#endif -struct StepwiseLastLeq { - Node *current; - Node *result; - const Key *key; - int resultC = -1; - int index; - std::strong_ordering c = std::strong_ordering::equal; - - StepwiseLastLeq() {} - StepwiseLastLeq(Node *current, Node *result, const Key &key, int index) - : current(current), result(result), key(&key), index(index) {} - - bool step() { - if (current == nullptr) { - return true; +Node *getChild(Node *self, uint8_t index) { + if (self->type == Type::Node4) { + auto *self4 = static_cast(self); + int i = getNodeIndex(self4, index); + if (i >= 0) { + return self4->children[i]; } - c = *current <=> *key; - if (c == 0) { - result = current; - resultC = 0; - return true; + return nullptr; + } else if (self->type == Type::Node16) { + auto *self16 = static_cast(self); + int i = getNodeIndex(self16, index); + if (i >= 0) { + return self16->children[i]; } - result = c < 0 ? current : result; - current = current->child[c < 0]; - return false; - } -}; - -void lastLeqMulti(Arena &arena, Node *root, std::span keys, - Iterator *results) { - assert(std::is_sorted(keys.begin(), keys.end())); - - if (keys.size() == 0) { - return; - } - - auto *stepwiseLastLeqs = new (arena) StepwiseLastLeq[keys.size()]; - - // Descend until queries for front and back diverge - Node *current = root; - Node *resultP = nullptr; - auto stepwiseFront = StepwiseLastLeq(current, resultP, keys.front(), -1); - auto stepwiseBack = StepwiseLastLeq(current, resultP, keys.back(), -1); - for (;;) { - bool done1 = stepwiseFront.step(); - bool done2 = stepwiseBack.step(); - if (!done1 && !done2 && stepwiseFront.c == stepwiseBack.c) { - assert(stepwiseFront.current == stepwiseBack.current); - assert(stepwiseFront.result == stepwiseBack.result); - current = stepwiseFront.current; - resultP = stepwiseFront.result; - } else { - break; + return nullptr; + } else if (self->type == Type::Node48) { + auto *self48 = static_cast(self); + int secondIndex = self48->index[index]; + if (secondIndex >= 0) { + return self48->children[secondIndex]; } - } - - int index = 0; - { - auto iter = stepwiseLastLeqs; - for (const auto &k : keys) { - *iter++ = StepwiseLastLeq(current, resultP, k, index++); - } - } - auto stepwiseSpan = std::span(stepwiseLastLeqs, keys.size()); - runInterleaved(stepwiseSpan); - for (const auto &stepwise : stepwiseSpan) { - results[stepwise.index] = Iterator{stepwise.result, stepwise.resultC}; + return nullptr; + } else { + auto *self256 = static_cast(self); + return self256->children[index]; } } -// Return a pointer to the node whose key immediately follows `n`'s key (if -// `dir` is false, precedes). Return nullptr if none exists. -Node *next(Node *n, bool dir) { - // Traverse left spine of right child (when moving right, i.e. dir = true) - if (n->child[dir]) { - n = n->child[dir]; - while (n->child[!dir]) { - n = n->child[!dir]; +// Precondition - an entry for index must exist in the node +static Node *&getChildExists(Node *self, uint8_t index) { + if (self->type == Type::Node4) { + auto *self4 = static_cast(self); + return self4->children[getNodeIndex(self4, index)]; + } else if (self->type == Type::Node16) { + auto *self16 = static_cast(self); + return self16->children[getNodeIndex(self16, index)]; + } else if (self->type == Type::Node48) { + auto *self48 = static_cast(self); + int secondIndex = self48->index[index]; + if (secondIndex >= 0) { + return self48->children[secondIndex]; } } else { - // Search upward for a node such that we're the left child (when moving - // right, i.e. dir = true) - while (n->parent && n == n->parent->child[dir]) { - n = n->parent; - } - n = n->parent; + auto *self256 = static_cast(self); + return self256->children[index]; } - return n; + __builtin_unreachable(); } -// Return a pointer to the node whose key is greatest among keys in the tree -// rooted at `n` (if dir = false, least). Return nullptr if none exists (i.e. -// `n` is null). -Node *extrema(Node *n, bool dir) { - if (n == nullptr) { - return nullptr; - } - while (n->child[dir] != nullptr) { - n = n->child[dir]; - } - return n; -} - -void debugPrintDot(FILE *file, Node *node) { - - struct DebugDotPrinter { - - explicit DebugDotPrinter(FILE *file) : file(file) {} - - void print(Node *node) { - if (node->child[0] == nullptr && node->child[1] == nullptr) { - return; +int getChildGeq(Node *self, int child) { + if (self->type == Type::Node4) { + auto *self4 = static_cast(self); + for (int i = 0; i < self->numChildren; ++i) { + if (i > 0) { + assert(self4->index[i - 1] < self4->index[i]); } - for (int i = 0; i < 2; ++i) { - if (node->child[i] != nullptr) { - fprintf(file, " k_%.*s -> k_%.*s;\n", node->len, - (const char *)(node + 1), node->child[i]->len, - (const char *)(node->child[i] + 1)); - print(node->child[i]); - } else { - fprintf(file, " k_%.*s -> null%d;\n", node->len, - (const char *)(node + 1), id); - ++id; - } + if (self4->index[i] >= child) { + return self4->index[i]; } } - int id = 0; - FILE *file; - }; - - fprintf(file, "digraph ConflictSet {\n"); - if (node != nullptr) { - DebugDotPrinter printer{file}; - fprintf(file, "\n"); - printer.print(node); - fprintf(file, "\n"); - for (auto iter = extrema(node, false); iter != nullptr; - iter = next(iter, true)) { - fprintf(file, - " k_%.*s [label=\"k=\\\"%.*s\\\"\\n" -#if SHOW_PRIORITY - "p=%u\\n" -#endif - "m=%d\\nv=%d r=%d\"];\n", - iter->len, (const char *)(iter + 1), iter->len, - (const char *)(iter + 1), -#if SHOW_PRIORITY - iter->priority, -#endif - int(iter->maxVersion), int(iter->pointVersion), - int(iter->rangeVersion)); + } else if (self->type == Type::Node16) { + auto *self16 = static_cast(self); + for (int i = 0; i < self->numChildren; ++i) { + if (i > 0) { + assert(self16->index[i - 1] < self16->index[i]); + } + if (self16->index[i] >= child) { + return self16->index[i]; + } } - for (int i = 0; i < printer.id; ++i) { - fprintf(file, " null%d [shape=point];\n", i); + } else if (self->type == Type::Node48) { + auto *self48 = static_cast(self); +#if defined(HAS_AVX) || defined(HAS_ARM_NEON) + int i = child; + for (; (i & 0xf) != 0; ++i) { + if (self48->index[i] >= 0) { + assert(self48->children[self48->index[i]] != nullptr); + return i; + } } + for (; i < 256; i += 16) { + auto result = firstNonNeg1(self48->index + i); + if (result != -1) { + return i + result; + } + } +#else + for (int i = child; i < 256; ++i) { + if (self48->index[i] >= 0) { + assert(self48->children[self48->index[i]] != nullptr); + return i; + } + } +#endif + } else { + auto *self256 = static_cast(self); + // For some reason gcc can't auto vectorize this, and the plain loop is + // faster. +#if defined(__clang__) + int i = child; + constexpr int kUnrollCount = 8; // Must be a power of two and <= 8 + for (; (i & (kUnrollCount - 1)) != 0; ++i) { + if (self256->children[i]) { + return i; + } + } + for (; i < 256; i += kUnrollCount) { + uint8_t nonNull[kUnrollCount]; + for (int j = 0; j < kUnrollCount; ++j) { + nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0; + } + uint64_t word; + memcpy(&word, nonNull, kUnrollCount); + if (word) { + return i + __builtin_ctzll(word) / 8; + } + } +#else + for (int i = child; i < 256; ++i) { + if (self256->children[i]) { + return i; + } + } +#endif } - fprintf(file, "}\n"); + return -1; } -void printLogical(std::string &result, Node *node) { - for (auto iter = extrema(node, false); iter != nullptr;) { - auto *next = ::next(iter, true); - std::string key; - for (uint8_t c : std::string_view((const char *)(iter + 1), iter->len)) { - key += "x"; - key += "0123456789abcdef"[c / 16]; - key += "0123456789abcdef"[c % 16]; +static void setChildrenParents(Node *node) { + for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { + getChildExists(node, i)->parent = node; + } +} + +Node *&getOrCreateChild(Node *&self, uint8_t index) { + if (self->type == Type::Node4) { + auto *self4 = static_cast(self); + { + int i = getNodeIndex(self4, index); + if (i >= 0) { + return self4->children[i]; + } } - if (iter->pointVersion == iter->rangeVersion) { - result += key + " -> " + std::to_string(iter->pointVersion) + "\n"; + if (self->numChildren == 4) { + auto *newSelf = new (safe_malloc(sizeof(Node16))) Node16; + memcpy((void *)newSelf, self, offsetof(Node, type)); + if (newSelf->parent) { + getOrCreateChild(newSelf->parent, newSelf->parentsIndex) = newSelf; + } + memcpy(newSelf->index, self4->index, 4); + memcpy(newSelf->children, self4->children, 4 * sizeof(void *)); + self = newSelf; + setChildrenParents(self); + goto insert16; } else { - result += key + " -> " + std::to_string(iter->pointVersion) + "\n"; - if (next == nullptr || - (std::string_view((const char *)(next + 1), iter->len) != - (std::string((const char *)(iter + 1), iter->len) + - std::string("\x00", 1)))) { - result += key + "x00 -> " + std::to_string(iter->rangeVersion) + "\n"; - } - } - iter = next; - } -} - -// Recompute maxVersion, and propagate up the tree as necessary -void updateMaxVersion(Node *n) { - for (;;) { - int64_t maxVersion = std::max(n->pointVersion, n->rangeVersion); - for (int i = 0; i < 2; ++i) { - maxVersion = - std::max(maxVersion, n->child[i] != nullptr ? n->child[i]->maxVersion - : maxVersion); - } - if (n->maxVersion == maxVersion) { - break; - } - n->maxVersion = maxVersion; - if (n->parent == nullptr) { - break; - } - n = n->parent; - } -} - -void rotate(Node **node, bool dir) { - // diagram shown for dir == true - /* n - / - l - \ - lr - */ - assert(node != nullptr); - Node *n = *node; - assert(n != nullptr); - Node *parent = n->parent; - Node *l = n->child[!dir]; - assert(l != nullptr); - Node *lr = l->child[dir]; - n->child[!dir] = lr; - if (lr) { - lr->parent = n; - } - l->child[dir] = n; - n->parent = l; - l->parent = parent; - *node = l; - /* l - \ - n - / - lr - */ - updateMaxVersion(n); - updateMaxVersion(l); -} - -void checkParentPointers(Node *node, bool &success) { - for (int i = 0; i < 2; ++i) { - if (node->child[i] != nullptr) { - if (node->child[i]->parent != node) { - fprintf(stderr, "%.*s child %d has parent pointer %p. Expected %p\n", - node->len, (const char *)(node + 1), i, - (void *)node->child[i]->parent, (void *)node); - success = false; - } - checkParentPointers(node->child[i], success); - } - } -} - -int64_t checkMaxVersion(Node *node, bool &success) { - int64_t expected = std::max(node->pointVersion, node->rangeVersion); - for (int i = 0; i < 2; ++i) { - if (node->child[i] != nullptr) { - expected = std::max(expected, checkMaxVersion(node->child[i], success)); - } - } - if (node->maxVersion != expected) { - fprintf(stderr, "%.*s has max version %d. Expected %d\n", node->len, - (const char *)(node + 1), int(node->maxVersion), int(expected)); - success = false; - } - return expected; -} - -bool checkCorrectness(Node *node, ReferenceImpl &refImpl) { - bool success = true; - // Check bst invariant - Arena arena; - auto keys = vector(arena); - for (auto iter = extrema(node, false); iter != nullptr; - iter = next(iter, true)) { - keys.push_back(std::string_view((char *)(iter + 1), iter->len)); - for (int i = 0; i < 2; ++i) { - if (iter->child[i] != nullptr) { - if (iter->priority < iter->child[i]->priority) { - fprintf(stderr, "%.*s has priority < its child %.*s\n", iter->len, - (const char *)(iter + 1), iter->child[i]->len, - (const char *)(iter->child[i] + 1)); - success = false; + ++self->numChildren; + for (int i = 0; i < int(self->numChildren) - 1; ++i) { + if (int(self4->index[i]) > int(index)) { + memmove(self4->index + i + 1, self4->index + i, + self->numChildren - (i + 1)); + memmove(self4->children + i + 1, self4->children + i, + (self->numChildren - (i + 1)) * sizeof(void *)); + self4->index[i] = index; + self4->children[i] = nullptr; + return self4->children[i]; } } + self4->index[self->numChildren - 1] = index; + self4->children[self->numChildren - 1] = nullptr; + return self4->children[self->numChildren - 1]; } + } else if (self->type == Type::Node16) { + insert16: + auto *self16 = static_cast(self); + { + int i = getNodeIndex(self16, index); + if (i >= 0) { + return self16->children[i]; + } + } + if (self->numChildren == 16) { + auto *newSelf = new (safe_malloc(sizeof(Node48))) Node48; + memcpy((void *)newSelf, self, offsetof(Node, type)); + newSelf->nextFree = 16; + if (newSelf->parent) { + getOrCreateChild(newSelf->parent, newSelf->parentsIndex) = newSelf; + } + int i = 0; + for (auto x : self16->index) { + newSelf->children[i] = self16->children[i]; + newSelf->index[x] = i; + ++i; + } + assert(i == 16); + self = newSelf; + setChildrenParents(self); + goto insert48; + } else { + ++self->numChildren; + for (int i = 0; i < int(self->numChildren) - 1; ++i) { + if (int(self16->index[i]) > int(index)) { + memmove(self16->index + i + 1, self16->index + i, + self->numChildren - (i + 1)); + memmove(self16->children + i + 1, self16->children + i, + (self->numChildren - (i + 1)) * sizeof(void *)); + self16->index[i] = index; + self16->children[i] = nullptr; + return self16->children[i]; + } + } + self16->index[self->numChildren - 1] = index; + self16->children[self->numChildren - 1] = nullptr; + return self16->children[self->numChildren - 1]; + } + } else if (self->type == Type::Node48) { + insert48: + auto *self48 = static_cast(self); + int secondIndex = self48->index[index]; + if (secondIndex >= 0) { + return self48->children[secondIndex]; + } + if (self->numChildren == 48) { + auto *newSelf = new (safe_malloc(sizeof(Node256))) Node256; + memcpy((void *)newSelf, self, offsetof(Node, type)); + if (newSelf->parent) { + getOrCreateChild(newSelf->parent, newSelf->parentsIndex) = newSelf; + } + for (int i = 0; i < 256; ++i) { + if (self48->index[i] >= 0) { + newSelf->children[i] = self48->children[self48->index[i]]; + } + } + self = newSelf; + setChildrenParents(self); + goto insert256; + } else { + ++self->numChildren; + assert(self48->nextFree < 48); + self48->index[index] = self48->nextFree; + self48->children[self48->nextFree] = nullptr; + return self48->children[self48->nextFree++]; + } + } else { + insert256: + auto *self256 = static_cast(self); + if (!self256->children[index]) { + ++self->numChildren; + } + return self256->children[index]; } - assert(std::is_sorted(keys.begin(), keys.end())); - - checkMaxVersion(node, success); - checkParentPointers(node, success); - - std::string logicalMap; - std::string referenceLogicalMap; - printLogical(logicalMap, node); - refImpl.printLogical(referenceLogicalMap); - if (logicalMap != referenceLogicalMap) { - fprintf(stderr, - "Logical map not equal to reference logical map.\n\nActual:\n" - "%s\nExpected:\n%s\n", - logicalMap.c_str(), referenceLogicalMap.c_str()); - success = false; - } - - return success; } -struct __attribute__((__visibility__("hidden"))) ConflictSet::Impl { - Random rand; - Node *root; - int64_t oldestVersion; - - explicit Impl(int64_t oldestVersion, uint64_t seed) noexcept - : rand{seed & 0xfffffffful, seed >> 32}, - root(createNode({nullptr, 0}, nullptr, oldestVersion, rand)), - oldestVersion(oldestVersion) { - root->rangeVersion = oldestVersion; +// Precondition - an entry for index must exist in the node +void eraseChild(Node *self, uint8_t index) { + if (self->type == Type::Node4) { + auto *self4 = static_cast(self); + int nodeIndex = getNodeIndex(self4, index); + memmove(self4->index + nodeIndex, self4->index + nodeIndex + 1, + sizeof(self4->index[0]) * (self->numChildren - (nodeIndex + 1))); + memmove(self4->children + nodeIndex, self4->children + nodeIndex + 1, + sizeof(self4->children[0]) * // NOLINT + (self->numChildren - (nodeIndex + 1))); + } else if (self->type == Type::Node16) { + auto *self16 = static_cast(self); + int nodeIndex = getNodeIndex(self16, index); + memmove(self16->index + nodeIndex, self16->index + nodeIndex + 1, + sizeof(self16->index[0]) * (self->numChildren - (nodeIndex + 1))); + memmove(self16->children + nodeIndex, self16->children + nodeIndex + 1, + sizeof(self16->children[0]) * // NOLINT + (self->numChildren - (nodeIndex + 1))); + } else if (self->type == Type::Node48) { + auto *self48 = static_cast(self); + int8_t toRemoveChildrenIndex = std::exchange(self48->index[index], -1); + int8_t lastChildrenIndex = --self48->nextFree; + assert(toRemoveChildrenIndex >= 0); + assert(lastChildrenIndex >= 0); + if (toRemoveChildrenIndex != lastChildrenIndex) { + self48->children[toRemoveChildrenIndex] = + std::exchange(self48->children[lastChildrenIndex], nullptr); + self48->index[self48->children[toRemoveChildrenIndex]->parentsIndex] = + toRemoveChildrenIndex; + } + } else { + auto *self256 = static_cast(self); + self256->children[index] = nullptr; } - - void check(const ReadRange *reads, Result *results, int count) const { - int searchCount = 0; - for (int i = 0; i < count; ++i) { - if (reads[i].readVersion >= oldestVersion) { - ++searchCount; - } else { - results[i] = ConflictSet::TooOld; - } - } - Arena arena; - auto *iters = new (arena) Iterator[searchCount]; - auto *begins = new (arena) Key[searchCount]; - int j = 0; - for (int i = 0; i < count; ++i) { - if (reads[i].readVersion >= oldestVersion) { - begins[j++] = reads[i].begin; - } - } - lastLeqMulti(arena, root, std::span(begins, searchCount), iters); - // TODO check non-singleton reads lol - j = 0; - for (int i = 0; i < count; ++i) { - if (reads[i].readVersion >= oldestVersion) { - assert(reads[i].end.len == 0); - assert(iters[i].node != nullptr); - if ((iters[j].cmp == 0 - ? iters[j].node->pointVersion - : iters[j].node->rangeVersion) > reads[i].readVersion) { - results[i] = ConflictSet::Conflict; - } else { - results[i] = ConflictSet::Commit; - } - ++j; - } - } - } - - struct StepwiseInsert { - // After this phase, the heap invariant may be violated for *current and - // (*current)->parent. - Node **current; - Node *parent; - const Key *key; - int64_t writeVersion; - Random *rand; - - StepwiseInsert() {} - StepwiseInsert(Node **root, const Key &key, int64_t writeVersion, - Random *rand) - : current(root), parent(nullptr), key(&key), writeVersion(writeVersion), - rand(rand) {} - bool step() { -#if DEBUG - fprintf(stderr, "Step insert of %.*s. At node: %.*s\n", key->len, key->p, - (*current) ? (*current)->len : 7, - (*current) ? (const char *)((*current) + 1) : "nullptr"); -#endif - if (*current == nullptr) { - auto *newNode = createNode(*key, parent, writeVersion, *rand); - *current = newNode; - // We could interleave the iteration in ::next, but we'd need a careful - // analysis for correctness and it's unlikely to be worthwhile. - auto *prev = ::next(newNode, false); - // The empty key always exists. If key is empty then we won't reach - // here. - assert(prev != nullptr); - assert(prev->rangeVersion <= writeVersion); - newNode->rangeVersion = prev->rangeVersion; - return true; - } else { - // This is the key optimization - setting the max version on the way - // down the search path so we only have to do one traversal. - (*current)->maxVersion = std::max((*current)->maxVersion, writeVersion); - auto c = *key <=> **current; - if (c == 0) { - (*current)->pointVersion = writeVersion; - return true; - } - parent = *current; - current = &((*current)->child[c > 0]); - } - return false; - } - }; - - void addWrites(const WriteRange *writes, int count) { - Arena arena; - auto stepwiseInserts = - std::span(new (arena) StepwiseInsert[count], count); - for (int i = 0; i < count; ++i) { - // TODO handle non-singleton writes lol - assert(writes[i].end.len == 0); - - stepwiseInserts[i] = - StepwiseInsert{&root, writes[i].begin, writes[i].writeVersion, &rand}; - } - - // TODO Descend until queries for front and back diverge - - // Mitigate potential n^2 behavior of insertion (imagine if all inserts - // shared the same search path in the pre-existing tree) by shuffling the - // insertion order. Not sure how this interacts with interleaved insertion - // but it's probably fine. There's a hand-wavy symmetry argument. - shuffle(rand, stepwiseInserts); - - runInterleaved(stepwiseInserts); - - auto workList = vector(arena); - workList.reserve(count); - for (int i = 0; i < count; ++i) { - Node *node = *stepwiseInserts[i].current; - assert(node != nullptr); - workList.push_back(*stepwiseInserts[i].current); - } - - while (!workList.empty()) { - Node *n = workList.back(); - workList.pop_back(); -#if DEBUG - fprintf(stderr, "\tcheck heap invariant %.*s\n", n->len, - (const char *)(n + 1)); -#endif - if (n->parent == nullptr) { - continue; - } - const bool dir = n == n->parent->child[1]; - assert(dir || n == n->parent->child[0]); - // p is the address of the pointer to n->parent in the tree - Node **p = n->parent->parent == nullptr - ? &root - : &n->parent->parent - ->child[n->parent->parent->child[1] == n->parent]; - assert(*p == n->parent); - if (n->parent->priority < n->priority) { -#if DEBUG - fprintf(stderr, "\trotate %.*s %s\n", n->len, (const char *)(n + 1), - !dir ? "right" : "left"); -#endif - rotate(p, !dir); - workList.push_back(*p); - assert((*p)->child[!dir] != nullptr); - auto *lr = (*p)->child[!dir]->child[dir]; - if (lr != nullptr) { - workList.push_back(lr); - } - } - } + --self->numChildren; + if (self->numChildren == 0 && !self->entryPresent && + self->parent != nullptr) { + eraseChild(self->parent, self->parentsIndex); } +} +struct __attribute__((visibility("hidden"))) ConflictSet::Impl { + void check(const ReadRange *, Result *, int) const {} + void addWrites(const WriteRange *, int) {} void setOldestVersion(int64_t oldestVersion) { - assert(oldestVersion > this->oldestVersion); this->oldestVersion = oldestVersion; } - + explicit Impl(int64_t oldestVersion) : oldestVersion(oldestVersion) { + // Insert "" + root = newNode(); + root->maxVersion = oldestVersion; + root->entry.pointVersion = oldestVersion; + root->entry.rangeVersion = oldestVersion; + root->entryPresent = true; + } ~Impl() { Arena arena; auto toFree = vector(arena); - if (root != nullptr) { - toFree.push_back(root); - } + toFree.push_back(root); while (toFree.size() > 0) { - Node *n = toFree.back(); + auto *n = toFree.back(); toFree.pop_back(); - for (int i = 0; i < 2; ++i) { - auto *c = std::exchange(n->child[i], nullptr); - if (c != nullptr) { - toFree.push_back(c); - } + // Add all children to toFree + for (int child = getChildGeq(n, 0); child >= 0; + child = getChildGeq(n, child + 1)) { + auto *c = getChild(n, child); + assert(c != nullptr); + toFree.push_back(c); } - destroyNode(n); + free(n); } } + +private: + Node *root; + int64_t oldestVersion; }; // ==================== END IMPLEMENTATION ==================== @@ -1156,8 +931,8 @@ void ConflictSet::setOldestVersion(int64_t oldestVersion) { return impl->setOldestVersion(oldestVersion); } -ConflictSet::ConflictSet(int64_t oldestVersion, uint64_t seed) - : impl(new(safe_malloc(sizeof(Impl))) Impl{oldestVersion, seed}) {} +ConflictSet::ConflictSet(int64_t oldestVersion, [[maybe_unused]] uint64_t seed) + : impl(new(safe_malloc(sizeof(Impl))) Impl{oldestVersion}) {} ConflictSet::~ConflictSet() { if (impl) { @@ -1195,9 +970,9 @@ ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) { ((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion); } __attribute__((__visibility__("default"))) void * -ConflictSet_create(int64_t oldestVersion, uint64_t seed) { +ConflictSet_create(int64_t oldestVersion, uint64_t) { return new (safe_malloc(sizeof(ConflictSet::Impl))) - ConflictSet::Impl{oldestVersion, seed}; + ConflictSet::Impl{oldestVersion}; } __attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) { using Impl = ConflictSet::Impl; @@ -1213,7 +988,7 @@ void __throw_length_error(const char *) { __builtin_unreachable(); } #ifdef ENABLE_TESTS int main(void) { int64_t writeVersion = 0; - ConflictSet::Impl cs{writeVersion, 0}; + ConflictSet::Impl cs{writeVersion}; ReferenceImpl refImpl{writeVersion}; Arena arena; constexpr int kNumKeys = 10; @@ -1226,9 +1001,7 @@ int main(void) { } cs.addWrites(write, kNumKeys); refImpl.addWrites(write, kNumKeys); - debugPrintDot(stdout, cs.root); - bool success = checkCorrectness(cs.root, refImpl); - return success ? 0 : 1; + return 0; } #endif @@ -1236,12 +1009,9 @@ int main(void) { extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // TODO call setOldestVersion, and check range writes/reads gArbitrary = Arbitrary{{data, size}}; - uint64_t state = gArbitrary.next(); - uint64_t seq = gArbitrary.next(); - auto rand = Random{state, seq}; int64_t writeVersion = 0; - ConflictSet::Impl cs{writeVersion, rand.next()}; + ConflictSet::Impl cs{writeVersion}; ReferenceImpl refImpl{writeVersion}; while (gArbitrary.hasEntropy()) { @@ -1271,10 +1041,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { cs.addWrites(writes, numWrites); refImpl.addWrites(writes, numWrites); } - bool success = checkCorrectness(cs.root, refImpl); - if (!success) { - abort(); - } + // bool success = checkCorrectness(cs.root, refImpl); + // if (!success) { + // abort(); + // } { int numReads = gArbitrary.bounded(10); int64_t v = writeVersion - gArbitrary.bounded(10);