WIP - need to use worklist algorithm for rotations

This commit is contained in:
2024-01-18 14:38:22 -08:00
parent b2de5c82e3
commit 321993baab

View File

@@ -4,10 +4,12 @@
#include <compare> #include <compare>
#include <memory> #include <memory>
#include <span> #include <span>
#include <string_view>
#include <utility> #include <utility>
#include <vector> #include <vector>
#define SHOW_PRIORITY 0 #define SHOW_PRIORITY 1
#define DEBUG 1
using Key = ConflictSet::Key; using Key = ConflictSet::Key;
@@ -311,7 +313,7 @@ void updateMaxVersion(Node *n) {
} }
} }
Node **rotate(Node **node, bool dir) { void rotate(Node **node, bool dir) {
// diagram shown for dir == true // diagram shown for dir == true
/* n /* n
/ /
@@ -342,7 +344,64 @@ Node **rotate(Node **node, bool dir) {
*/ */
updateMaxVersion(n); updateMaxVersion(n);
updateMaxVersion(l); updateMaxVersion(l);
return &l->child[dir]; }
bool checkInvariants(Node *node) {
bool success = true;
// Check bst invariant
std::vector<std::string_view> keys;
for (auto iter = extrema(node, false); iter != nullptr;
iter = next(iter, true)) {
keys.push_back(std::string_view((char *)(iter + 1), iter->len));
for (int i = 0; i < 2; ++i) {
if (iter->child[i] != nullptr) {
if (iter->priority < iter->child[i]->priority) {
fprintf(stderr, "%.*s has priority < its child %.*s\n", iter->len,
(const char *)(iter + 1), iter->child[i]->len,
(const char *)(iter->child[i] + 1));
success = false;
}
}
}
}
assert(std::is_sorted(keys.begin(), keys.end()));
// TODO more invariants
return success;
}
template <class Stepwise>
void runInterleaved(std::span<Stepwise> remaining, int stepLimit = -1) {
while (remaining.size() > 0) {
for (int i = 0; i < int(remaining.size());) {
if (stepLimit-- == 0) {
return;
}
bool done = remaining[i].step();
if (done) {
if (i != int(remaining.size()) - 1) {
remaining[i] = remaining.back();
}
remaining = remaining.subspan(0, remaining.size() - 1);
} else {
++i;
}
}
}
};
template <class Stepwise>
void runSequential(std::span<Stepwise> remaining, int stepLimit = -1) {
for (auto &r : remaining) {
if (stepLimit-- == 0) {
return;
}
while (!r.step()) {
if (stepLimit-- == 0) {
return;
}
}
}
} }
} // namespace } // namespace
@@ -377,81 +436,86 @@ struct ConflictSet::Impl {
struct StepwiseInsert { struct StepwiseInsert {
// Search phase state. After this phase, the heap invariant may be violated // Search phase state. After this phase, the heap invariant may be violated
// for n->parent. // for (*current)->parent.
Node **current; Node **current;
Node *parent; Node *parent;
const Key *key; const Key *key;
int64_t writeVersion; int64_t writeVersion;
StepwiseInsert() {}
StepwiseInsert(Node **root, const Key &key, int64_t writeVersion)
: current(root), parent(nullptr), key(&key),
writeVersion(writeVersion) {}
bool step() {
#if DEBUG
fprintf(stderr, "Step insert of %.*s. At node: %.*s\n", key->len, key->p,
(*current) ? (*current)->len : 7,
(*current) ? (const char *)((*current) + 1) : "nullptr");
#endif
if (*current == nullptr) {
auto *newNode = createNode(*key, parent, writeVersion);
*current = newNode;
// We could interleave the iteration in next, but we'd need a careful
// analysis for correctness and it's unlikely to be worthwhile.
auto *prev = ::next(newNode, false);
assert(prev != nullptr);
assert(prev->rangeVersion <= writeVersion);
newNode->rangeVersion = prev->rangeVersion;
return true;
} else {
// This is the key optimization - setting the max version on the way
// down the search path so we only have to do one traversal.
(*current)->maxVersion = std::max((*current)->maxVersion, writeVersion);
auto c = *key <=> **current;
if (c == 0) {
(*current)->pointVersion = writeVersion;
return true;
}
parent = *current;
current = &((*current)->child[c > 0]);
}
return false;
}
};
struct StepwiseRotate {
// Rotation phase state. The heap invariant may be violated for n->parent. // Rotation phase state. The heap invariant may be violated for n->parent.
// Once this phase is complete the heap invariant is restored for each // Once this phase is complete the heap invariant is restored for each
// n->parent encountered in a step of this phase. // n->parent encountered in a step of this phase.
Node *n; Node *n;
Impl *impl; Impl *impl;
enum Phase { Search, Rotate }; StepwiseRotate() {}
Phase phase; StepwiseRotate(Node *n, Impl *impl) : n(n), impl(impl) {}
StepwiseInsert() {}
StepwiseInsert(Node **root, const Key &key, int64_t writeVersion,
Impl *impl)
: current(root), parent(nullptr), key(&key), writeVersion(writeVersion),
impl(impl), phase(Search) {}
bool step() { bool step() {
switch (phase) { #if DEBUG
case Search: { fprintf(stderr, "Step rotate %.*s\n", n->len, (const char *)(n + 1));
if (*current == nullptr) { #endif
auto *newNode = createNode(*key, parent, writeVersion); if (n->parent == nullptr) {
*current = newNode; return true;
// We could interleave the iteration in next, but we'd need a careful
// analysis for correctness and it's unlikely to be worthwhile.
auto *prev = ::next(newNode, false);
assert(prev != nullptr);
assert(prev->rangeVersion <= writeVersion);
newNode->rangeVersion = prev->rangeVersion;
phase = Rotate;
n = *current;
return false;
} else {
// This is the key optimization - setting the max version on the way
// down the search path so we only have to do one traversal.
(*current)->maxVersion =
std::max((*current)->maxVersion, writeVersion);
auto c = *key <=> **current;
if (c == 0) {
(*current)->pointVersion = writeVersion;
phase = Rotate;
n = *current;
return false;
}
parent = *current;
current = &((*current)->child[c > 0]);
}
return false;
} }
case Rotate: { const bool dir = n == n->parent->child[1];
if (n->parent == nullptr) { assert(dir || n == n->parent->child[0]);
return true; // p is the address of the pointer to n->parent in the tree
} Node **p = n->parent->parent == nullptr
const bool dir = n == n->parent->child[1]; ? &impl->root
assert(dir || n == n->parent->child[0]); : &n->parent->parent
// p is the address of the pointer to n->parent in the tree ->child[n->parent->parent->child[1] == n->parent];
Node **p = n->parent->parent == nullptr assert(*p == n->parent);
? &impl->root if (n->parent->priority < n->priority) {
: &n->parent->parent #if DEBUG
->child[n->parent->parent->child[1] == n->parent]; fprintf(stderr, "\trotate %s\n", !dir ? "right" : "left");
assert(*p == n->parent); #endif
if (n->parent->priority < n->priority) { rotate(p, !dir);
p = rotate(p, !dir); // assert((*p)->child[0] == nullptr || (*p)->priority >=
n = (*p)->parent; // (*p)->child[0]->priority); assert((*p)->child[1] == nullptr ||
} else { // (*p)->priority >= (*p)->child[1]->priority);
return true; n = *p;
} } else {
return false; return true;
}
default:
__builtin_unreachable();
} }
return false;
} }
}; };
@@ -463,25 +527,19 @@ struct ConflictSet::Impl {
assert(writes[i].end.len == 0); assert(writes[i].end.len == 0);
stepwiseInserts[i] = stepwiseInserts[i] =
StepwiseInsert{&root, writes[i].begin, writes[i].writeVersion, this}; StepwiseInsert{&root, writes[i].begin, writes[i].writeVersion};
} }
// TODO Descend until queries for front and back diverge // TODO Descend until queries for front and back diverge
auto remaining = std::span<StepwiseInsert>(stepwiseInserts.get(), count); runInterleaved(std::span<StepwiseInsert>(stepwiseInserts.get(), count));
while (remaining.size() > 0) {
for (int i = 0; i < int(remaining.size());) { auto stepwiseRotates =
bool done = remaining[i].step(); std::unique_ptr<StepwiseRotate[]>{new StepwiseRotate[count]};
if (done) { for (int i = 0; i < count; ++i) {
if (i != int(remaining.size()) - 1) { stepwiseRotates[i] = StepwiseRotate{*stepwiseInserts[i].current, this};
remaining[i] = remaining.back();
}
remaining = remaining.subspan(0, remaining.size() - 1);
} else {
++i;
}
}
} }
runSequential(std::span<StepwiseRotate>(stepwiseRotates.get(), count), 7);
} }
void setOldestVersion(int64_t oldestVersion) { void setOldestVersion(int64_t oldestVersion) {
@@ -538,7 +596,7 @@ ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept {
int main(void) { int main(void) {
int64_t writeVersion = 0; int64_t writeVersion = 0;
ConflictSet::Impl cs{writeVersion}; ConflictSet::Impl cs{writeVersion};
constexpr int kNumKeys = 5; constexpr int kNumKeys = 3;
ConflictSet::WriteRange write[kNumKeys]; ConflictSet::WriteRange write[kNumKeys];
for (int i = 0; i < kNumKeys; ++i) { for (int i = 0; i < kNumKeys; ++i) {
write[i].begin = toKey(i); write[i].begin = toKey(i);
@@ -547,5 +605,6 @@ int main(void) {
} }
cs.addWrites(write, kNumKeys); cs.addWrites(write, kNumKeys);
debugPrintDot(stdout, cs.root); debugPrintDot(stdout, cs.root);
checkInvariants(cs.root);
} }
#endif #endif