Prepare for firstGeq to be safe on foreign threads
This commit is contained in:
209
RootSet.cpp
Normal file
209
RootSet.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include "RootSet.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <inttypes.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
struct RootSet::ThreadSafeHandle::Impl {
|
||||
|
||||
static Impl *create(int capacity) {
|
||||
int size =
|
||||
sizeof(Impl) + sizeof(int64_t) * capacity + sizeof(uint32_t) * capacity;
|
||||
auto *result = (Impl *)malloc(size);
|
||||
result->capacity = capacity;
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t *versions() { return (int64_t *)(this + 1); }
|
||||
uint32_t *roots() { return (uint32_t *)(versions() + capacity); }
|
||||
|
||||
// Linked list of Impl's to free, ordered by version
|
||||
Impl *next;
|
||||
int capacity;
|
||||
std::atomic<int> end;
|
||||
|
||||
// Find the index of the last version <= version, or 0 if no such version
|
||||
// exists
|
||||
uint32_t lastLeq(int64_t version) {
|
||||
int left = 1;
|
||||
int right = end.load(std::memory_order_acquire) - 1;
|
||||
int result = 0;
|
||||
while (left <= right) {
|
||||
int mid = left + (right - left) / 2;
|
||||
if (versions()[mid] <= version) {
|
||||
result = mid;
|
||||
left = mid + 1;
|
||||
} else {
|
||||
right = mid - 1;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct RootSet::Impl {
|
||||
|
||||
Impl() {
|
||||
auto *h = ThreadSafeHandle::Impl::create(kMinCapacity);
|
||||
h->roots()[0] = 0;
|
||||
h->versions()[0] = 0;
|
||||
h->end.store(1, std::memory_order_relaxed);
|
||||
handle.store(h, std::memory_order_relaxed);
|
||||
firstToFree = nullptr;
|
||||
lastToFree = nullptr;
|
||||
oldestVersion = 0;
|
||||
}
|
||||
|
||||
~Impl() {
|
||||
for (auto *i = firstToFree; i != nullptr;) {
|
||||
auto *tmp = i;
|
||||
i = i->next;
|
||||
|
||||
free(tmp);
|
||||
}
|
||||
free(handle.load(std::memory_order_relaxed));
|
||||
}
|
||||
|
||||
void add(uint32_t node, int64_t version) {
|
||||
ThreadSafeHandle::Impl *h = handle.load(std::memory_order_relaxed);
|
||||
|
||||
// Upsize if necessary
|
||||
if (h->end.load(std::memory_order_relaxed) == h->capacity) {
|
||||
h->next = nullptr;
|
||||
auto begin = h->lastLeq(oldestVersion);
|
||||
if (lastToFree != nullptr) {
|
||||
lastToFree->next = h;
|
||||
lastToFree = h;
|
||||
} else {
|
||||
firstToFree = h;
|
||||
lastToFree = h;
|
||||
}
|
||||
auto *newH = ThreadSafeHandle::Impl::create((h->capacity - begin) * 2);
|
||||
memcpy(newH->roots(), h->roots() + begin,
|
||||
sizeof(h->roots()[0]) * (h->capacity - begin));
|
||||
memcpy(newH->versions(), h->versions() + begin,
|
||||
sizeof(h->versions()[0]) * (h->capacity - begin));
|
||||
newH->end.store(h->capacity - begin, std::memory_order_relaxed);
|
||||
handle.store(newH, std::memory_order_release);
|
||||
h = newH;
|
||||
}
|
||||
|
||||
auto end = h->end.load(std::memory_order_relaxed);
|
||||
|
||||
if (h->roots()[end - 1] != node) {
|
||||
h->roots()[end] = node;
|
||||
h->versions()[end] = version;
|
||||
h->end.store(end + 1, std::memory_order_release);
|
||||
}
|
||||
}
|
||||
|
||||
void setOldestVersion(int64_t oldestVersion) {
|
||||
this->oldestVersion = oldestVersion;
|
||||
while (firstToFree != nullptr && firstToFree->next != nullptr &&
|
||||
firstToFree->next->versions()[firstToFree->next->end.load(
|
||||
std::memory_order_relaxed) -
|
||||
1] < oldestVersion) {
|
||||
auto *tmp = firstToFree;
|
||||
firstToFree = firstToFree->next;
|
||||
free(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t *roots() const {
|
||||
auto *h = handle.load(std::memory_order_relaxed);
|
||||
return h->roots() + h->lastLeq(oldestVersion);
|
||||
}
|
||||
|
||||
int rootCount() const {
|
||||
auto *h = handle.load(std::memory_order_relaxed);
|
||||
return h->end.load(std::memory_order_relaxed) - h->lastLeq(oldestVersion);
|
||||
}
|
||||
|
||||
ThreadSafeHandle getThreadSafeHandle() const {
|
||||
ThreadSafeHandle result;
|
||||
auto *impl = handle.load(std::memory_order_acquire);
|
||||
memcpy(&result, &impl, sizeof(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr static uint32_t kMinCapacity = 16;
|
||||
|
||||
std::atomic<ThreadSafeHandle::Impl *> handle;
|
||||
|
||||
int64_t oldestVersion;
|
||||
ThreadSafeHandle::Impl *firstToFree;
|
||||
ThreadSafeHandle::Impl *lastToFree;
|
||||
};
|
||||
|
||||
void RootSet::add(uint32_t node, int64_t version) { impl->add(node, version); }
|
||||
|
||||
void RootSet::setOldestVersion(int64_t oldestVersion) {
|
||||
impl->setOldestVersion(oldestVersion);
|
||||
}
|
||||
|
||||
uint32_t RootSet::ThreadSafeHandle::rootForVersion(int64_t version) const {
|
||||
auto result = impl->roots()[impl->lastLeq(version)];
|
||||
return result;
|
||||
}
|
||||
|
||||
RootSet::ThreadSafeHandle RootSet::getThreadSafeHandle() const {
|
||||
return impl->getThreadSafeHandle();
|
||||
}
|
||||
|
||||
const uint32_t *RootSet::roots() const { return impl->roots(); }
|
||||
int RootSet::rootCount() const { return impl->rootCount(); }
|
||||
|
||||
RootSet::RootSet() : impl(new(malloc(sizeof(Impl))) Impl()) {}
|
||||
|
||||
RootSet::~RootSet() {
|
||||
impl->~Impl();
|
||||
free(impl);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ROOTSET_TESTS
|
||||
#include <latch>
|
||||
#include <thread>
|
||||
|
||||
int main() {
|
||||
constexpr int kNumReaders = 3;
|
||||
constexpr int kNumVersions = 2000000;
|
||||
|
||||
RootSet rs;
|
||||
std::latch ready{1 + kNumReaders};
|
||||
std::atomic<int> version;
|
||||
std::vector<std::atomic<int>> doneVersions(kNumReaders);
|
||||
std::thread writer([&]() {
|
||||
ready.arrive_and_wait();
|
||||
for (int i = 0; i < kNumVersions; ++i) {
|
||||
rs.add(i / 10, i);
|
||||
version.store(i);
|
||||
int min = std::numeric_limits<int>::max();
|
||||
for (auto &v : doneVersions) {
|
||||
min = std::min(min, v.load());
|
||||
}
|
||||
rs.setOldestVersion(min);
|
||||
}
|
||||
});
|
||||
std::vector<std::thread> readers;
|
||||
for (int i = 0; i < kNumReaders; ++i) {
|
||||
readers.emplace_back([&, i]() {
|
||||
ready.arrive_and_wait();
|
||||
for (;;) {
|
||||
auto v = version.load();
|
||||
assert(rs.getThreadSafeHandle().rootForVersion(v) == v / 10);
|
||||
doneVersions[i].store(v);
|
||||
if (v == kNumVersions - 1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
writer.join();
|
||||
for (auto &t : readers) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
#endif
|
Reference in New Issue
Block a user