#include "ConflictSet.h" #include "Internal.h" #include #include #include #include #define ANKERL_NANOBENCH_IMPLEMENT #include "third_party/nanobench.h" std::span keyAfter(Arena &arena, std::span key) { auto result = std::span(new (arena) uint8_t[key.size() + 1], key.size() + 1); memcpy(result.data(), key.data(), key.size()); result[result.size() - 1] = 0; return result; } namespace { using Version = int64_t; #define force_inline __attribute__((always_inline)) using StringRef = std::span; struct KeyRangeRef { StringRef begin; StringRef end; KeyRangeRef() {} KeyRangeRef(StringRef begin, StringRef end) : begin(begin), end(end) {} KeyRangeRef(Arena &arena, StringRef begin) : begin(begin), end(keyAfter(arena, begin)) {} }; static thread_local uint32_t g_seed = 0; static inline int skfastrand() { g_seed = g_seed * 1664525L + 1013904223L; return g_seed; } static int compare(const StringRef &a, const StringRef &b) { int c = memcmp(a.data(), b.data(), std::min(a.size(), b.size())); if (c < 0) return -1; if (c > 0) return +1; if (a.size() < b.size()) return -1; if (a.size() == b.size()) return 0; return +1; } struct ReadConflictRange { StringRef begin, end; Version version; ReadConflictRange() {} ReadConflictRange(StringRef begin, StringRef end, Version version) : begin(begin), end(end), version(version) {} bool operator<(const ReadConflictRange &rhs) const { return compare(begin, rhs.begin) < 0; } }; class SkipList { private: static constexpr int MaxLevels = 26; int randomLevel() const { uint32_t i = uint32_t(skfastrand()) >> (32 - (MaxLevels - 1)); int level = 0; while (i & 1) { i >>= 1; level++; } assert(level < MaxLevels); return level; } // Represent a node in the SkipList. The node has multiple (i.e., level) // pointers to other nodes, and keeps a record of the max versions for each // level. struct Node { int level() const { return nPointers - 1; } uint8_t *value() { return end() + nPointers * (sizeof(Node *) + sizeof(Version)); } int length() const { return valueLength; } // Returns the next node pointer at the given level. Node *getNext(int level) { return *((Node **)end() + level); } // Sets the next node pointer at the given level. void setNext(int level, Node *n) { *((Node **)end() + level) = n; } // Returns the max version at the given level. Version getMaxVersion(int i) const { return ((Version *)(end() + nPointers * sizeof(Node *)))[i]; } // Sets the max version at the given level. void setMaxVersion(int i, Version v) { ((Version *)(end() + nPointers * sizeof(Node *)))[i] = v; } // Return a node with initialized value but uninitialized pointers // Memory layout: *this, (level+1) Node*, (level+1) Version, value static Node *create(const StringRef &value, int level) { int nodeSize = sizeof(Node) + value.size() + (level + 1) * (sizeof(Node *) + sizeof(Version)); Node *n; n = (Node *)new char[nodeSize]; n->nPointers = level + 1; n->valueLength = value.size(); if (value.size() > 0) { memcpy(n->value(), value.data(), value.size()); } return n; } // pre: level>0, all lower level nodes between this and getNext(level) have // correct maxversions void calcVersionForLevel(int level) { Node *end = getNext(level); Version v = getMaxVersion(level - 1); for (Node *x = getNext(level - 1); x != end; x = x->getNext(level - 1)) v = std::max(v, x->getMaxVersion(level - 1)); setMaxVersion(level, v); } void destroy() { delete[](char *) this; } private: int getNodeSize() const { return sizeof(Node) + valueLength + nPointers * (sizeof(Node *) + sizeof(Version)); } // Returns the first Node* pointer uint8_t *end() { return (uint8_t *)(this + 1); } uint8_t const *end() const { return (uint8_t const *)(this + 1); } int nPointers, valueLength; }; static force_inline bool less(const uint8_t *a, int aLen, const uint8_t *b, int bLen) { int c = memcmp(a, b, std::min(aLen, bLen)); if (c < 0) return true; if (c > 0) return false; return aLen < bLen; } Node *header; void destroy() { Node *next, *x; for (x = header; x; x = next) { next = x->getNext(0); x->destroy(); } } public: // Points the location (i.e., Node*) that value would appear in the SkipList. // If the "value" is in the list, then finger[0] points to that exact node; // otherwise, the finger points to Nodes that the value should be inserted // before. Note the SkipList organizes all nodes at level 0, higher levels // contain jump pointers. struct Finger { Node *finger[MaxLevels]; // valid for levels >= level int level = MaxLevels; Node *x = nullptr; Node *alreadyChecked = nullptr; StringRef value; Finger() = default; Finger(Node *header, const StringRef &ptr) : x(header), value(ptr) {} void init(const StringRef &value, Node *header) { this->value = value; x = header; alreadyChecked = nullptr; level = MaxLevels; } // pre: !finished() force_inline void prefetch() { Node *next = x->getNext(0); __builtin_prefetch(next); } // pre: !finished() // Advances the pointer at the current level to a Node that's >= finger's // value if possible; or move to the next level (i.e., level--). Returns // true if we have advanced to the next level force_inline bool advance() { Node *next = x->getNext(level - 1); if (next == alreadyChecked || !less(next->value(), next->length(), value.data(), value.size())) { alreadyChecked = next; level--; finger[level] = x; return true; } else { x = next; return false; } } // pre: !finished() force_inline void nextLevel() { while (!advance()) ; } force_inline bool finished() const { return level == 0; } // Returns if the finger value is found in the SkipList. force_inline Node *found() const { // valid after finished returns true Node *n = finger[0]->getNext( 0); // or alreadyChecked, but that is more easily invalidated if (n && n->length() == value.size() && !memcmp(n->value(), value.data(), value.size())) return n; else return nullptr; } StringRef getValue() const { Node *n = finger[0]->getNext(0); return n ? StringRef(n->value(), n->length()) : StringRef(); } }; // Returns the total number of nodes in the list. int count() const { int count = 0; Node *x = header->getNext(0); while (x) { x = x->getNext(0); count++; } return count; } explicit SkipList(Version version = 0) { header = Node::create(StringRef(), MaxLevels - 1); for (int l = 0; l < MaxLevels; l++) { header->setNext(l, nullptr); header->setMaxVersion(l, version); } } ~SkipList() { destroy(); } SkipList(SkipList &&other) noexcept : header(other.header) { other.header = nullptr; } void operator=(SkipList &&other) noexcept { destroy(); header = other.header; other.header = nullptr; } void swap(SkipList &other) { std::swap(header, other.header); } void addConflictRanges(const Finger *fingers, int rangeCount, Version *version) { for (int r = rangeCount - 1; r >= 0; r--) { const Finger &startF = fingers[r * 2]; const Finger &endF = fingers[r * 2 + 1]; if (endF.found() == nullptr) insert(endF, endF.finger[0]->getMaxVersion(0)); remove(startF, endF); insert(startF, version[r]); } } void detectConflicts(ReadConflictRange *ranges, int count, ConflictSet::Result *transactionConflictStatus) const { const int M = 16; int nextJob[M]; CheckMax inProgress[M]; if (!count) return; int started = std::min(M, count); for (int i = 0; i < started; i++) { inProgress[i].init(ranges[i], header, transactionConflictStatus + i); nextJob[i] = i + 1; } nextJob[started - 1] = 0; int prevJob = started - 1; int job = 0; // vtune: 340 parts while (true) { if (inProgress[job].advance()) { if (started == count) { if (prevJob == job) break; nextJob[prevJob] = nextJob[job]; job = prevJob; } else { int temp = started++; inProgress[job].init(ranges[temp], header, transactionConflictStatus + temp); } } prevJob = job; job = nextJob[job]; } } void find(const StringRef *values, Finger *results, int *temp, int count) { // Relying on the ordering of values, descend until the values aren't all in // the same part of the tree // vtune: 11 parts results[0].init(values[0], header); const StringRef &endValue = values[count - 1]; while (results[0].level > 1) { results[0].nextLevel(); Node *ac = results[0].alreadyChecked; if (ac && less(ac->value(), ac->length(), endValue.data(), endValue.size())) break; } // Init all the other fingers to start descending where we stopped // the first one // SOMEDAY: this loop showed up on vtune, could be faster? // vtune: 8 parts int startLevel = results[0].level + 1; Node *x = startLevel < MaxLevels ? results[0].finger[startLevel] : header; for (int i = 1; i < count; i++) { results[i].level = startLevel; results[i].x = x; results[i].alreadyChecked = nullptr; results[i].value = values[i]; for (int j = startLevel; j < MaxLevels; j++) results[i].finger[j] = results[0].finger[j]; } int *nextJob = temp; for (int i = 0; i < count - 1; i++) nextJob[i] = i + 1; nextJob[count - 1] = 0; int prevJob = count - 1; int job = 0; // vtune: 225 parts while (true) { Finger *f = &results[job]; f->advance(); if (f->finished()) { if (prevJob == job) break; nextJob[prevJob] = nextJob[job]; } else { f->prefetch(); prevJob = job; } job = nextJob[job]; } } int removeBefore(Version v, Finger &f, int nodeCount) { // f.x, f.alreadyChecked? int removedCount = 0; bool wasAbove = true; while (nodeCount--) { Node *x = f.finger[0]->getNext(0); if (!x) break; // double prefetch gives +25% speed (single threaded) Node *next = x->getNext(0); __builtin_prefetch(next); next = x->getNext(1); __builtin_prefetch(next); bool isAbove = x->getMaxVersion(0) >= v; if (isAbove || wasAbove) { // f.nextItem for (int l = 0; l <= x->level(); l++) f.finger[l] = x; } else { // f.eraseItem removedCount++; for (int l = 0; l <= x->level(); l++) f.finger[l]->setNext(l, x->getNext(l)); for (int i = 1; i <= x->level(); i++) f.finger[i]->setMaxVersion( i, std::max(f.finger[i]->getMaxVersion(i), x->getMaxVersion(i))); x->destroy(); } wasAbove = isAbove; } return removedCount; } private: void remove(const Finger &start, const Finger &end) { if (start.finger[0] == end.finger[0]) return; Node *x = start.finger[0]->getNext(0); // vtune says: this loop is the expensive parts (6 parts) for (int i = 0; i < MaxLevels; i++) if (start.finger[i] != end.finger[i]) start.finger[i]->setNext(i, end.finger[i]->getNext(i)); while (true) { Node *next = x->getNext(0); x->destroy(); if (x == end.finger[0]) break; x = next; } } void insert(const Finger &f, Version version) { int level = randomLevel(); // std::cout << std::string((const char*)value,length) << " level: " << // level << std::endl; Node *x = Node::create(f.value, level); x->setMaxVersion(0, version); for (int i = 0; i <= level; i++) { x->setNext(i, f.finger[i]->getNext(i)); f.finger[i]->setNext(i, x); } // vtune says: this loop is the costly part of this function for (int i = 1; i <= level; i++) { f.finger[i]->calcVersionForLevel(i); x->calcVersionForLevel(i); } for (int i = level + 1; i < MaxLevels; i++) { Version v = f.finger[i]->getMaxVersion(i); if (v >= version) break; f.finger[i]->setMaxVersion(i, version); } } struct CheckMax { Finger start, end; Version version; ConflictSet::Result *result; int state; void init(const ReadConflictRange &r, Node *header, ConflictSet::Result *result) { this->start.init(r.begin, header); this->end.init(r.end, header); this->version = r.version; this->state = 0; this->result = result; } bool noConflict() const { return true; } bool conflict() { *result = ConflictSet::Conflict; return true; } // Return true if finished force_inline bool advance() { if (*result == ConflictSet::TooOld) { return true; } switch (state) { case 0: // find where start and end fingers diverge while (true) { if (!start.advance()) { start.prefetch(); return false; } end.x = start.x; while (!end.advance()) ; int l = start.level; if (start.finger[l] != end.finger[l]) break; // accept if the range spans the check range, but does not have a // greater version if (start.finger[l]->getMaxVersion(l) <= version) return noConflict(); if (l == 0) return conflict(); } state = 1; case 1: { // check the end side of the pyramid Node *e = end.finger[end.level]; while (e->getMaxVersion(end.level) > version) { if (end.finished()) return conflict(); end.nextLevel(); Node *f = end.finger[end.level]; while (e != f) { if (e->getMaxVersion(end.level) > version) return conflict(); e = e->getNext(end.level); } } // check the start side of the pyramid Node *s = end.finger[start.level]; while (true) { Node *nextS = start.finger[start.level]->getNext(start.level); Node *p = nextS; while (p != s) { if (p->getMaxVersion(start.level) > version) return conflict(); p = p->getNext(start.level); } if (start.finger[start.level]->getMaxVersion(start.level) <= version) return noConflict(); s = nextS; if (start.finished()) { if (nextS->length() == start.value.size() && !memcmp(nextS->value(), start.value.data(), start.value.size())) return noConflict(); else return conflict(); } start.nextLevel(); } } default: __builtin_unreachable(); } } }; }; struct SkipListConflictSet { SkipListConflictSet(int64_t oldestVersion) : oldestVersion(oldestVersion), skipList(oldestVersion) {} void check(const ConflictSet::ReadRange *reads, ConflictSet::Result *results, int count) const { Arena arena; auto *ranges = new (arena) ReadConflictRange[count]; for (int i = 0; i < count; ++i) { ranges[i].begin = {reads[i].begin.p, size_t(reads[i].begin.len)}; ranges[i].end = {reads[i].end.p, size_t(reads[i].end.len == 0 ? reads[i].begin.len + 1 : reads[i].end.len)}; ranges[i].version = reads[i].readVersion; if (reads[i].readVersion < oldestVersion) { results[i] = ConflictSet::TooOld; } else { results[i] = ConflictSet::Commit; } } skipList.detectConflicts(ranges, count, results); } void addWrites(const ConflictSet::WriteRange *writes, int count) { Arena arena; const int stringCount = count * 2; const int stripeSize = 16; SkipList::Finger fingers[stripeSize]; int temp[stripeSize]; int stripes = (stringCount + stripeSize - 1) / stripeSize; StringRef values[stripeSize]; int64_t writeVersions[stripeSize / 2]; int ss = stringCount - (stripes - 1) * stripeSize; for (int s = stripes - 1; s >= 0; s--) { for (int i = 0; i * 2 < ss; ++i) { const auto &w = writes[s * stripeSize / 2 + i]; values[i * 2] = {w.begin.p, size_t(w.begin.len)}; values[i * 2 + 1] = { w.end.p, size_t(w.end.len == 0 ? w.begin.len + 1 : w.end.len)}; writeVersions[i] = w.writeVersion; keyUpdates += 2; } skipList.find(values, fingers, temp, ss); skipList.addConflictRanges(fingers, ss / 2, writeVersions); ss = stripeSize; } } void setOldestVersion(int64_t oldestVersion) { this->oldestVersion = oldestVersion; SkipList::Finger finger; int temp; std::span key = removalKey; skipList.find(&key, &finger, &temp, 1); skipList.removeBefore(oldestVersion, finger, std::exchange(keyUpdates, 0)); removalKey = std::basic_string(finger.getValue().data(), finger.getValue().size()); } private: int64_t keyUpdates = 0; std::basic_string removalKey; int64_t oldestVersion; SkipList skipList; }; ConflictSet::ReadRange singleton(Arena &arena, std::span key) { auto r = std::span(new (arena) uint8_t[key.size() + 1], key.size() + 1); memcpy(r.data(), key.data(), key.size()); r[key.size()] = 0; return {key.data(), int(key.size()), r.data(), int(r.size())}; } ConflictSet::ReadRange prefixRange(Arena &arena, std::span key) { int index; for (index = key.size() - 1; index >= 0; index--) if ((key[index]) != 255) break; // Must not be called with a string that consists only of zero or more '\xff' // bytes. if (index < 0) { assert(false); } auto r = std::span(new (arena) uint8_t[index + 1], index + 1); memcpy(r.data(), key.data(), index + 1); r[r.size() - 1]++; return {key.data(), int(key.size()), r.data(), int(r.size())}; } } // namespace constexpr int kNumKeys = 1000000; constexpr int kOpsPerTx = 100; constexpr int kPrefixLen = 0; constexpr int kMvccWindow = 100000; std::span makeKey(Arena &arena, int index) { auto result = std::span{new (arena) uint8_t[4 + kPrefixLen], 4 + kPrefixLen}; index = __builtin_bswap32(index); memset(result.data(), 0, kPrefixLen); memcpy(result.data() + kPrefixLen, &index, 4); return result; } template void benchConflictSet(const std::string &name) { ankerl::nanobench::Bench bench; ConflictSet_ cs{0}; bench.batch(kOpsPerTx); bench.minEpochIterations(2000); int64_t version = 0; // Populate conflict set Arena arena; { std::vector writes; writes.reserve(kNumKeys); for (int i = 0; i < kNumKeys; ++i) { auto key = makeKey(arena, i); ConflictSet::WriteRange w; auto r = singleton(arena, key); w.begin.p = r.begin.p; w.begin.len = r.begin.len; w.end.p = r.end.p; w.end.len = 0; writes.push_back(w); } cs.addWrites(writes.data(), writes.size()); ++version; } // I don't know why std::less didn't work /shrug struct Less { bool operator()(const std::span &lhs, const std::span &rhs) const { return lhs < rhs; } }; auto points = set, Less>(arena); while (points.size() < kOpsPerTx * 2 + 1) { // TODO don't use rand? points.insert(makeKey(arena, rand() % kNumKeys)); } // Make short-circuiting non-trivial { std::vector writes; auto iter = points.begin(); ++iter; // Complement of the set we'll be reading with range reads. Almost. for (int i = 0; i < kOpsPerTx; ++i) { auto begin = *iter++; auto end = *iter++; ConflictSet::WriteRange w; w.begin.p = begin.data(); w.begin.len = begin.size(); w.end.p = end.data(); w.end.len = end.size(); w.writeVersion = version + 1; writes.push_back(w); } ++version; cs.addWrites(writes.data(), kOpsPerTx); } { std::vector reads; auto iter = points.begin(); for (int i = 0; i < kOpsPerTx; ++i) { auto r = singleton(arena, *iter); r.end.len = 0; r.readVersion = version - 1; reads.push_back(r); ++iter; } auto *results = new (arena) ConflictSet::Result[kOpsPerTx]; bench.run(name + " (point reads)", [&]() { cs.check(reads.data(), results, kOpsPerTx); }); } { std::vector reads; auto iter = points.begin(); for (int i = 0; i < kOpsPerTx; ++i) { auto r = prefixRange(arena, *iter); r.readVersion = version - 1; reads.push_back(r); ++iter; } auto *results = new (arena) ConflictSet::Result[kOpsPerTx]; bench.run(name + " (prefix reads)", [&]() { cs.check(reads.data(), results, kOpsPerTx); }); } { std::vector reads; auto iter = points.begin(); for (int i = 0; i < kOpsPerTx; ++i) { auto begin = *iter++; auto end = *iter++; ConflictSet::ReadRange r; r.begin.p = begin.data(); r.begin.len = begin.size(); r.end.p = end.data(); r.end.len = end.size(); r.readVersion = version - 1; reads.push_back(r); } auto *results = new (arena) ConflictSet::Result[kOpsPerTx]; bench.run(name + " (range reads)", [&]() { cs.check(reads.data(), results, kOpsPerTx); }); } { std::vector writes; auto iter = points.begin(); for (int i = 0; i < kOpsPerTx; ++i) { ConflictSet::WriteRange w; auto r = singleton(arena, *iter); w.begin.p = r.begin.p; w.begin.len = r.begin.len; w.end.p = r.end.p; w.end.len = 0; writes.push_back(w); ++iter; } bench.run(name + " (point writes)", [&]() { auto v = ++version; for (auto &w : writes) { w.writeVersion = v; } cs.addWrites(writes.data(), writes.size()); cs.setOldestVersion(std::max(version - kMvccWindow, 0)); }); } { std::vector writes; auto iter = points.begin(); for (int i = 0; i < kOpsPerTx; ++i) { ConflictSet::WriteRange w; auto r = prefixRange(arena, *iter); w.begin.p = r.begin.p; w.begin.len = r.begin.len; w.end.p = r.end.p; w.end.len = r.end.len; writes.push_back(w); ++iter; } bench.run(name + " (prefix writes)", [&]() { auto v = ++version; for (auto &w : writes) { w.writeVersion = v; } cs.addWrites(writes.data(), writes.size()); cs.setOldestVersion(std::max(version - kMvccWindow, 0)); }); } { std::vector writes; auto iter = points.begin(); for (int i = 0; i < kOpsPerTx; ++i) { auto begin = *iter++; auto end = *iter++; ConflictSet::WriteRange w; w.begin.p = begin.data(); w.begin.len = begin.size(); w.end.p = end.data(); w.end.len = end.size(); writes.push_back(w); } bench.run(name + " (range writes)", [&]() { auto v = ++version; for (auto &w : writes) { w.writeVersion = v; } cs.addWrites(writes.data(), writes.size()); cs.setOldestVersion(std::max(version - kMvccWindow, 0)); }); } } int main(void) { benchConflictSet("skip list"); benchConflictSet("radix tree"); } // extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // TestDriver driver{data, size}; // for (;;) { // bool done = driver.next(); // if (!driver.ok) { // // debugPrintDot(stdout, driver.cs.root); // // fflush(stdout); // abort(); // } // #if DEBUG_VERBOSE && !defined(NDEBUG) // fprintf(stderr, "Check correctness\n"); // #endif // // bool success = checkCorrectness(driver.cs.root); // // if (!success) { // // debugPrintDot(stdout, driver.cs.root); // // fflush(stdout); // // abort(); // // } // if (done) { // break; // } // } // return 0; // }