From ad14db5d7c19faff36ab2a3d522e81af61c7b1b2 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Mon, 12 Feb 2024 12:43:36 -0800 Subject: [PATCH] Add skiplist --- Bench.cpp | 652 +++++++++++++++++++++++++++++++++++++++++++++++-- CMakeLists.txt | 4 + 2 files changed, 637 insertions(+), 19 deletions(-) diff --git a/Bench.cpp b/Bench.cpp index be23834..e47f5ab 100644 --- a/Bench.cpp +++ b/Bench.cpp @@ -6,33 +6,617 @@ #define ANKERL_NANOBENCH_IMPLEMENT #include "third_party/nanobench.h" -std::string toKey(int n) { - std::string result; - result.resize(32); - - for (int i = 0; i < 32; ++i) { - result[i] = n & (1 << (31 - i)) ? '1' : '0'; - } +std::span keyAfter(Arena &arena, std::span key) { + auto result = + std::span(new (arena) uint8_t[key.size() + 1], key.size() + 1); + memcpy(result.data(), key.data(), key.size()); + result[result.size() - 1] = 0; return result; } +namespace { + +using Version = int64_t; +#define force_inline __attribute__((always_inline)) +using StringRef = std::span; + +struct KeyRangeRef { + StringRef begin; + StringRef end; + KeyRangeRef() {} + KeyRangeRef(StringRef begin, StringRef end) : begin(begin), end(end) {} + KeyRangeRef(Arena &arena, StringRef begin) + : begin(begin), end(keyAfter(arena, begin)) {} +}; + +static thread_local uint32_t g_seed = 0; + +static inline int skfastrand() { + g_seed = g_seed * 1664525L + 1013904223L; + return g_seed; +} + +static int compare(const StringRef &a, const StringRef &b) { + int c = memcmp(a.data(), b.data(), std::min(a.size(), b.size())); + if (c < 0) + return -1; + if (c > 0) + return +1; + if (a.size() < b.size()) + return -1; + if (a.size() == b.size()) + return 0; + return +1; +} + +struct ReadConflictRange { + StringRef begin, end; + Version version; + + ReadConflictRange() {} + ReadConflictRange(StringRef begin, StringRef end, Version version) + : begin(begin), end(end), version(version) {} + bool operator<(const ReadConflictRange &rhs) const { + return compare(begin, rhs.begin) < 0; + } +}; + +class SkipList { +private: + static constexpr int MaxLevels = 26; + + int randomLevel() const { + uint32_t i = uint32_t(skfastrand()) >> (32 - (MaxLevels - 1)); + int level = 0; + while (i & 1) { + i >>= 1; + level++; + } + assert(level < MaxLevels); + return level; + } + + // Represent a node in the SkipList. The node has multiple (i.e., level) + // pointers to other nodes, and keeps a record of the max versions for each + // level. + struct Node { + int level() const { return nPointers - 1; } + uint8_t *value() { + return end() + nPointers * (sizeof(Node *) + sizeof(Version)); + } + int length() const { return valueLength; } + + // Returns the next node pointer at the given level. + Node *getNext(int level) { return *((Node **)end() + level); } + // Sets the next node pointer at the given level. + void setNext(int level, Node *n) { *((Node **)end() + level) = n; } + + // Returns the max version at the given level. + Version getMaxVersion(int i) const { + return ((Version *)(end() + nPointers * sizeof(Node *)))[i]; + } + // Sets the max version at the given level. + void setMaxVersion(int i, Version v) { + ((Version *)(end() + nPointers * sizeof(Node *)))[i] = v; + } + + // Return a node with initialized value but uninitialized pointers + // Memory layout: *this, (level+1) Node*, (level+1) Version, value + static Node *create(const StringRef &value, int level) { + int nodeSize = sizeof(Node) + value.size() + + (level + 1) * (sizeof(Node *) + sizeof(Version)); + + Node *n; + n = (Node *)new char[nodeSize]; + + n->nPointers = level + 1; + + n->valueLength = value.size(); + if (value.size() > 0) { + memcpy(n->value(), value.data(), value.size()); + } + return n; + } + + // pre: level>0, all lower level nodes between this and getNext(level) have + // correct maxversions + void calcVersionForLevel(int level) { + Node *end = getNext(level); + Version v = getMaxVersion(level - 1); + for (Node *x = getNext(level - 1); x != end; x = x->getNext(level - 1)) + v = std::max(v, x->getMaxVersion(level - 1)); + setMaxVersion(level, v); + } + + void destroy() { delete[](char *) this; } + + private: + int getNodeSize() const { + return sizeof(Node) + valueLength + + nPointers * (sizeof(Node *) + sizeof(Version)); + } + // Returns the first Node* pointer + uint8_t *end() { return (uint8_t *)(this + 1); } + uint8_t const *end() const { return (uint8_t const *)(this + 1); } + int nPointers, valueLength; + }; + + static force_inline bool less(const uint8_t *a, int aLen, const uint8_t *b, + int bLen) { + int c = memcmp(a, b, std::min(aLen, bLen)); + if (c < 0) + return true; + if (c > 0) + return false; + return aLen < bLen; + } + + Node *header; + + void destroy() { + Node *next, *x; + for (x = header; x; x = next) { + next = x->getNext(0); + x->destroy(); + } + } + +public: + // Points the location (i.e., Node*) that value would appear in the SkipList. + // If the "value" is in the list, then finger[0] points to that exact node; + // otherwise, the finger points to Nodes that the value should be inserted + // before. Note the SkipList organizes all nodes at level 0, higher levels + // contain jump pointers. + struct Finger { + Node *finger[MaxLevels]; // valid for levels >= level + int level = MaxLevels; + Node *x = nullptr; + Node *alreadyChecked = nullptr; + StringRef value; + + Finger() = default; + Finger(Node *header, const StringRef &ptr) : x(header), value(ptr) {} + + void init(const StringRef &value, Node *header) { + this->value = value; + x = header; + alreadyChecked = nullptr; + level = MaxLevels; + } + + // pre: !finished() + force_inline void prefetch() { + Node *next = x->getNext(0); + __builtin_prefetch(next); + } + + // pre: !finished() + // Advances the pointer at the current level to a Node that's >= finger's + // value if possible; or move to the next level (i.e., level--). Returns + // true if we have advanced to the next level + force_inline bool advance() { + Node *next = x->getNext(level - 1); + + if (next == alreadyChecked || + !less(next->value(), next->length(), value.data(), value.size())) { + alreadyChecked = next; + level--; + finger[level] = x; + return true; + } else { + x = next; + return false; + } + } + + // pre: !finished() + force_inline void nextLevel() { + while (!advance()) + ; + } + + force_inline bool finished() const { return level == 0; } + + // Returns if the finger value is found in the SkipList. + force_inline Node *found() const { + // valid after finished returns true + Node *n = finger[0]->getNext( + 0); // or alreadyChecked, but that is more easily invalidated + if (n && n->length() == value.size() && + !memcmp(n->value(), value.data(), value.size())) + return n; + else + return nullptr; + } + + StringRef getValue() const { + Node *n = finger[0]->getNext(0); + return n ? StringRef(n->value(), n->length()) : StringRef(); + } + }; + + // Returns the total number of nodes in the list. + int count() const { + int count = 0; + Node *x = header->getNext(0); + while (x) { + x = x->getNext(0); + count++; + } + return count; + } + + explicit SkipList(Version version = 0) { + header = Node::create(StringRef(), MaxLevels - 1); + for (int l = 0; l < MaxLevels; l++) { + header->setNext(l, nullptr); + header->setMaxVersion(l, version); + } + } + ~SkipList() { destroy(); } + SkipList(SkipList &&other) noexcept : header(other.header) { + other.header = nullptr; + } + void operator=(SkipList &&other) noexcept { + destroy(); + header = other.header; + other.header = nullptr; + } + void swap(SkipList &other) { std::swap(header, other.header); } + + void addConflictRanges(const Finger *fingers, int rangeCount, + Version *version) { + for (int r = rangeCount - 1; r >= 0; r--) { + const Finger &startF = fingers[r * 2]; + const Finger &endF = fingers[r * 2 + 1]; + + if (endF.found() == nullptr) + insert(endF, endF.finger[0]->getMaxVersion(0)); + + remove(startF, endF); + insert(startF, version[r]); + } + } + + void detectConflicts(ReadConflictRange *ranges, int count, + ConflictSet::Result *transactionConflictStatus) const { + const int M = 16; + int nextJob[M]; + CheckMax inProgress[M]; + if (!count) + return; + + int started = std::min(M, count); + for (int i = 0; i < started; i++) { + inProgress[i].init(ranges[i], header, transactionConflictStatus + i); + nextJob[i] = i + 1; + } + nextJob[started - 1] = 0; + + int prevJob = started - 1; + int job = 0; + // vtune: 340 parts + while (true) { + if (inProgress[job].advance()) { + if (started == count) { + if (prevJob == job) + break; + nextJob[prevJob] = nextJob[job]; + job = prevJob; + } else { + int temp = started++; + inProgress[job].init(ranges[temp], header, + transactionConflictStatus + temp); + } + } + prevJob = job; + job = nextJob[job]; + } + } + + void find(const StringRef *values, Finger *results, int *temp, int count) { + // Relying on the ordering of values, descend until the values aren't all in + // the same part of the tree + + // vtune: 11 parts + results[0].init(values[0], header); + const StringRef &endValue = values[count - 1]; + while (results[0].level > 1) { + results[0].nextLevel(); + Node *ac = results[0].alreadyChecked; + if (ac && + less(ac->value(), ac->length(), endValue.data(), endValue.size())) + break; + } + + // Init all the other fingers to start descending where we stopped + // the first one + + // SOMEDAY: this loop showed up on vtune, could be faster? + // vtune: 8 parts + int startLevel = results[0].level + 1; + Node *x = startLevel < MaxLevels ? results[0].finger[startLevel] : header; + for (int i = 1; i < count; i++) { + results[i].level = startLevel; + results[i].x = x; + results[i].alreadyChecked = nullptr; + results[i].value = values[i]; + for (int j = startLevel; j < MaxLevels; j++) + results[i].finger[j] = results[0].finger[j]; + } + + int *nextJob = temp; + for (int i = 0; i < count - 1; i++) + nextJob[i] = i + 1; + nextJob[count - 1] = 0; + + int prevJob = count - 1; + int job = 0; + + // vtune: 225 parts + while (true) { + Finger *f = &results[job]; + f->advance(); + if (f->finished()) { + if (prevJob == job) + break; + nextJob[prevJob] = nextJob[job]; + } else { + f->prefetch(); + prevJob = job; + } + job = nextJob[job]; + } + } + + int removeBefore(Version v, Finger &f, int nodeCount) { + // f.x, f.alreadyChecked? + + int removedCount = 0; + bool wasAbove = true; + while (nodeCount--) { + Node *x = f.finger[0]->getNext(0); + if (!x) + break; + + // double prefetch gives +25% speed (single threaded) + Node *next = x->getNext(0); + __builtin_prefetch(next); + next = x->getNext(1); + __builtin_prefetch(next); + + bool isAbove = x->getMaxVersion(0) >= v; + if (isAbove || wasAbove) { // f.nextItem + for (int l = 0; l <= x->level(); l++) + f.finger[l] = x; + } else { // f.eraseItem + removedCount++; + for (int l = 0; l <= x->level(); l++) + f.finger[l]->setNext(l, x->getNext(l)); + for (int i = 1; i <= x->level(); i++) + f.finger[i]->setMaxVersion( + i, std::max(f.finger[i]->getMaxVersion(i), x->getMaxVersion(i))); + x->destroy(); + } + wasAbove = isAbove; + } + + return removedCount; + } + +private: + void remove(const Finger &start, const Finger &end) { + if (start.finger[0] == end.finger[0]) + return; + + Node *x = start.finger[0]->getNext(0); + + // vtune says: this loop is the expensive parts (6 parts) + for (int i = 0; i < MaxLevels; i++) + if (start.finger[i] != end.finger[i]) + start.finger[i]->setNext(i, end.finger[i]->getNext(i)); + + while (true) { + Node *next = x->getNext(0); + x->destroy(); + if (x == end.finger[0]) + break; + x = next; + } + } + + void insert(const Finger &f, Version version) { + int level = randomLevel(); + // std::cout << std::string((const char*)value,length) << " level: " << + // level << std::endl; + Node *x = Node::create(f.value, level); + x->setMaxVersion(0, version); + for (int i = 0; i <= level; i++) { + x->setNext(i, f.finger[i]->getNext(i)); + f.finger[i]->setNext(i, x); + } + // vtune says: this loop is the costly part of this function + for (int i = 1; i <= level; i++) { + f.finger[i]->calcVersionForLevel(i); + x->calcVersionForLevel(i); + } + for (int i = level + 1; i < MaxLevels; i++) { + Version v = f.finger[i]->getMaxVersion(i); + if (v >= version) + break; + f.finger[i]->setMaxVersion(i, version); + } + } + + struct CheckMax { + Finger start, end; + Version version; + ConflictSet::Result *result; + int state; + + void init(const ReadConflictRange &r, Node *header, + ConflictSet::Result *result) { + this->start.init(r.begin, header); + this->end.init(r.end, header); + this->version = r.version; + this->state = 0; + this->result = result; + } + + bool noConflict() const { return true; } + bool conflict() { + *result = ConflictSet::Conflict; + return true; + } + + // Return true if finished + force_inline bool advance() { + if (*result == ConflictSet::TooOld) { + return true; + } + switch (state) { + case 0: + // find where start and end fingers diverge + while (true) { + if (!start.advance()) { + start.prefetch(); + return false; + } + end.x = start.x; + while (!end.advance()) + ; + + int l = start.level; + if (start.finger[l] != end.finger[l]) + break; + // accept if the range spans the check range, but does not have a + // greater version + if (start.finger[l]->getMaxVersion(l) <= version) + return noConflict(); + if (l == 0) + return conflict(); + } + state = 1; + case 1: { + // check the end side of the pyramid + Node *e = end.finger[end.level]; + while (e->getMaxVersion(end.level) > version) { + if (end.finished()) + return conflict(); + end.nextLevel(); + Node *f = end.finger[end.level]; + while (e != f) { + if (e->getMaxVersion(end.level) > version) + return conflict(); + e = e->getNext(end.level); + } + } + + // check the start side of the pyramid + Node *s = end.finger[start.level]; + while (true) { + Node *nextS = start.finger[start.level]->getNext(start.level); + Node *p = nextS; + while (p != s) { + if (p->getMaxVersion(start.level) > version) + return conflict(); + p = p->getNext(start.level); + } + if (start.finger[start.level]->getMaxVersion(start.level) <= version) + return noConflict(); + s = nextS; + if (start.finished()) { + if (nextS->length() == start.value.size() && + !memcmp(nextS->value(), start.value.data(), start.value.size())) + return noConflict(); + else + return conflict(); + } + start.nextLevel(); + } + } + default: + __builtin_unreachable(); + } + } + }; +}; + +struct SkipListConflictSet { + int64_t oldestVersion; + SkipListConflictSet(int64_t oldestVersion) : oldestVersion(oldestVersion) {} + void check(const ConflictSet::ReadRange *reads, ConflictSet::Result *results, + int count) const { + Arena arena; + auto *ranges = new (arena) ReadConflictRange[count]; + for (int i = 0; i < count; ++i) { + ranges[i].begin = {reads[i].begin.p, size_t(reads[i].begin.len)}; + ranges[i].end = reads[i].end.len > 0 + ? std::span{reads[i].end.p, + size_t(reads[i].end.len)} + : keyAfter(arena, ranges[i].begin); + ranges[i].version = reads[i].readVersion; + if (reads[i].readVersion < oldestVersion) { + results[i] = ConflictSet::TooOld; + } else { + results[i] = ConflictSet::Commit; + } + } + skipList.detectConflicts(ranges, count, results); + } + + void addWrites(const ConflictSet::WriteRange *writes, int count) { + Arena arena; + const int stringCount = count * 2; + + const int stripeSize = 16; + SkipList::Finger fingers[stripeSize]; + int temp[stripeSize]; + int stripes = (stringCount + stripeSize - 1) / stripeSize; + StringRef values[stripeSize]; + int64_t writeVersions[stripeSize / 2]; + int ss = stringCount - (stripes - 1) * stripeSize; + for (int s = stripes - 1; s >= 0; s--) { + for (int i = 0; i * 2 < ss; ++i) { + const auto &w = writes[s * stripeSize / 2 + i]; + values[i * 2] = {w.begin.p, size_t(w.begin.len)}; + if (w.end.len > 0) { + values[i * 2 + 1] = {w.end.p, size_t(w.end.len)}; + } else { + values[i * 2 + 1] = keyAfter(arena, values[i * 2]); + } + writeVersions[i] = w.writeVersion; + } + skipList.find(values, fingers, temp, ss); + skipList.addConflictRanges(fingers, ss / 2, writeVersions); + ss = stripeSize; + } + } + +private: + SkipList skipList; +}; + +} // namespace + constexpr int kNumKeys = 100000; -// A range read, a point read, and a point write. Range writes can erase -// keys, and we don't want to change the number of keys stored in the -// conflict set. + constexpr int kOpsPerTx = 100; constexpr int kPrefixLen = 0; std::span makeKey(Arena &arena, int index) { - auto result = std::span{new (arena) uint8_t[4], 4}; + auto result = + std::span{new (arena) uint8_t[4 + kPrefixLen], 4 + kPrefixLen}; index = __builtin_bswap32(index); - memcpy(result.data(), &index, 4); + memset(result.data(), 0, kPrefixLen); + memcpy(result.data() + kPrefixLen, &index, 4); return result; } -template void benchConflictSet() { +template void benchConflictSet(const std::string &name) { ankerl::nanobench::Bench bench; + bench.minEpochIterations(1000); ConflictSet_ cs{0}; bench.batch(kOpsPerTx); @@ -108,7 +692,7 @@ template void benchConflictSet() { auto *results = new (arena) ConflictSet::Result[kOpsPerTx]; - bench.run("radix tree (point reads)", + bench.run(name + " (point reads)", [&]() { cs.check(reads.data(), results, kOpsPerTx); }); } @@ -129,7 +713,7 @@ template void benchConflictSet() { auto *results = new (arena) ConflictSet::Result[kOpsPerTx]; - bench.run("radix tree (range reads)", + bench.run(name + " (range reads)", [&]() { cs.check(reads.data(), results, kOpsPerTx); }); } @@ -145,7 +729,7 @@ template void benchConflictSet() { ++iter; } - bench.run("radix tree (point writes)", [&]() { + bench.run(name + " (point writes)", [&]() { auto v = ++version; for (auto &w : writes) { w.writeVersion = v; @@ -157,7 +741,7 @@ template void benchConflictSet() { { std::vector writes; auto iter = points.begin(); - for (int i = 0; i < kOpsPerTx - 1; ++i) { + for (int i = 0; i < kOpsPerTx; ++i) { auto begin = *iter++; auto end = *iter++; ConflictSet::WriteRange w; @@ -168,7 +752,7 @@ template void benchConflictSet() { writes.push_back(w); } - bench.run("radix tree (range writes)", [&]() { + bench.run(name + " (range writes)", [&]() { auto v = ++version; for (auto &w : writes) { w.writeVersion = v; @@ -178,4 +762,34 @@ template void benchConflictSet() { } } -int main(void) { benchConflictSet(); } \ No newline at end of file +int main(void) { + benchConflictSet("skip list"); + benchConflictSet("radix tree"); +} + +// extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { +// TestDriver driver{data, size}; + +// for (;;) { +// bool done = driver.next(); +// if (!driver.ok) { +// // debugPrintDot(stdout, driver.cs.root); +// // fflush(stdout); +// abort(); +// } +// #if DEBUG_VERBOSE && !defined(NDEBUG) +// fprintf(stderr, "Check correctness\n"); +// #endif +// // bool success = checkCorrectness(driver.cs.root); +// // if (!success) { +// // debugPrintDot(stdout, driver.cs.root); +// // fflush(stdout); +// // abort(); +// // } +// if (done) { +// break; +// } +// } + +// return 0; +// } \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 9463824..a6033ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,6 +180,10 @@ if(BUILD_TESTING) add_executable(conflict_set_bench Bench.cpp) target_link_libraries(conflict_set_bench PRIVATE ${PROJECT_NAME}) + # target_compile_options(conflict_set_bench PRIVATE + # "-fsanitize=address,undefined,fuzzer") + # target_link_options(conflict_set_bench PRIVATE + # "-fsanitize=address,undefined,fuzzer") endif() # packaging