diff --git a/src/ThreadPipeline.h b/src/ThreadPipeline.h new file mode 100644 index 0000000..a1c45ad --- /dev/null +++ b/src/ThreadPipeline.h @@ -0,0 +1,267 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +template struct ThreadPipeline { + ThreadPipeline(int lgSlotCount, const std::vector &threadsPerStage) + : slotCount(1 << lgSlotCount), slotCountMask(slotCount - 1), + threadState(threadsPerStage.size()), ring(slotCount) { + // Otherwise we can't tell the difference between full and empty. + assert(!(slotCountMask & 0x80000000)); + for (size_t i = 0; i < threadsPerStage.size(); ++i) { + threadState[i] = std::vector(threadsPerStage[i]); + for (auto &t : threadState[i]) { + if (i == 0) { + t.lastPushRead = std::vector(1); + } else { + t.lastPushRead = std::vector(threadsPerStage[i - 1]); + } + } + } + } + + ThreadPipeline(ThreadPipeline const &) = delete; + ThreadPipeline &operator=(ThreadPipeline const &) = delete; + ThreadPipeline(ThreadPipeline &&) = delete; + ThreadPipeline &operator=(ThreadPipeline &&) = delete; + + struct Batch { + + Batch() : ring(), begin_(), end_() {} + + struct Iterator { + // TODO implement random_iterator_tag + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = value_type *; + using reference = value_type &; + + reference operator*() const { + return (*ring)[index & (ring->size() - 1)]; + } + pointer operator->() const { + return &(*ring)[index & (ring->size() - 1)]; + } + Iterator &operator++() { + ++index; + return *this; + } + Iterator operator++(int) { + auto tmp = *this; + ++(*this); + return tmp; + } + friend bool operator==(const Iterator &lhs, const Iterator &rhs) { + assert(lhs.ring == rhs.ring); + return lhs.index == rhs.index; + } + friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { + assert(lhs.ring == rhs.ring); + return lhs.index != rhs.index; + } + + private: + Iterator(uint32_t index, std::vector *const ring) + : index(index), ring(ring) {} + friend struct Batch; + uint32_t index; + std::vector *const ring; + }; + + [[nodiscard]] Iterator begin() { return Iterator(begin_, ring); } + [[nodiscard]] Iterator end() { return Iterator(end_, ring); } + + [[nodiscard]] size_t size() const { return end_ - begin_; } + [[nodiscard]] bool empty() const { return end_ == begin_; } + + private: + friend struct ThreadPipeline; + Batch(std::vector *const ring, uint32_t begin_, uint32_t end_) + : ring(ring), begin_(begin_), end_(end_) {} + std::vector *const ring; + uint32_t begin_; + uint32_t end_; + }; + +private: + Batch acquireHelper(int stage, int thread, uint32_t maxBatch, bool mayBlock) { + uint32_t begin = threadState[stage][thread].localPops & slotCountMask; + uint32_t len = getSafeLen(stage, thread, mayBlock); + if (maxBatch != 0) { + len = std::min(len, maxBatch); + } + if (len == 0) { + return Batch{}; + } + auto result = Batch{&ring, begin, begin + len}; + threadState[stage][thread].localPops += len; + return result; + } + + // Used by producer threads to reserve slots in the ring buffer + alignas(128) std::atomic slots{0}; + // Used for producers to publish + alignas(128) std::atomic pushes{0}; + + const uint32_t slotCount; + const uint32_t slotCountMask; + + // We can safely acquire this many items + uint32_t getSafeLen(int stage, int threadIndex, bool mayBlock) { + uint32_t safeLen = -1; + auto &thread = threadState[stage][threadIndex]; + // See if we can determine that there are entries we can acquire entirely + // from state local to the thread + for (int i = 0; i < int(thread.lastPushRead.size()); ++i) { + auto &lastPush = stage == 0 ? pushes : threadState[stage - 1][i].pops; + if (thread.lastPushRead[i] == thread.localPops) { + // Re-read lastPush with memory order and try again + thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire); + if (thread.lastPushRead[i] == thread.localPops) { + if (!mayBlock) { + return 0; + } + // Wait for lastPush to change and try again + lastPush.wait(thread.lastPushRead[i], std::memory_order_relaxed); + thread.lastPushRead[i] = lastPush.load(std::memory_order_acquire); + } + } + safeLen = std::min(safeLen, thread.lastPushRead[i] - thread.localPops); + } + return safeLen; + } + + struct alignas(128) ThreadState { + // Where this thread has published up to + std::atomic pops; + // Where this thread will publish to the next time it publishes + uint32_t localPops; + // Where the previous stage's threads have published up to last we checked + std::vector lastPushRead; + }; + // threadState[i][j] is the state for thread j in stage i + std::vector> threadState; + // Shared ring buffer + std::vector ring; + +public: + struct StageGuard { + Batch batch; + ~StageGuard() { + if (ts != nullptr) { + // seq_cst so that the notify can't be ordered before the store + ts->pops.store(localPops, std::memory_order_seq_cst); + ts->pops.notify_all(); + } + } + + StageGuard(StageGuard const &) = delete; + StageGuard &operator=(StageGuard const &) = delete; + StageGuard(StageGuard &&other) + : batch(other.batch), localPops(other.localPops), + ts(std::exchange(other.ts, nullptr)) {} + StageGuard &operator=(StageGuard &&other) { + batch = other.batch; + localPops = other.localPops; + ts = std::exchange(other.ts, nullptr); + return *this; + } + + private: + uint32_t localPops; + friend struct ThreadPipeline; + StageGuard(Batch batch, ThreadState *ts) + : batch(batch), localPops(ts->localPops), + ts(batch.empty() ? nullptr : ts) {} + ThreadState *ts; + }; + + struct ProducerGuard { + Batch batch; + + ~ProducerGuard() { + if (tp == nullptr) { + return; + } + // Wait for earlier slots to finish being published, since publishing + // implies that all previous slots were also published. + for (;;) { + uint32_t p = tp->pushes.load(std::memory_order_acquire); + if (p == oldSlot) { + break; + } + tp->pushes.wait(p, std::memory_order_relaxed); + } + // Publish. seq_cst so that the notify can't be ordered before the store + tp->pushes.store(newSlot, std::memory_order_seq_cst); + // We have to notify every time, since we don't know if this is the last + // push ever + tp->pushes.notify_all(); + } + + private: + friend struct ThreadPipeline; + ProducerGuard() : batch(), tp() {} + ProducerGuard(Batch batch, ThreadPipeline *tp, uint32_t oldSlot, + uint32_t newSlot) + : batch(batch), tp(tp), oldSlot(oldSlot), newSlot(newSlot) {} + ThreadPipeline *const tp; + uint32_t oldSlot; + uint32_t newSlot; + }; + + [[nodiscard]] StageGuard acquire(int stage, int thread, int maxBatch = 0, + bool mayBlock = true) { + assert(stage < threadState.size()); + assert(thread < threadState[stage].size()); + auto batch = acquireHelper(stage, thread, maxBatch, mayBlock); + return StageGuard{std::move(batch), &threadState[stage][thread]}; + } + + // Grants exclusive access to a producer thread to a span of up to `size`. If + // `block` is true, then this call will block if the queue is full, until the + // queue is not full. Otherwise it will return an empty batch if the queue is + // full. + [[nodiscard]] ProducerGuard push(uint32_t const size, bool block) { + if (size == 0) { + abort(); + } + if (size > slotCount) { + abort(); + } + // Reserve a slot to construct an item, but don't publish to consumer yet + uint32_t slot; + uint32_t begin; + for (;;) { + begin_loop: + slot = slots.load(std::memory_order_relaxed); + begin = slot & slotCountMask; + // Make sure we won't stomp the back of the ring buffer + for (auto &thread : threadState.back()) { + uint32_t pops = thread.pops.load(std::memory_order_acquire); + if (slot + size - pops > slotCount) { + if (!block) { + return ProducerGuard{}; + } + thread.pops.wait(pops, std::memory_order_relaxed); + goto begin_loop; + } + } + if (slots.compare_exchange_weak(slot, slot + size, + std::memory_order_relaxed, + std::memory_order_relaxed)) { + break; + } + } + return ProducerGuard{Batch{&ring, begin, begin + size}, this, slot, + slot + size}; + } +};