ThreadPipeline.h -> thread_pipeline.hpp
This commit is contained in:
379
src/thread_pipeline.hpp
Normal file
379
src/thread_pipeline.hpp
Normal file
@@ -0,0 +1,379 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// Multi-stage lock-free pipeline for high-throughput inter-thread
|
||||
// communication.
|
||||
//
|
||||
// Overview:
|
||||
// - Items flow through multiple processing stages (stage 0 -> stage 1 -> ... ->
|
||||
// final stage)
|
||||
// - Each stage can have multiple worker threads processing items in parallel
|
||||
// - Uses a shared ring buffer with atomic counters for lock-free coordination
|
||||
// - Supports batch processing for efficiency
|
||||
//
|
||||
// Usage Pattern:
|
||||
// // Producer threads (add items to stage 0):
|
||||
// auto guard = pipeline.push(batchSize, /*block=*/true);
|
||||
// for (auto& item : guard.batch) {
|
||||
// // Initialize item data
|
||||
// }
|
||||
// // Guard destructor publishes batch to consumers
|
||||
//
|
||||
// // Consumer threads (process items from any stage):
|
||||
// auto guard = pipeline.acquire(stageNum, threadId, maxBatch,
|
||||
// /*mayBlock=*/true); for (auto& item : guard.batch) {
|
||||
// // Process item
|
||||
// }
|
||||
// // Guard destructor marks items as consumed and available to next stage
|
||||
//
|
||||
// Memory Model:
|
||||
// - Ring buffer size must be power of 2 for efficient masking
|
||||
// - Actual ring slots accessed via: index & (slotCount - 1)
|
||||
// - 128-byte aligned atomics prevent false sharing between CPU cache lines
|
||||
//
|
||||
// Thread Safety:
|
||||
// - Fully lock-free using atomic operations with acquire/release memory
|
||||
// ordering
|
||||
// - Uses C++20 atomic wait/notify for efficient blocking when no work available
|
||||
// - RAII guards ensure proper cleanup even with exceptions
|
||||
template <class T> struct ThreadPipeline {
|
||||
// Constructor
|
||||
// lgSlotCount: log2 of ring buffer size (e.g., 10 -> 1024 slots)
|
||||
// threadsPerStage: number of threads for each stage (e.g., {1, 4, 2} = 1
|
||||
// stage-0 worker, 4 stage-1 workers, 2 stage-2 workers)
|
||||
ThreadPipeline(int lgSlotCount, const std::vector<int> &threadsPerStage)
|
||||
: slotCount(1 << lgSlotCount), slotCountMask(slotCount - 1),
|
||||
threadState(threadsPerStage.size()), ring(slotCount) {
|
||||
// Otherwise we can't tell the difference between full and empty.
|
||||
assert(!(slotCountMask & 0x80000000));
|
||||
for (size_t i = 0; i < threadsPerStage.size(); ++i) {
|
||||
threadState[i] = std::vector<ThreadState>(threadsPerStage[i]);
|
||||
for (auto &t : threadState[i]) {
|
||||
if (i == 0) {
|
||||
t.lastPushRead = std::vector<uint32_t>(1);
|
||||
} else {
|
||||
t.lastPushRead = std::vector<uint32_t>(threadsPerStage[i - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPipeline(ThreadPipeline const &) = delete;
|
||||
ThreadPipeline &operator=(ThreadPipeline const &) = delete;
|
||||
ThreadPipeline(ThreadPipeline &&) = delete;
|
||||
ThreadPipeline &operator=(ThreadPipeline &&) = delete;
|
||||
|
||||
struct Batch {
|
||||
|
||||
Batch() : ring(), begin_(), end_() {}
|
||||
|
||||
struct Iterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using value_type = T;
|
||||
using pointer = value_type *;
|
||||
using reference = value_type &;
|
||||
|
||||
reference operator*() const {
|
||||
return (*ring)[index_ & (ring->size() - 1)];
|
||||
}
|
||||
pointer operator->() const {
|
||||
return &(*ring)[index_ & (ring->size() - 1)];
|
||||
}
|
||||
Iterator &operator++() {
|
||||
++index_;
|
||||
return *this;
|
||||
}
|
||||
Iterator operator++(int) {
|
||||
auto tmp = *this;
|
||||
++(*this);
|
||||
return tmp;
|
||||
}
|
||||
Iterator &operator--() {
|
||||
--index_;
|
||||
return *this;
|
||||
}
|
||||
Iterator operator--(int) {
|
||||
auto tmp = *this;
|
||||
--(*this);
|
||||
return tmp;
|
||||
}
|
||||
Iterator &operator+=(difference_type n) {
|
||||
index_ += n;
|
||||
return *this;
|
||||
}
|
||||
Iterator &operator-=(difference_type n) {
|
||||
index_ -= n;
|
||||
return *this;
|
||||
}
|
||||
Iterator operator+(difference_type n) const {
|
||||
return Iterator(index_ + n, ring);
|
||||
}
|
||||
Iterator operator-(difference_type n) const {
|
||||
return Iterator(index_ - n, ring);
|
||||
}
|
||||
difference_type operator-(const Iterator &rhs) const {
|
||||
assert(ring == rhs.ring);
|
||||
return static_cast<difference_type>(index_) -
|
||||
static_cast<difference_type>(rhs.index_);
|
||||
}
|
||||
reference operator[](difference_type n) const {
|
||||
return (*ring)[(index_ + n) & (ring->size() - 1)];
|
||||
}
|
||||
friend Iterator operator+(difference_type n, const Iterator &iter) {
|
||||
return iter + n;
|
||||
}
|
||||
friend bool operator==(const Iterator &lhs, const Iterator &rhs) {
|
||||
assert(lhs.ring == rhs.ring);
|
||||
return lhs.index_ == rhs.index_;
|
||||
}
|
||||
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) {
|
||||
assert(lhs.ring == rhs.ring);
|
||||
return lhs.index_ != rhs.index_;
|
||||
}
|
||||
friend bool operator<(const Iterator &lhs, const Iterator &rhs) {
|
||||
assert(lhs.ring == rhs.ring);
|
||||
// Handle potential uint32_t wraparound by using signed difference
|
||||
return static_cast<int32_t>(lhs.index_ - rhs.index_) < 0;
|
||||
}
|
||||
friend bool operator<=(const Iterator &lhs, const Iterator &rhs) {
|
||||
assert(lhs.ring == rhs.ring);
|
||||
return static_cast<int32_t>(lhs.index_ - rhs.index_) <= 0;
|
||||
}
|
||||
friend bool operator>(const Iterator &lhs, const Iterator &rhs) {
|
||||
assert(lhs.ring == rhs.ring);
|
||||
return static_cast<int32_t>(lhs.index_ - rhs.index_) > 0;
|
||||
}
|
||||
friend bool operator>=(const Iterator &lhs, const Iterator &rhs) {
|
||||
assert(lhs.ring == rhs.ring);
|
||||
return static_cast<int32_t>(lhs.index_ - rhs.index_) >= 0;
|
||||
}
|
||||
|
||||
/// Returns the ring buffer index (0 to ring->size()-1) for this iterator
|
||||
/// position. Useful for distributing work across multiple threads by
|
||||
/// using modulo operations.
|
||||
uint32_t index() const { return index_ & (ring->size() - 1); }
|
||||
|
||||
private:
|
||||
Iterator(uint32_t index, std::vector<T> *const ring)
|
||||
: index_(index), ring(ring) {}
|
||||
friend struct Batch;
|
||||
uint32_t index_;
|
||||
std::vector<T> *const ring;
|
||||
};
|
||||
|
||||
[[nodiscard]] Iterator begin() { return Iterator(begin_, ring); }
|
||||
[[nodiscard]] Iterator end() { return Iterator(end_, ring); }
|
||||
|
||||
[[nodiscard]] size_t size() const { return end_ - begin_; }
|
||||
[[nodiscard]] bool empty() const { return end_ == begin_; }
|
||||
|
||||
private:
|
||||
friend struct ThreadPipeline<T>;
|
||||
Batch(std::vector<T> *const ring, uint32_t begin_, uint32_t end_)
|
||||
: ring(ring), begin_(begin_), end_(end_) {}
|
||||
std::vector<T> *const ring;
|
||||
uint32_t begin_;
|
||||
uint32_t end_;
|
||||
};
|
||||
|
||||
private:
|
||||
Batch acquireHelper(int stage, int thread, uint32_t maxBatch, bool mayBlock) {
|
||||
uint32_t begin = threadState[stage][thread].localPops & slotCountMask;
|
||||
uint32_t len = getSafeLen(stage, thread, mayBlock);
|
||||
if (maxBatch != 0) {
|
||||
len = std::min(len, maxBatch);
|
||||
}
|
||||
if (len == 0) {
|
||||
return Batch{};
|
||||
}
|
||||
auto result = Batch{&ring, begin, begin + len};
|
||||
threadState[stage][thread].localPops += len;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Used by producer threads to reserve slots in the ring buffer
|
||||
alignas(128) std::atomic<uint32_t> slots{0};
|
||||
// Used for producers to publish
|
||||
alignas(128) std::atomic<uint32_t> pushes{0};
|
||||
|
||||
const uint32_t slotCount;
|
||||
const uint32_t slotCountMask;
|
||||
|
||||
// We can safely acquire this many items
|
||||
uint32_t getSafeLen(int stage, int threadIndex, bool mayBlock) {
|
||||
uint32_t safeLen = UINT32_MAX;
|
||||
auto &thread = threadState[stage][threadIndex];
|
||||
// See if we can determine that there are entries we can acquire entirely
|
||||
// from state local to the thread
|
||||
for (int i = 0; i < int(thread.lastPushRead.size()); ++i) {
|
||||
auto &lastPush = stage == 0 ? pushes : threadState[stage - 1][i].pops;
|
||||
if (thread.lastPushRead[i] == thread.localPops) {
|
||||
// Re-read lastPush with memory order and try again
|
||||
thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire);
|
||||
if (thread.lastPushRead[i] == thread.localPops) {
|
||||
if (!mayBlock) {
|
||||
return 0;
|
||||
}
|
||||
// Wait for lastPush to change and try again
|
||||
lastPush.wait(thread.lastPushRead[i], std::memory_order_relaxed);
|
||||
thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire);
|
||||
}
|
||||
}
|
||||
safeLen = std::min(safeLen, thread.lastPushRead[i] - thread.localPops);
|
||||
}
|
||||
return safeLen;
|
||||
}
|
||||
|
||||
struct ThreadState {
|
||||
// Where this thread has published up to
|
||||
alignas(128) std::atomic<uint32_t> pops{0};
|
||||
// Where this thread will publish to the next time it publishes
|
||||
uint32_t localPops{0};
|
||||
// Where the previous stage's threads have published up to last we checked
|
||||
std::vector<uint32_t> lastPushRead;
|
||||
};
|
||||
// threadState[i][j] is the state for thread j in stage i
|
||||
std::vector<std::vector<ThreadState>> threadState;
|
||||
// Shared ring buffer
|
||||
std::vector<T> ring;
|
||||
|
||||
public:
|
||||
struct StageGuard {
|
||||
Batch batch;
|
||||
~StageGuard() {
|
||||
if (ts != nullptr) {
|
||||
// seq_cst so that the notify can't be ordered before the store
|
||||
ts->pops.store(localPops, std::memory_order_seq_cst);
|
||||
ts->pops.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
StageGuard(StageGuard const &) = delete;
|
||||
StageGuard &operator=(StageGuard const &) = delete;
|
||||
StageGuard(StageGuard &&other)
|
||||
: batch(other.batch), localPops(other.localPops),
|
||||
ts(std::exchange(other.ts, nullptr)) {}
|
||||
StageGuard &operator=(StageGuard &&other) {
|
||||
batch = other.batch;
|
||||
localPops = other.localPops;
|
||||
ts = std::exchange(other.ts, nullptr);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t localPops;
|
||||
friend struct ThreadPipeline;
|
||||
StageGuard(Batch batch, ThreadState *ts)
|
||||
: batch(batch), localPops(ts->localPops),
|
||||
ts(batch.empty() ? nullptr : ts) {}
|
||||
ThreadState *ts;
|
||||
};
|
||||
|
||||
struct ProducerGuard {
|
||||
Batch batch;
|
||||
|
||||
~ProducerGuard() {
|
||||
if (tp == nullptr) {
|
||||
return;
|
||||
}
|
||||
// Wait for earlier slots to finish being published, since publishing
|
||||
// implies that all previous slots were also published.
|
||||
for (;;) {
|
||||
uint32_t p = tp->pushes.load(std::memory_order_acquire);
|
||||
if (p == oldSlot) {
|
||||
break;
|
||||
}
|
||||
tp->pushes.wait(p, std::memory_order_relaxed);
|
||||
}
|
||||
// Publish. seq_cst so that the notify can't be ordered before the store
|
||||
tp->pushes.store(newSlot, std::memory_order_seq_cst);
|
||||
// We have to notify every time, since we don't know if this is the last
|
||||
// push ever
|
||||
tp->pushes.notify_all();
|
||||
}
|
||||
|
||||
private:
|
||||
friend struct ThreadPipeline;
|
||||
ProducerGuard() : batch(), tp() {}
|
||||
ProducerGuard(Batch batch, ThreadPipeline<T> *tp, uint32_t oldSlot,
|
||||
uint32_t newSlot)
|
||||
: batch(batch), tp(tp), oldSlot(oldSlot), newSlot(newSlot) {}
|
||||
ThreadPipeline<T> *const tp;
|
||||
uint32_t oldSlot;
|
||||
uint32_t newSlot;
|
||||
};
|
||||
|
||||
// Acquire a batch of items for processing by a consumer thread.
|
||||
// stage: which processing stage (0 = first consumer stage after producers)
|
||||
// thread: thread ID within the stage (0 to threadsPerStage[stage]-1)
|
||||
// maxBatch: maximum items to acquire (0 = no limit)
|
||||
// mayBlock: whether to block waiting for items (false = return empty batch if
|
||||
// none available) Returns: StageGuard with batch of items to process
|
||||
[[nodiscard]] StageGuard acquire(int stage, int thread, int maxBatch = 0,
|
||||
bool mayBlock = true) {
|
||||
assert(stage < int(threadState.size()));
|
||||
assert(thread < int(threadState[stage].size()));
|
||||
auto batch = acquireHelper(stage, thread, maxBatch, mayBlock);
|
||||
return StageGuard{std::move(batch), &threadState[stage][thread]};
|
||||
}
|
||||
|
||||
// Reserve slots in the ring buffer for a producer thread to fill with items.
|
||||
// This is used by producer threads to add new items to stage 0 of the
|
||||
// pipeline.
|
||||
//
|
||||
// size: number of slots to reserve (must be > 0 and <= ring buffer capacity)
|
||||
// block: if true, blocks when ring buffer is full; if false, returns empty
|
||||
// guard Returns: ProducerGuard with exclusive access to reserved slots
|
||||
//
|
||||
// Usage: Fill items in the returned batch, then let guard destructor publish
|
||||
// them. The guard destructor ensures items are published in the correct
|
||||
// order.
|
||||
//
|
||||
// Preconditions:
|
||||
// - size > 0 (must request at least one slot)
|
||||
// - size <= slotCount (cannot request more slots than ring buffer capacity)
|
||||
// Violating preconditions results in program termination via abort().
|
||||
[[nodiscard]] ProducerGuard push(uint32_t const size, bool block) {
|
||||
if (size == 0) {
|
||||
abort();
|
||||
}
|
||||
if (size > slotCount) {
|
||||
abort();
|
||||
}
|
||||
// Reserve a slot to construct an item, but don't publish to consumer yet
|
||||
uint32_t slot;
|
||||
uint32_t begin;
|
||||
for (;;) {
|
||||
begin_loop:
|
||||
slot = slots.load(std::memory_order_relaxed);
|
||||
begin = slot & slotCountMask;
|
||||
// Make sure we won't stomp the back of the ring buffer
|
||||
for (auto &thread : threadState.back()) {
|
||||
uint32_t pops = thread.pops.load(std::memory_order_acquire);
|
||||
if (slot + size - pops > slotCount) {
|
||||
if (!block) {
|
||||
return ProducerGuard{};
|
||||
}
|
||||
thread.pops.wait(pops, std::memory_order_relaxed);
|
||||
goto begin_loop;
|
||||
}
|
||||
}
|
||||
if (slots.compare_exchange_weak(slot, slot + size,
|
||||
std::memory_order_relaxed,
|
||||
std::memory_order_relaxed)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ProducerGuard{Batch{&ring, begin, begin + size}, this, slot,
|
||||
slot + size};
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user