#pragma once #include #include #include #include #include #include #include #include // Multi-stage lock-free pipeline for high-throughput inter-thread // communication. // // Overview: // - Items flow through multiple processing stages (stage 0 -> stage 1 -> ... -> // final stage) // - Each stage can have multiple worker threads processing items in parallel // - Uses a shared ring buffer with atomic counters for lock-free coordination // - Supports batch processing for efficiency // // Usage Pattern: // // Producer threads (add items to stage 0): // auto guard = pipeline.push(batchSize, /*block=*/true); // for (auto& item : guard.batch) { // // Initialize item data // } // // Guard destructor publishes batch to consumers // // // Consumer threads (process items from any stage): // auto guard = pipeline.acquire(stageNum, threadId, maxBatch, // /*mayBlock=*/true); for (auto& item : guard.batch) { // // Process item // } // // Guard destructor marks items as consumed and available to next stage // // Memory Model: // - Ring buffer size must be power of 2 for efficient masking // - Uses 64-bit indices to avoid ABA problems (indices never repeat until // uint32_t overflow) // - Actual ring slots accessed via: index & (slotCount - 1) // - 128-byte aligned atomics prevent false sharing between CPU cache lines // // Thread Safety: // - Fully lock-free using atomic operations with acquire/release memory // ordering // - Uses C++20 atomic wait/notify for efficient blocking when no work available // - RAII guards ensure proper cleanup even with exceptions template struct ThreadPipeline { // Constructor // lgSlotCount: log2 of ring buffer size (e.g., 10 -> 1024 slots) // threadsPerStage: number of threads for each stage (e.g., {1, 4, 2} = 1 // producer, 4 stage-1 workers, 2 stage-2 workers) ThreadPipeline(int lgSlotCount, const std::vector &threadsPerStage) : slotCount(1 << lgSlotCount), slotCountMask(slotCount - 1), threadState(threadsPerStage.size()), ring(slotCount) { // Otherwise we can't tell the difference between full and empty. assert(!(slotCountMask & 0x80000000)); for (size_t i = 0; i < threadsPerStage.size(); ++i) { threadState[i] = std::vector(threadsPerStage[i]); for (auto &t : threadState[i]) { if (i == 0) { t.lastPushRead = std::vector(1); } else { t.lastPushRead = std::vector(threadsPerStage[i - 1]); } } } } ThreadPipeline(ThreadPipeline const &) = delete; ThreadPipeline &operator=(ThreadPipeline const &) = delete; ThreadPipeline(ThreadPipeline &&) = delete; ThreadPipeline &operator=(ThreadPipeline &&) = delete; struct Batch { Batch() : ring(), begin_(), end_() {} struct Iterator { using iterator_category = std::random_access_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = T; using pointer = value_type *; using reference = value_type &; reference operator*() const { return (*ring)[index & (ring->size() - 1)]; } pointer operator->() const { return &(*ring)[index & (ring->size() - 1)]; } Iterator &operator++() { ++index; return *this; } Iterator operator++(int) { auto tmp = *this; ++(*this); return tmp; } Iterator &operator--() { --index; return *this; } Iterator operator--(int) { auto tmp = *this; --(*this); return tmp; } Iterator &operator+=(difference_type n) { index += n; return *this; } Iterator &operator-=(difference_type n) { index -= n; return *this; } Iterator operator+(difference_type n) const { return Iterator(index + n, ring); } Iterator operator-(difference_type n) const { return Iterator(index - n, ring); } difference_type operator-(const Iterator &rhs) const { assert(ring == rhs.ring); return static_cast(index) - static_cast(rhs.index); } reference operator[](difference_type n) const { return (*ring)[(index + n) & (ring->size() - 1)]; } friend Iterator operator+(difference_type n, const Iterator &iter) { return iter + n; } friend bool operator==(const Iterator &lhs, const Iterator &rhs) { assert(lhs.ring == rhs.ring); return lhs.index == rhs.index; } friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { assert(lhs.ring == rhs.ring); return lhs.index != rhs.index; } friend bool operator<(const Iterator &lhs, const Iterator &rhs) { assert(lhs.ring == rhs.ring); // Handle potential uint32_t wraparound by using signed difference return static_cast(lhs.index - rhs.index) < 0; } friend bool operator<=(const Iterator &lhs, const Iterator &rhs) { assert(lhs.ring == rhs.ring); return static_cast(lhs.index - rhs.index) <= 0; } friend bool operator>(const Iterator &lhs, const Iterator &rhs) { assert(lhs.ring == rhs.ring); return static_cast(lhs.index - rhs.index) > 0; } friend bool operator>=(const Iterator &lhs, const Iterator &rhs) { assert(lhs.ring == rhs.ring); return static_cast(lhs.index - rhs.index) >= 0; } private: Iterator(uint32_t index, std::vector *const ring) : index(index), ring(ring) {} friend struct Batch; uint32_t index; std::vector *const ring; }; [[nodiscard]] Iterator begin() { return Iterator(begin_, ring); } [[nodiscard]] Iterator end() { return Iterator(end_, ring); } [[nodiscard]] size_t size() const { return end_ - begin_; } [[nodiscard]] bool empty() const { return end_ == begin_; } private: friend struct ThreadPipeline; Batch(std::vector *const ring, uint32_t begin_, uint32_t end_) : ring(ring), begin_(begin_), end_(end_) {} std::vector *const ring; uint32_t begin_; uint32_t end_; }; private: Batch acquireHelper(int stage, int thread, uint32_t maxBatch, bool mayBlock) { uint32_t begin = threadState[stage][thread].localPops & slotCountMask; uint32_t len = getSafeLen(stage, thread, mayBlock); if (maxBatch != 0) { len = std::min(len, maxBatch); } if (len == 0) { return Batch{}; } auto result = Batch{&ring, begin, begin + len}; threadState[stage][thread].localPops += len; return result; } // Used by producer threads to reserve slots in the ring buffer alignas(128) std::atomic slots{0}; // Used for producers to publish alignas(128) std::atomic pushes{0}; const uint32_t slotCount; const uint32_t slotCountMask; // We can safely acquire this many items uint32_t getSafeLen(int stage, int threadIndex, bool mayBlock) { uint32_t safeLen = UINT32_MAX; auto &thread = threadState[stage][threadIndex]; // See if we can determine that there are entries we can acquire entirely // from state local to the thread for (int i = 0; i < int(thread.lastPushRead.size()); ++i) { auto &lastPush = stage == 0 ? pushes : threadState[stage - 1][i].pops; if (thread.lastPushRead[i] == thread.localPops) { // Re-read lastPush with memory order and try again thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire); if (thread.lastPushRead[i] == thread.localPops) { if (!mayBlock) { return 0; } // Wait for lastPush to change and try again lastPush.wait(thread.lastPushRead[i], std::memory_order_relaxed); thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire); } } safeLen = std::min(safeLen, thread.lastPushRead[i] - thread.localPops); } return safeLen; } struct alignas(128) ThreadState { // Where this thread has published up to std::atomic pops{0}; // Where this thread will publish to the next time it publishes uint32_t localPops{0}; // Where the previous stage's threads have published up to last we checked std::vector lastPushRead; }; // threadState[i][j] is the state for thread j in stage i std::vector> threadState; // Shared ring buffer std::vector ring; public: struct StageGuard { Batch batch; ~StageGuard() { if (ts != nullptr) { // seq_cst so that the notify can't be ordered before the store ts->pops.store(localPops, std::memory_order_seq_cst); ts->pops.notify_all(); } } StageGuard(StageGuard const &) = delete; StageGuard &operator=(StageGuard const &) = delete; StageGuard(StageGuard &&other) : batch(other.batch), localPops(other.localPops), ts(std::exchange(other.ts, nullptr)) {} StageGuard &operator=(StageGuard &&other) { batch = other.batch; localPops = other.localPops; ts = std::exchange(other.ts, nullptr); return *this; } private: uint32_t localPops; friend struct ThreadPipeline; StageGuard(Batch batch, ThreadState *ts) : batch(batch), localPops(ts->localPops), ts(batch.empty() ? nullptr : ts) {} ThreadState *ts; }; struct ProducerGuard { Batch batch; ~ProducerGuard() { if (tp == nullptr) { return; } // Wait for earlier slots to finish being published, since publishing // implies that all previous slots were also published. for (;;) { uint32_t p = tp->pushes.load(std::memory_order_acquire); if (p == oldSlot) { break; } tp->pushes.wait(p, std::memory_order_relaxed); } // Publish. seq_cst so that the notify can't be ordered before the store tp->pushes.store(newSlot, std::memory_order_seq_cst); // We have to notify every time, since we don't know if this is the last // push ever tp->pushes.notify_all(); } private: friend struct ThreadPipeline; ProducerGuard() : batch(), tp() {} ProducerGuard(Batch batch, ThreadPipeline *tp, uint32_t oldSlot, uint32_t newSlot) : batch(batch), tp(tp), oldSlot(oldSlot), newSlot(newSlot) {} ThreadPipeline *const tp; uint32_t oldSlot; uint32_t newSlot; }; // Acquire a batch of items for processing by a consumer thread. // stage: which processing stage (0 = first consumer stage after producers) // thread: thread ID within the stage (0 to threadsPerStage[stage]-1) // maxBatch: maximum items to acquire (0 = no limit) // mayBlock: whether to block waiting for items (false = return empty batch if // none available) Returns: StageGuard with batch of items to process [[nodiscard]] StageGuard acquire(int stage, int thread, int maxBatch = 0, bool mayBlock = true) { assert(stage < threadState.size()); assert(thread < threadState[stage].size()); auto batch = acquireHelper(stage, thread, maxBatch, mayBlock); return StageGuard{std::move(batch), &threadState[stage][thread]}; } // Reserve slots in the ring buffer for a producer thread to fill with items. // This is used by producer threads to add new items to stage 0 of the // pipeline. // // size: number of slots to reserve (must be > 0 and <= ring buffer capacity) // block: if true, blocks when ring buffer is full; if false, returns empty // guard Returns: ProducerGuard with exclusive access to reserved slots // // Usage: Fill items in the returned batch, then let guard destructor publish // them. The guard destructor ensures items are published in the correct // order. // // Preconditions: // - size > 0 (must request at least one slot) // - size <= slotCount (cannot request more slots than ring buffer capacity) // Violating preconditions results in program termination via abort(). [[nodiscard]] ProducerGuard push(uint32_t const size, bool block) { if (size == 0) { abort(); } if (size > slotCount) { abort(); } // Reserve a slot to construct an item, but don't publish to consumer yet uint32_t slot; uint32_t begin; for (;;) { begin_loop: slot = slots.load(std::memory_order_relaxed); begin = slot & slotCountMask; // Make sure we won't stomp the back of the ring buffer for (auto &thread : threadState.back()) { uint32_t pops = thread.pops.load(std::memory_order_acquire); if (slot + size - pops > slotCount) { if (!block) { return ProducerGuard{}; } thread.pops.wait(pops, std::memory_order_relaxed); goto begin_loop; } } if (slots.compare_exchange_weak(slot, slot + size, std::memory_order_relaxed, std::memory_order_relaxed)) { break; } } return ProducerGuard{Batch{&ring, begin, begin + size}, this, slot, slot + size}; } };