From 6241533dfbdb004566abd867b7d4bbc7d73b6886 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Thu, 12 Sep 2024 22:05:00 -0700 Subject: [PATCH] Improve codegen for getChildExists(Node{3,16}*, ...) --- ConflictSet.cpp | 55 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index b05d105..8e5a928 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -776,6 +776,18 @@ int getNodeIndex(Node3 *self, uint8_t index) { return -1; } +int getNodeIndexExists(Node3 *self, uint8_t index) { + Node3 *n = (Node3 *)self; + assume(n->numChildren >= 1); + assume(n->numChildren <= 3); + for (int i = 0; i < n->numChildren; ++i) { + if (n->index[i] == index) { + return i; + } + } + __builtin_unreachable(); // GCOVR_EXCL_LINE +} + int getNodeIndex(Node16 *self, uint8_t index) { #ifdef HAS_AVX @@ -836,13 +848,52 @@ int getNodeIndex(Node16 *self, uint8_t index) { #endif } +int getNodeIndexExists(Node16 *self, uint8_t index) { + +#ifdef HAS_AVX + __m128i key_vec = _mm_set1_epi8(index); + __m128i indices; + memcpy(&indices, self->index, Node16::kMaxNodes); + __m128i results = _mm_cmpeq_epi8(key_vec, indices); + uint32_t mask = (1 << self->numChildren) - 1; + uint32_t bitfield = _mm_movemask_epi8(results) & mask; + assume(bitfield != 0); + return std::countr_zero(bitfield); +#elif defined(HAS_ARM_NEON) + // Based on + // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon + + uint8x16_t indices; + memcpy(&indices, self->index, Node16::kMaxNodes); + // 0xff for each match + uint16x8_t results = + vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(index), indices)); + assume(self->numChildren <= Node16::kMaxNodes); + uint64_t mask = self->numChildren == 16 + ? uint64_t(-1) + : (uint64_t(1) << (self->numChildren * 4)) - 1; + // 0xf for each match in valid range + uint64_t bitfield = + vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) & mask; + assume(bitfield != 0); + return std::countr_zero(bitfield) / 4; +#else + for (int i = 0; i < self->numChildren; ++i) { + if (self->index[i] == index) { + return i; + } + } + __builtin_unreachable(); // GCOVR_EXCL_LINE +#endif +} + // Precondition - an entry for index must exist in the node Node *&getChildExists(Node3 *self, uint8_t index) { - return self->children[getNodeIndex(self, index)]; + return self->children[getNodeIndexExists(self, index)]; } // Precondition - an entry for index must exist in the node Node *&getChildExists(Node16 *self, uint8_t index) { - return self->children[getNodeIndex(self, index)]; + return self->children[getNodeIndexExists(self, index)]; } // Precondition - an entry for index must exist in the node Node *&getChildExists(Node48 *self, uint8_t index) {