#include "RootSet.h" #include "Internal.h" #include #include #include #include #include #include struct RootSet::ThreadSafeHandle::Impl { static int sizeForCapacity(int capacity) { return sizeof(Impl) + sizeof(int64_t) * capacity + sizeof(uint32_t) * capacity; } static Impl *create(int capacity) { auto *result = (Impl *)safe_malloc(sizeForCapacity(capacity)); result->capacity = capacity; return result; } int64_t *versions() { return (int64_t *)(this + 1); } uint32_t *roots() { return (uint32_t *)(versions() + capacity); } // Linked list of Impl's to free, ordered by version Impl *next; int capacity; std::atomic end; // Find the index of the last version <= version, or 0 if no such version // exists uint32_t lastLeq(int64_t version) { int left = 1; int right = end.load(std::memory_order_acquire) - 1; int result = 0; while (left <= right) { int mid = left + (right - left) / 2; if (versions()[mid] <= version) { result = mid; left = mid + 1; } else { right = mid - 1; } } return result; } }; struct RootSet::Impl { Impl() { auto *h = ThreadSafeHandle::Impl::create(kMinCapacity); h->roots()[0] = 0; h->versions()[0] = 0; h->end.store(1, std::memory_order_relaxed); handle.store(h, std::memory_order_relaxed); firstToFree = nullptr; lastToFree = nullptr; oldestVersion = 0; } ~Impl() { for (auto *i = firstToFree; i != nullptr;) { auto *tmp = i; i = i->next; safe_free(tmp, ThreadSafeHandle::Impl::sizeForCapacity(tmp->capacity)); } auto h = handle.load(std::memory_order_relaxed); safe_free(h, ThreadSafeHandle::Impl::sizeForCapacity(h->capacity)); } void add(uint32_t node, int64_t version) { ThreadSafeHandle::Impl *h = handle.load(std::memory_order_relaxed); auto end = h->end.load(std::memory_order_relaxed); if (h->roots()[end - 1] == node) { return; } // Upsize if necessary if (end == h->capacity) { h->next = nullptr; auto begin = h->lastLeq(oldestVersion); if (lastToFree != nullptr) { lastToFree->next = h; lastToFree = h; } else { firstToFree = h; lastToFree = h; } auto newEnd = h->capacity - begin; auto *newH = ThreadSafeHandle::Impl::create(newEnd * 2); memcpy(newH->roots(), h->roots() + begin, sizeof(h->roots()[0]) * newEnd); memcpy(newH->versions(), h->versions() + begin, sizeof(h->versions()[0]) * newEnd); newH->end.store(newEnd, std::memory_order_relaxed); handle.store(newH, std::memory_order_release); h = newH; end = newEnd; } h->roots()[end] = node; h->versions()[end] = version; h->end.store(end + 1, std::memory_order_release); } void setOldestVersion(int64_t oldestVersion) { this->oldestVersion = oldestVersion; // The next buffer has an entry greater than all entries in this buffer. // If the next buffer does not have an entry > `oldestVersion`, then this // buffer is missing an entry <= oldestVersion on the right, so it's // incorrect to read this buffer at `oldestVersion` and we can safely free // it. while (firstToFree != nullptr && firstToFree->next != nullptr && firstToFree->next->versions()[firstToFree->next->end.load( std::memory_order_relaxed) - 1] <= oldestVersion) { auto *tmp = firstToFree; firstToFree = firstToFree->next; safe_free(tmp, ThreadSafeHandle::Impl::sizeForCapacity(tmp->capacity)); } #ifndef NDEBUG assert(rootCount() > 0); auto *h = handle.load(std::memory_order_relaxed); assert(h->versions()[h->lastLeq(oldestVersion)] <= oldestVersion); #endif } const uint32_t *roots() const { auto *h = handle.load(std::memory_order_relaxed); return h->roots() + h->lastLeq(oldestVersion); } int rootCount() const { auto *h = handle.load(std::memory_order_relaxed); return h->end.load(std::memory_order_relaxed) - h->lastLeq(oldestVersion); } ThreadSafeHandle getThreadSafeHandle() const { ThreadSafeHandle result; auto *impl = handle.load(std::memory_order_acquire); memcpy(&result, &impl, sizeof(result)); return result; } constexpr static uint32_t kMinCapacity = 16; std::atomic handle; int64_t oldestVersion; ThreadSafeHandle::Impl *firstToFree; ThreadSafeHandle::Impl *lastToFree; }; void RootSet::add(uint32_t node, int64_t version) { impl->add(node, version); } void RootSet::setOldestVersion(int64_t oldestVersion) { impl->setOldestVersion(oldestVersion); } uint32_t RootSet::ThreadSafeHandle::rootForVersion(int64_t version) const { auto result = impl->roots()[impl->lastLeq(version)]; return result; } RootSet::ThreadSafeHandle RootSet::getThreadSafeHandle() const { return impl->getThreadSafeHandle(); } const uint32_t *RootSet::roots() const { return impl->roots(); } int RootSet::rootCount() const { return impl->rootCount(); } RootSet::RootSet() : impl(new(safe_malloc(sizeof(Impl))) Impl()) {} RootSet::~RootSet() { impl->~Impl(); safe_free(impl, sizeof(*impl)); } #ifdef ENABLE_ROOTSET_TESTS #include #include #include #include int main() { constexpr int kNumReaders = 3; constexpr int kNumVersions = 2000000; RootSet rs; std::latch ready{1 + kNumReaders}; std::atomic version; std::vector> doneVersions(kNumReaders); std::thread writer([&]() { ready.arrive_and_wait(); for (int i = 0; i < kNumVersions; ++i) { rs.add(i / 10, i); version.store(i); uint32_t min = -1; for (auto &v : doneVersions) { min = std::min(min, v.load()); } rs.setOldestVersion(min); } }); std::vector readers; for (int i = 0; i < kNumReaders; ++i) { readers.emplace_back([&, i]() { ready.arrive_and_wait(); for (;;) { auto v = version.load(); assert(rs.getThreadSafeHandle().rootForVersion(v) == v / 10); doneVersions[i].store(v); if (v == kNumVersions - 1) { break; } } }); } writer.join(); for (auto &t : readers) { t.join(); } } #endif