From 4edf0315d987236d7139c6e7c2f8509039684031 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Thu, 28 Mar 2024 10:47:20 -0700 Subject: [PATCH] Find insertion point for Node16 with simd Closes #13 --- ConflictSet.cpp | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index dddf47c..923bccd 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -29,6 +29,7 @@ limitations under the License. #include #include #include +#include #include #ifdef HAS_AVX @@ -954,6 +955,42 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, assert(self->getType() == Type_Node16); ++self->numChildren; +#ifdef HAS_AVX + __m128i key_vec = _mm_set1_epi8(index); + __m128i indices; + memcpy(&indices, self16->index, sizeof(self16->index)); + __m128i results = _mm_cmpeq_epi8(key_vec, _mm_min_epu8(key_vec, indices)); + int mask = (1 << (self->numChildren - 1)) - 1; + uint32_t bitfield = _mm_movemask_epi8(results) & mask; + bitfield |= uint32_t(1) << (self->numChildren - 1); + int i = std::countr_zero(bitfield); + if (i < self->numChildren - 1) { + memmove(self16->index + i + 1, self16->index + i, + self->numChildren - (i + 1)); + memmove(self16->children + i + 1, self16->children + i, + (self->numChildren - (i + 1)) * sizeof(Child)); + } +#elif defined(HAS_ARM_NEON) + uint8x16_t indices; + memcpy(&indices, self16->index, sizeof(self16->index)); + // 0xff for each leq + auto results = vcleq_u8(vdupq_n_u8(index), indices); + uint64_t mask = (uint64_t(1) << ((self->numChildren - 1) * 4)) - 1; + // 0xf for each 0xff (within mask) + uint64_t bitfield = + vget_lane_u64( + vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), + 0) & + mask; + bitfield |= uint64_t(0xf) << ((self->numChildren - 1) * 4); + int i = std::countr_zero(bitfield) / 4; + if (i < self->numChildren - 1) { + memmove(self16->index + i + 1, self16->index + i, + self->numChildren - (i + 1)); + memmove(self16->children + i + 1, self16->children + i, + (self->numChildren - (i + 1)) * sizeof(Child)); + } +#else int i = 0; for (; i < int(self->numChildren) - 1; ++i) { if (int(self16->index[i]) > int(index)) { @@ -964,6 +1001,7 @@ Node *&getOrCreateChild(Node *&self, uint8_t index, break; } } +#endif self16->index[i] = index; auto &result = self16->children[i].child; result = nullptr;