Add ThreadPipeline.h

This commit is contained in:
2025-08-18 12:42:54 -04:00
parent 5c377aa14d
commit b8891eee29

267
src/ThreadPipeline.h Normal file
View File

@@ -0,0 +1,267 @@
#pragma once
#include <atomic>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <iterator>
#include <utility>
#include <vector>
template <class T> struct ThreadPipeline {
ThreadPipeline(int lgSlotCount, const std::vector<int> &threadsPerStage)
: slotCount(1 << lgSlotCount), slotCountMask(slotCount - 1),
threadState(threadsPerStage.size()), ring(slotCount) {
// Otherwise we can't tell the difference between full and empty.
assert(!(slotCountMask & 0x80000000));
for (size_t i = 0; i < threadsPerStage.size(); ++i) {
threadState[i] = std::vector<ThreadState>(threadsPerStage[i]);
for (auto &t : threadState[i]) {
if (i == 0) {
t.lastPushRead = std::vector<uint32_t>(1);
} else {
t.lastPushRead = std::vector<uint32_t>(threadsPerStage[i - 1]);
}
}
}
}
ThreadPipeline(ThreadPipeline const &) = delete;
ThreadPipeline &operator=(ThreadPipeline const &) = delete;
ThreadPipeline(ThreadPipeline &&) = delete;
ThreadPipeline &operator=(ThreadPipeline &&) = delete;
struct Batch {
Batch() : ring(), begin_(), end_() {}
struct Iterator {
// TODO implement random_iterator_tag
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = T;
using pointer = value_type *;
using reference = value_type &;
reference operator*() const {
return (*ring)[index & (ring->size() - 1)];
}
pointer operator->() const {
return &(*ring)[index & (ring->size() - 1)];
}
Iterator &operator++() {
++index;
return *this;
}
Iterator operator++(int) {
auto tmp = *this;
++(*this);
return tmp;
}
friend bool operator==(const Iterator &lhs, const Iterator &rhs) {
assert(lhs.ring == rhs.ring);
return lhs.index == rhs.index;
}
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) {
assert(lhs.ring == rhs.ring);
return lhs.index != rhs.index;
}
private:
Iterator(uint32_t index, std::vector<T> *const ring)
: index(index), ring(ring) {}
friend struct Batch;
uint32_t index;
std::vector<T> *const ring;
};
[[nodiscard]] Iterator begin() { return Iterator(begin_, ring); }
[[nodiscard]] Iterator end() { return Iterator(end_, ring); }
[[nodiscard]] size_t size() const { return end_ - begin_; }
[[nodiscard]] bool empty() const { return end_ == begin_; }
private:
friend struct ThreadPipeline<T>;
Batch(std::vector<T> *const ring, uint32_t begin_, uint32_t end_)
: ring(ring), begin_(begin_), end_(end_) {}
std::vector<T> *const ring;
uint32_t begin_;
uint32_t end_;
};
private:
Batch acquireHelper(int stage, int thread, uint32_t maxBatch, bool mayBlock) {
uint32_t begin = threadState[stage][thread].localPops & slotCountMask;
uint32_t len = getSafeLen(stage, thread, mayBlock);
if (maxBatch != 0) {
len = std::min(len, maxBatch);
}
if (len == 0) {
return Batch{};
}
auto result = Batch{&ring, begin, begin + len};
threadState[stage][thread].localPops += len;
return result;
}
// Used by producer threads to reserve slots in the ring buffer
alignas(128) std::atomic<uint32_t> slots{0};
// Used for producers to publish
alignas(128) std::atomic<uint32_t> pushes{0};
const uint32_t slotCount;
const uint32_t slotCountMask;
// We can safely acquire this many items
uint32_t getSafeLen(int stage, int threadIndex, bool mayBlock) {
uint32_t safeLen = -1;
auto &thread = threadState[stage][threadIndex];
// See if we can determine that there are entries we can acquire entirely
// from state local to the thread
for (int i = 0; i < int(thread.lastPushRead.size()); ++i) {
auto &lastPush = stage == 0 ? pushes : threadState[stage - 1][i].pops;
if (thread.lastPushRead[i] == thread.localPops) {
// Re-read lastPush with memory order and try again
thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire);
if (thread.lastPushRead[i] == thread.localPops) {
if (!mayBlock) {
return 0;
}
// Wait for lastPush to change and try again
lastPush.wait(thread.lastPushRead[i], std::memory_order_relaxed);
thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire);
}
}
safeLen = std::min(safeLen, thread.lastPushRead[i] - thread.localPops);
}
return safeLen;
}
struct alignas(128) ThreadState {
// Where this thread has published up to
std::atomic<uint32_t> pops;
// Where this thread will publish to the next time it publishes
uint32_t localPops;
// Where the previous stage's threads have published up to last we checked
std::vector<uint32_t> lastPushRead;
};
// threadState[i][j] is the state for thread j in stage i
std::vector<std::vector<ThreadState>> threadState;
// Shared ring buffer
std::vector<T> ring;
public:
struct StageGuard {
Batch batch;
~StageGuard() {
if (ts != nullptr) {
// seq_cst so that the notify can't be ordered before the store
ts->pops.store(localPops, std::memory_order_seq_cst);
ts->pops.notify_all();
}
}
StageGuard(StageGuard const &) = delete;
StageGuard &operator=(StageGuard const &) = delete;
StageGuard(StageGuard &&other)
: batch(other.batch), localPops(other.localPops),
ts(std::exchange(other.ts, nullptr)) {}
StageGuard &operator=(StageGuard &&other) {
batch = other.batch;
localPops = other.localPops;
ts = std::exchange(other.ts, nullptr);
return *this;
}
private:
uint32_t localPops;
friend struct ThreadPipeline;
StageGuard(Batch batch, ThreadState *ts)
: batch(batch), localPops(ts->localPops),
ts(batch.empty() ? nullptr : ts) {}
ThreadState *ts;
};
struct ProducerGuard {
Batch batch;
~ProducerGuard() {
if (tp == nullptr) {
return;
}
// Wait for earlier slots to finish being published, since publishing
// implies that all previous slots were also published.
for (;;) {
uint32_t p = tp->pushes.load(std::memory_order_acquire);
if (p == oldSlot) {
break;
}
tp->pushes.wait(p, std::memory_order_relaxed);
}
// Publish. seq_cst so that the notify can't be ordered before the store
tp->pushes.store(newSlot, std::memory_order_seq_cst);
// We have to notify every time, since we don't know if this is the last
// push ever
tp->pushes.notify_all();
}
private:
friend struct ThreadPipeline;
ProducerGuard() : batch(), tp() {}
ProducerGuard(Batch batch, ThreadPipeline<T> *tp, uint32_t oldSlot,
uint32_t newSlot)
: batch(batch), tp(tp), oldSlot(oldSlot), newSlot(newSlot) {}
ThreadPipeline<T> *const tp;
uint32_t oldSlot;
uint32_t newSlot;
};
[[nodiscard]] StageGuard acquire(int stage, int thread, int maxBatch = 0,
bool mayBlock = true) {
assert(stage < threadState.size());
assert(thread < threadState[stage].size());
auto batch = acquireHelper(stage, thread, maxBatch, mayBlock);
return StageGuard{std::move(batch), &threadState[stage][thread]};
}
// Grants exclusive access to a producer thread to a span of up to `size`. If
// `block` is true, then this call will block if the queue is full, until the
// queue is not full. Otherwise it will return an empty batch if the queue is
// full.
[[nodiscard]] ProducerGuard push(uint32_t const size, bool block) {
if (size == 0) {
abort();
}
if (size > slotCount) {
abort();
}
// Reserve a slot to construct an item, but don't publish to consumer yet
uint32_t slot;
uint32_t begin;
for (;;) {
begin_loop:
slot = slots.load(std::memory_order_relaxed);
begin = slot & slotCountMask;
// Make sure we won't stomp the back of the ring buffer
for (auto &thread : threadState.back()) {
uint32_t pops = thread.pops.load(std::memory_order_acquire);
if (slot + size - pops > slotCount) {
if (!block) {
return ProducerGuard{};
}
thread.pops.wait(pops, std::memory_order_relaxed);
goto begin_loop;
}
}
if (slots.compare_exchange_weak(slot, slot + size,
std::memory_order_relaxed,
std::memory_order_relaxed)) {
break;
}
}
return ProducerGuard{Batch{&ring, begin, begin + size}, this, slot,
slot + size};
}
};