Implement std::random_access_iterator_tag

This commit is contained in:
2025-08-18 12:51:50 -04:00
parent 0920193e5c
commit 8d99ee9a5b

View File

@@ -78,8 +78,7 @@ template <class T> struct ThreadPipeline {
Batch() : ring(), begin_(), end_() {} Batch() : ring(), begin_(), end_() {}
struct Iterator { struct Iterator {
// TODO implement random_iterator_tag using iterator_category = std::random_access_iterator_tag;
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using value_type = T; using value_type = T;
using pointer = value_type *; using pointer = value_type *;
@@ -100,6 +99,40 @@ template <class T> struct ThreadPipeline {
++(*this); ++(*this);
return tmp; return tmp;
} }
Iterator &operator--() {
--index;
return *this;
}
Iterator operator--(int) {
auto tmp = *this;
--(*this);
return tmp;
}
Iterator &operator+=(difference_type n) {
index += n;
return *this;
}
Iterator &operator-=(difference_type n) {
index -= n;
return *this;
}
Iterator operator+(difference_type n) const {
return Iterator(index + n, ring);
}
Iterator operator-(difference_type n) const {
return Iterator(index - n, ring);
}
difference_type operator-(const Iterator &rhs) const {
assert(ring == rhs.ring);
return static_cast<difference_type>(index) -
static_cast<difference_type>(rhs.index);
}
reference operator[](difference_type n) const {
return (*ring)[(index + n) & (ring->size() - 1)];
}
friend Iterator operator+(difference_type n, const Iterator &iter) {
return iter + n;
}
friend bool operator==(const Iterator &lhs, const Iterator &rhs) { friend bool operator==(const Iterator &lhs, const Iterator &rhs) {
assert(lhs.ring == rhs.ring); assert(lhs.ring == rhs.ring);
return lhs.index == rhs.index; return lhs.index == rhs.index;
@@ -108,6 +141,23 @@ template <class T> struct ThreadPipeline {
assert(lhs.ring == rhs.ring); assert(lhs.ring == rhs.ring);
return lhs.index != rhs.index; return lhs.index != rhs.index;
} }
friend bool operator<(const Iterator &lhs, const Iterator &rhs) {
assert(lhs.ring == rhs.ring);
// Handle potential uint32_t wraparound by using signed difference
return static_cast<int32_t>(lhs.index - rhs.index) < 0;
}
friend bool operator<=(const Iterator &lhs, const Iterator &rhs) {
assert(lhs.ring == rhs.ring);
return static_cast<int32_t>(lhs.index - rhs.index) <= 0;
}
friend bool operator>(const Iterator &lhs, const Iterator &rhs) {
assert(lhs.ring == rhs.ring);
return static_cast<int32_t>(lhs.index - rhs.index) > 0;
}
friend bool operator>=(const Iterator &lhs, const Iterator &rhs) {
assert(lhs.ring == rhs.ring);
return static_cast<int32_t>(lhs.index - rhs.index) >= 0;
}
private: private:
Iterator(uint32_t index, std::vector<T> *const ring) Iterator(uint32_t index, std::vector<T> *const ring)