diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 4385999..58009b6 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -5,6 +5,9 @@ #include #include #include +#include + +#define SHOW_PRIORITY 0 using Key = ConflictSet::Key; @@ -86,6 +89,9 @@ Node *createNode(const Key &key, Node *parent, int64_t pointVersion) { result->child[1] = nullptr; result->parent = parent; result->priority = fastRand(); +#if SHOW_PRIORITY + result->priority &= 0xff; +#endif result->len = key.len; memcpy(result + 1, key.p, key.len); return result; @@ -102,6 +108,34 @@ struct Iterator { int cmp; }; +struct StepwiseLastLeq { + Node *current; + Node *result; + const Key *key; + int resultC = -1; + int index; + std::strong_ordering c = std::strong_ordering::equal; + + StepwiseLastLeq() {} + StepwiseLastLeq(Node *current, Node *result, const Key &key, int index) + : current(current), result(result), key(&key), index(index) {} + + bool step() { + if (current == nullptr) { + return true; + } + c = *current <=> *key; + if (c == 0) { + result = current; + resultC = 0; + return true; + } + result = c < 0 ? current : result; + current = current->child[c < 0]; + return false; + } +}; + void lastLeqMulti(Node *root, std::span keys, Iterator *results) { assert(std::is_sorted(keys.begin(), keys.end())); @@ -109,49 +143,23 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { return; } - struct Coro { - Node *current; - Node *result; - const Key *key; - int resultC = -1; - int index; - std::strong_ordering c = std::strong_ordering::equal; - - Coro() {} - Coro(Node *current, Node *result, const Key &key, int index) - : current(current), result(result), key(&key), index(index) {} - - bool step() { - if (current == nullptr) { - return true; - } - c = *current <=> *key; - if (c == 0) { - result = current; - resultC = 0; - return true; - } - result = c < 0 ? current : result; - current = current->child[c < 0]; - return false; - } - }; - - auto coros = std::unique_ptr{new Coro[keys.size()]}; + auto stepwiseLastLeqs = + std::unique_ptr{new StepwiseLastLeq[keys.size()]}; // Descend until queries for front and back diverge Node *current = root; Node *resultP = nullptr; - auto coroBegin = Coro(current, resultP, keys.front(), -1); - auto coroEnd = Coro(current, resultP, keys.back(), -1); + auto stepwiseLastLeqBegin = + StepwiseLastLeq(current, resultP, keys.front(), -1); + auto stepwiseLastLeqEnd = StepwiseLastLeq(current, resultP, keys.back(), -1); for (;;) { - bool done1 = coroBegin.step(); - bool done2 = coroEnd.step(); - if (!done1 && !done2 && coroBegin.c == coroEnd.c) { - assert(coroBegin.current == coroEnd.current); - assert(coroBegin.result == coroEnd.result); - current = coroBegin.current; - resultP = coroBegin.result; + bool done1 = stepwiseLastLeqBegin.step(); + bool done2 = stepwiseLastLeqEnd.step(); + if (!done1 && !done2 && stepwiseLastLeqBegin.c == stepwiseLastLeqEnd.c) { + assert(stepwiseLastLeqBegin.current == stepwiseLastLeqEnd.current); + assert(stepwiseLastLeqBegin.result == stepwiseLastLeqEnd.result); + current = stepwiseLastLeqBegin.current; + resultP = stepwiseLastLeqBegin.result; } else { break; } @@ -159,12 +167,13 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { int index = 0; { - auto iter = coros.get(); + auto iter = stepwiseLastLeqs.get(); for (const auto &k : keys) { - *iter++ = Coro(current, resultP, k, index++); + *iter++ = StepwiseLastLeq(current, resultP, k, index++); } } - auto remaining = std::span(coros.get(), keys.size()); + auto remaining = + std::span(stepwiseLastLeqs.get(), keys.size()); while (remaining.size() > 0) { for (int i = 0; i < int(remaining.size());) { bool done = remaining[i].step(); @@ -247,10 +256,19 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { fprintf(file, "\n"); for (auto iter = extrema(node, false); iter != nullptr; iter = next(iter, true)) { - fprintf(file, " k_%.*s [label=\"k=\\\"%.*s\\\" m=%d v=%d r=%d\"];\n", + fprintf(file, + " k_%.*s [label=\"k=\\\"%.*s\\\" " +#if SHOW_PRIORITY + "p=%u " +#endif + "m=%d v=%d r=%d\"];\n", iter->len, (const char *)(iter + 1), iter->len, - (const char *)(iter + 1), int(iter->maxVersion), - int(iter->pointVersion), int(iter->rangeVersion)); + (const char *)(iter + 1), +#if SHOW_PRIORITY + iter->priority, +#endif + int(iter->maxVersion), int(iter->pointVersion), + int(iter->rangeVersion)); } for (int i = 0; i < printer.id; ++i) { fprintf(file, " null%d [shape=point];\n", i); @@ -261,9 +279,12 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { [[maybe_unused]] Key toKey(int n) { constexpr int kMaxLength = 4; - static unsigned char itoaBuf[kMaxLength]; + // TODO use arena allocation + static std::vector> *results = + new std::vector>; int i = kMaxLength; - memset(itoaBuf, '0', kMaxLength); + results->push_back(std::vector(kMaxLength, '0')); + uint8_t *itoaBuf = results->back().data(); do { itoaBuf[--i] = "0123456789abcdef"[n % 16]; n /= 16; @@ -272,6 +293,8 @@ void lastLeqMulti(Node *root, std::span keys, Iterator *results) { } // Recompute maxVersion, and propagate up the tree as necessary +// TODO interleave this? Will require careful analysis for correctness, and the +// performance gains may not be worth it. void updateMaxVersion(Node *n) { for (;;) { int64_t maxVersion = std::max(n->pointVersion, n->rangeVersion); @@ -352,61 +375,111 @@ struct ConflictSet::Impl { } } - void addWriteNaive(const WriteRange &write) { - // TODO handle non-singleton writes lol - Node **current = &root; - Node *parent = nullptr; - const auto &key = write.begin; - for (;;) { - if (*current == nullptr) { - auto *newNode = createNode(key, parent, write.writeVersion); - *current = newNode; - auto *prev = ::next(newNode, false); - assert(prev != nullptr); - assert(prev->rangeVersion <= write.writeVersion); - newNode->rangeVersion = prev->rangeVersion; - break; - } else { - // This is the key optimization - setting the max version on the way - // down the search path so we only have to do one traversal. - (*current)->maxVersion = - std::max((*current)->maxVersion, write.writeVersion); - auto c = key <=> **current; - if (c == 0) { - (*current)->pointVersion = write.writeVersion; - break; - } - parent = *current; - current = &((*current)->child[c > 0]); - } - } + struct StepwiseInsert { + // Search phase state. After this phase, the heap invariant may be violated + // for n->parent. + Node **current; + Node *parent; + const Key *key; + int64_t writeVersion; - auto *n = *current; - assert(n != nullptr); - for (;;) { - if (n->parent == nullptr) { - break; + // Rotation phase state. The heap invariant may be violated for n->parent. + // Once this phase is complete the heap invariant is restored for each + // n->parent encountered in a step of this phase. + Node *n; + Impl *impl; + + int state; + + StepwiseInsert() {} + StepwiseInsert(Node **root, const Key &key, int64_t writeVersion, + Impl *impl) + : current(root), parent(nullptr), key(&key), writeVersion(writeVersion), + impl(impl), state(0) {} + bool step() { + switch (state) { + // Search + case 0: { + if (*current == nullptr) { + auto *newNode = createNode(*key, parent, writeVersion); + *current = newNode; + auto *prev = ::next(newNode, false); + assert(prev != nullptr); + assert(prev->rangeVersion <= writeVersion); + newNode->rangeVersion = prev->rangeVersion; + state = 1; + n = *current; + return false; + } else { + // This is the key optimization - setting the max version on the way + // down the search path so we only have to do one traversal. + (*current)->maxVersion = + std::max((*current)->maxVersion, writeVersion); + auto c = *key <=> **current; + if (c == 0) { + (*current)->pointVersion = writeVersion; + state = 1; + n = *current; + return false; + } + parent = *current; + current = &((*current)->child[c > 0]); + } + return false; } - const bool dir = n == n->parent->child[1]; - assert(dir || n == n->parent->child[0]); - // p is the address of the pointer to n->parent in the tree - Node **p = n->parent->parent == nullptr - ? &root - : &n->parent->parent - ->child[n->parent->parent->child[1] == n->parent]; - assert(*p == n->parent); - if (n->parent->priority < n->priority) { - p = rotate(p, !dir); - n = (*p)->parent; - } else { - break; + // Rotate + case 1: { + if (n->parent == nullptr) { + return true; + } + const bool dir = n == n->parent->child[1]; + assert(dir || n == n->parent->child[0]); + // p is the address of the pointer to n->parent in the tree + Node **p = n->parent->parent == nullptr + ? &impl->root + : &n->parent->parent + ->child[n->parent->parent->child[1] == n->parent]; + assert(*p == n->parent); + if (n->parent->priority < n->priority) { + p = rotate(p, !dir); + n = (*p)->parent; + } else { + return true; + } + return false; + } + default: + __builtin_unreachable(); } } - } + }; void addWrites(const WriteRange *writes, int count) { - for (const auto &w : std::span(writes, count)) { - addWriteNaive(w); + auto stepwiseInserts = + std::unique_ptr{new StepwiseInsert[count]}; + for (int i = 0; i < count; ++i) { + // TODO handle non-singleton writes lol + assert(writes[i].end.len == 0); + + stepwiseInserts[i] = + StepwiseInsert{&root, writes[i].begin, writes[i].writeVersion, this}; + } + + // TODO Descend until queries for front and back diverge + + auto remaining = std::span(stepwiseInserts.get(), count); + while (remaining.size() > 0) { + for (int i = 0; i < int(remaining.size());) { + bool done = remaining[i].step(); + if (done) { + if (i != int(remaining.size()) - 1) { + remaining[i] = remaining.back(); + } + remaining = remaining.subspan(0, remaining.size() - 1); + } else { + ++i; + } + } } } @@ -464,13 +537,14 @@ ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept { int main(void) { int64_t writeVersion = 0; ConflictSet::Impl cs{writeVersion}; - for (int i = 0; i < 10; ++i) { - ConflictSet::WriteRange write; - write.begin = toKey(i); - write.end.len = 0; - write.writeVersion = ++writeVersion; - cs.addWrites(&write, 1); + constexpr int kNumKeys = 5; + ConflictSet::WriteRange write[kNumKeys]; + for (int i = 0; i < kNumKeys; ++i) { + write[i].begin = toKey(i); + write[i].end.len = 0; + write[i].writeVersion = ++writeVersion; } + cs.addWrites(write, kNumKeys); debugPrintDot(stdout, cs.root); } #endif \ No newline at end of file