/* * SkipList.cpp * * This source file is part of the FoundationDB open source project * * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ConflictSet.h" #include "Internal.h" #include std::span keyAfter(Arena &arena, std::span key) { auto result = std::span(new (arena) uint8_t[key.size() + 1], key.size() + 1); memcpy(result.data(), key.data(), key.size()); result[result.size() - 1] = 0; return result; } std::span copyToArena(Arena &arena, std::span key) { auto result = std::span(new (arena) uint8_t[key.size()], key.size()); memcpy(result.data(), key.data(), key.size()); return result; } using Version = int64_t; #define force_inline __attribute__((always_inline)) using StringRef = std::span; struct KeyRangeRef { StringRef begin; StringRef end; KeyRangeRef() {} KeyRangeRef(StringRef begin, StringRef end) : begin(begin), end(end) {} KeyRangeRef(Arena &arena, StringRef begin) : begin(begin), end(keyAfter(arena, begin)) {} }; static thread_local uint32_t g_seed = 0; static inline int skfastrand() { g_seed = g_seed * 1664525L + 1013904223L; return g_seed; } static int compare(const StringRef &a, const StringRef &b) { int c = memcmp(a.data(), b.data(), std::min(a.size(), b.size())); if (c < 0) return -1; if (c > 0) return +1; if (a.size() < b.size()) return -1; if (a.size() == b.size()) return 0; return +1; } struct ReadConflictRange { StringRef begin, end; Version version; ReadConflictRange() {} ReadConflictRange(StringRef begin, StringRef end, Version version) : begin(begin), end(end), version(version) {} bool operator<(const ReadConflictRange &rhs) const { return compare(begin, rhs.begin) < 0; } }; class SkipList { private: static constexpr int MaxLevels = 26; int randomLevel() const { uint32_t i = uint32_t(skfastrand()) >> (32 - (MaxLevels - 1)); int level = 0; while (i & 1) { i >>= 1; level++; } assert(level < MaxLevels); return level; } // Represent a node in the SkipList. The node has multiple (i.e., level) // pointers to other nodes, and keeps a record of the max versions for each // level. struct Node { int level() const { return nPointers - 1; } uint8_t *value() { return end() + nPointers * (sizeof(Node *) + sizeof(Version)); } int length() const { return valueLength; } // Returns the next node pointer at the given level. Node *getNext(int level) { return *((Node **)end() + level); } // Sets the next node pointer at the given level. void setNext(int level, Node *n) { *((Node **)end() + level) = n; } // Returns the max version at the given level. Version getMaxVersion(int i) const { return ((Version *)(end() + nPointers * sizeof(Node *)))[i]; } // Sets the max version at the given level. void setMaxVersion(int i, Version v) { ((Version *)(end() + nPointers * sizeof(Node *)))[i] = v; } // Return a node with initialized value but uninitialized pointers // Memory layout: *this, (level+1) Node*, (level+1) Version, value static Node *create(const StringRef &value, int level) { int nodeSize = sizeof(Node) + value.size() + (level + 1) * (sizeof(Node *) + sizeof(Version)); Node *n; n = (Node *)safe_malloc(nodeSize); n->nPointers = level + 1; n->valueLength = value.size(); if (value.size() > 0) { memcpy(n->value(), value.data(), value.size()); } return n; } // pre: level>0, all lower level nodes between this and getNext(level) have // correct maxversions void calcVersionForLevel(int level) { Node *end = getNext(level); Version v = getMaxVersion(level - 1); for (Node *x = getNext(level - 1); x != end; x = x->getNext(level - 1)) v = std::max(v, x->getMaxVersion(level - 1)); setMaxVersion(level, v); } void destroy() { safe_free(this); } private: int getNodeSize() const { return sizeof(Node) + valueLength + nPointers * (sizeof(Node *) + sizeof(Version)); } // Returns the first Node* pointer uint8_t *end() { return (uint8_t *)(this + 1); } uint8_t const *end() const { return (uint8_t const *)(this + 1); } int nPointers, valueLength; }; static force_inline bool less(const uint8_t *a, int aLen, const uint8_t *b, int bLen) { int c = memcmp(a, b, std::min(aLen, bLen)); if (c < 0) return true; if (c > 0) return false; return aLen < bLen; } Node *header; void destroy() { Node *next, *x; for (x = header; x; x = next) { next = x->getNext(0); x->destroy(); } } public: // Points the location (i.e., Node*) that value would appear in the SkipList. // If the "value" is in the list, then finger[0] points to that exact node; // otherwise, the finger points to Nodes that the value should be inserted // before. Note the SkipList organizes all nodes at level 0, higher levels // contain jump pointers. struct Finger { Node *finger[MaxLevels]; // valid for levels >= level int level = MaxLevels; Node *x = nullptr; Node *alreadyChecked = nullptr; StringRef value; Finger() = default; Finger(Node *header, const StringRef &ptr) : x(header), value(ptr) {} void init(const StringRef &value, Node *header) { this->value = value; x = header; alreadyChecked = nullptr; level = MaxLevels; } // pre: !finished() force_inline void prefetch() { Node *next = x->getNext(0); __builtin_prefetch(next); } // pre: !finished() // Advances the pointer at the current level to a Node that's >= finger's // value if possible; or move to the next level (i.e., level--). Returns // true if we have advanced to the next level force_inline bool advance() { Node *next = x->getNext(level - 1); if (next == alreadyChecked || !less(next->value(), next->length(), value.data(), value.size())) { alreadyChecked = next; level--; finger[level] = x; return true; } else { x = next; return false; } } // pre: !finished() force_inline void nextLevel() { while (!advance()) ; } force_inline bool finished() const { return level == 0; } // Returns if the finger value is found in the SkipList. force_inline Node *found() const { // valid after finished returns true Node *n = finger[0]->getNext( 0); // or alreadyChecked, but that is more easily invalidated if (n && n->length() == value.size() && !memcmp(n->value(), value.data(), value.size())) return n; else return nullptr; } StringRef getValue() const { Node *n = finger[0]->getNext(0); return n ? StringRef(n->value(), n->length()) : StringRef(); } }; // Returns the total number of nodes in the list. int count() const { int count = 0; Node *x = header->getNext(0); while (x) { x = x->getNext(0); count++; } return count; } explicit SkipList(Version version = 0) { header = Node::create(StringRef(), MaxLevels - 1); for (int l = 0; l < MaxLevels; l++) { header->setNext(l, nullptr); header->setMaxVersion(l, version); } } ~SkipList() { destroy(); } SkipList(SkipList &&other) noexcept : header(other.header) { other.header = nullptr; } void operator=(SkipList &&other) noexcept { destroy(); header = other.header; other.header = nullptr; } void swap(SkipList &other) { std::swap(header, other.header); } void addConflictRanges(const Finger *fingers, int rangeCount, Version version) { for (int r = rangeCount - 1; r >= 0; r--) { const Finger &startF = fingers[r * 2]; const Finger &endF = fingers[r * 2 + 1]; if (endF.found() == nullptr) insert(endF, endF.finger[0]->getMaxVersion(0)); remove(startF, endF); insert(startF, version); } } void detectConflicts(ReadConflictRange *ranges, int count, ConflictSet::Result *transactionConflictStatus) const { const int M = 16; int nextJob[M]; CheckMax inProgress[M]; if (!count) return; int started = std::min(M, count); for (int i = 0; i < started; i++) { inProgress[i].init(ranges[i], header, transactionConflictStatus + i); nextJob[i] = i + 1; } nextJob[started - 1] = 0; int prevJob = started - 1; int job = 0; // vtune: 340 parts while (true) { if (inProgress[job].advance()) { if (started == count) { if (prevJob == job) break; nextJob[prevJob] = nextJob[job]; job = prevJob; } else { int temp = started++; inProgress[job].init(ranges[temp], header, transactionConflictStatus + temp); } } prevJob = job; job = nextJob[job]; } } void find(const StringRef *values, Finger *results, int *temp, int count) { // Relying on the ordering of values, descend until the values aren't all in // the same part of the tree // vtune: 11 parts results[0].init(values[0], header); const StringRef &endValue = values[count - 1]; while (results[0].level > 1) { results[0].nextLevel(); Node *ac = results[0].alreadyChecked; if (ac && less(ac->value(), ac->length(), endValue.data(), endValue.size())) break; } // Init all the other fingers to start descending where we stopped // the first one // SOMEDAY: this loop showed up on vtune, could be faster? // vtune: 8 parts int startLevel = results[0].level + 1; Node *x = startLevel < MaxLevels ? results[0].finger[startLevel] : header; for (int i = 1; i < count; i++) { results[i].level = startLevel; results[i].x = x; results[i].alreadyChecked = nullptr; results[i].value = values[i]; for (int j = startLevel; j < MaxLevels; j++) results[i].finger[j] = results[0].finger[j]; } int *nextJob = temp; for (int i = 0; i < count - 1; i++) nextJob[i] = i + 1; nextJob[count - 1] = 0; int prevJob = count - 1; int job = 0; // vtune: 225 parts while (true) { Finger *f = &results[job]; f->advance(); if (f->finished()) { if (prevJob == job) break; nextJob[prevJob] = nextJob[job]; } else { f->prefetch(); prevJob = job; } job = nextJob[job]; } } int removeBefore(Version v, Finger &f, int nodeCount) { // f.x, f.alreadyChecked? int removedCount = 0; bool wasAbove = true; while (nodeCount--) { Node *x = f.finger[0]->getNext(0); if (!x) break; // double prefetch gives +25% speed (single threaded) Node *next = x->getNext(0); __builtin_prefetch(next); next = x->getNext(1); __builtin_prefetch(next); bool isAbove = x->getMaxVersion(0) >= v; if (isAbove || wasAbove) { // f.nextItem for (int l = 0; l <= x->level(); l++) f.finger[l] = x; } else { // f.eraseItem removedCount++; for (int l = 0; l <= x->level(); l++) f.finger[l]->setNext(l, x->getNext(l)); for (int i = 1; i <= x->level(); i++) f.finger[i]->setMaxVersion( i, std::max(f.finger[i]->getMaxVersion(i), x->getMaxVersion(i))); x->destroy(); } wasAbove = isAbove; } return removedCount; } private: void remove(const Finger &start, const Finger &end) { if (start.finger[0] == end.finger[0]) return; Node *x = start.finger[0]->getNext(0); // vtune says: this loop is the expensive parts (6 parts) for (int i = 0; i < MaxLevels; i++) if (start.finger[i] != end.finger[i]) start.finger[i]->setNext(i, end.finger[i]->getNext(i)); while (true) { Node *next = x->getNext(0); x->destroy(); if (x == end.finger[0]) break; x = next; } } void insert(const Finger &f, Version version) { int level = randomLevel(); // std::cout << std::string((const char*)value,length) << " level: " << // level << std::endl; Node *x = Node::create(f.value, level); x->setMaxVersion(0, version); for (int i = 0; i <= level; i++) { x->setNext(i, f.finger[i]->getNext(i)); f.finger[i]->setNext(i, x); } // vtune says: this loop is the costly part of this function for (int i = 1; i <= level; i++) { f.finger[i]->calcVersionForLevel(i); x->calcVersionForLevel(i); } for (int i = level + 1; i < MaxLevels; i++) { Version v = f.finger[i]->getMaxVersion(i); if (v >= version) break; f.finger[i]->setMaxVersion(i, version); } } struct CheckMax { Finger start, end; Version version; ConflictSet::Result *result; int state; void init(const ReadConflictRange &r, Node *header, ConflictSet::Result *result) { this->start.init(r.begin, header); this->end.init(r.end, header); this->version = r.version; this->state = 0; this->result = result; } bool noConflict() const { return true; } bool conflict() { *result = ConflictSet::Conflict; return true; } // Return true if finished force_inline bool advance() { if (*result == ConflictSet::TooOld) { return true; } switch (state) { case 0: // find where start and end fingers diverge while (true) { if (!start.advance()) { start.prefetch(); return false; } end.x = start.x; while (!end.advance()) ; int l = start.level; if (start.finger[l] != end.finger[l]) break; // accept if the range spans the check range, but does not have a // greater version if (start.finger[l]->getMaxVersion(l) <= version) return noConflict(); if (l == 0) return conflict(); } state = 1; case 1: { // check the end side of the pyramid Node *e = end.finger[end.level]; while (e->getMaxVersion(end.level) > version) { if (end.finished()) return conflict(); end.nextLevel(); Node *f = end.finger[end.level]; while (e != f) { if (e->getMaxVersion(end.level) > version) return conflict(); e = e->getNext(end.level); } } // check the start side of the pyramid Node *s = end.finger[start.level]; while (true) { Node *nextS = start.finger[start.level]->getNext(start.level); Node *p = nextS; while (p != s) { if (p->getMaxVersion(start.level) > version) return conflict(); p = p->getNext(start.level); } if (start.finger[start.level]->getMaxVersion(start.level) <= version) return noConflict(); s = nextS; if (start.finished()) { if (nextS->length() == start.value.size() && !memcmp(nextS->value(), start.value.data(), start.value.size())) return noConflict(); else return conflict(); } start.nextLevel(); } } default: __builtin_unreachable(); } } }; }; struct SkipListConflictSet {}; struct __attribute__((visibility("hidden"))) ConflictSet::Impl { Impl(int64_t oldestVersion) : oldestVersion(oldestVersion), skipList(oldestVersion) {} void check(const ConflictSet::ReadRange *reads, ConflictSet::Result *results, int count) const { Arena arena; auto *ranges = new (arena) ReadConflictRange[count]; for (int i = 0; i < count; ++i) { ranges[i].begin = {reads[i].begin.p, size_t(reads[i].begin.len)}; ranges[i].end = reads[i].end.len > 0 ? StringRef{reads[i].end.p, size_t(reads[i].end.len)} : keyAfter(arena, ranges[i].begin); ranges[i].version = reads[i].readVersion; results[i] = ConflictSet::Commit; } skipList.detectConflicts(ranges, count, results); for (int i = 0; i < count; ++i) { if (reads[i].readVersion < oldestVersion) { results[i] = TooOld; } } } void addWrites(const ConflictSet::WriteRange *writes, int count, int64_t writeVersion) { Arena arena; const int stringCount = count * 2; const int stripeSize = 16; SkipList::Finger fingers[stripeSize]; int temp[stripeSize]; int stripes = (stringCount + stripeSize - 1) / stripeSize; StringRef values[stripeSize]; int64_t writeVersions[stripeSize / 2]; int ss = stringCount - (stripes - 1) * stripeSize; for (int s = stripes - 1; s >= 0; s--) { for (int i = 0; i * 2 < ss; ++i) { const auto &w = writes[s * stripeSize / 2 + i]; values[i * 2] = {w.begin.p, size_t(w.begin.len)}; values[i * 2 + 1] = w.end.len > 0 ? StringRef{w.end.p, size_t(w.end.len)} : keyAfter(arena, values[i * 2]); keyUpdates += 2; } skipList.find(values, fingers, temp, ss); skipList.addConflictRanges(fingers, ss / 2, writeVersion); ss = stripeSize; } } void setOldestVersion(int64_t oldestVersion) { this->oldestVersion = oldestVersion; SkipList::Finger finger; int temp; std::span key = removalKey; skipList.find(&key, &finger, &temp, 1); skipList.removeBefore(oldestVersion, finger, std::exchange(keyUpdates, 0)); removalArena = Arena(); removalKey = copyToArena( removalArena, {finger.getValue().data(), finger.getValue().size()}); } private: int64_t keyUpdates = 0; Arena removalArena; std::span removalKey; int64_t oldestVersion; SkipList skipList; }; void ConflictSet::check(const ReadRange *reads, Result *results, int count) const { return impl->check(reads, results, count); } void ConflictSet::addWrites(const WriteRange *writes, int count, int64_t writeVersion) { return impl->addWrites(writes, count, writeVersion); } void ConflictSet::setOldestVersion(int64_t oldestVersion) { return impl->setOldestVersion(oldestVersion); } ConflictSet::ConflictSet(int64_t oldestVersion) : impl(new (safe_malloc(sizeof(Impl))) Impl{oldestVersion}) {} ConflictSet::~ConflictSet() { if (impl) { impl->~Impl(); safe_free(impl); } } ConflictSet::ConflictSet(ConflictSet &&other) noexcept : impl(std::exchange(other.impl, nullptr)) {} ConflictSet &ConflictSet::operator=(ConflictSet &&other) noexcept { impl = std::exchange(other.impl, nullptr); return *this; } using ConflictSet_Result = ConflictSet::Result; using ConflictSet_Key = ConflictSet::Key; using ConflictSet_ReadRange = ConflictSet::ReadRange; using ConflictSet_WriteRange = ConflictSet::WriteRange; extern "C" { __attribute__((__visibility__("default"))) void ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads, ConflictSet_Result *results, int count) { ((ConflictSet::Impl *)cs)->check(reads, results, count); } __attribute__((__visibility__("default"))) void ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count, int64_t writeVersion) { ((ConflictSet::Impl *)cs)->addWrites(writes, count, writeVersion); } __attribute__((__visibility__("default"))) void ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) { ((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion); } __attribute__((__visibility__("default"))) void * ConflictSet_create(int64_t oldestVersion) { return new (safe_malloc(sizeof(ConflictSet::Impl))) ConflictSet::Impl{oldestVersion}; } __attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) { using Impl = ConflictSet::Impl; ((Impl *)cs)->~Impl(); free(cs); } }