diff --git a/ConflictSet.cpp b/ConflictSet.cpp index b383257..a0c0a7b 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -686,6 +686,9 @@ Node *&getChildExists(Node *self, uint8_t index) { } int getChildGeq(Node *self, int child) { + if (child > 255) { + return -1; + } if (self->type == Type::Node4) { auto *self4 = static_cast(self); for (int i = 0; i < self->numChildren; ++i) { @@ -698,6 +701,34 @@ int getChildGeq(Node *self, int child) { } } else if (self->type == Type::Node16) { auto *self16 = static_cast(self); +#ifdef HAS_AVX +// TODO +#elif defined(HAS_ARM_NEON) + uint8x16_t indices; + memcpy(&indices, self16->index, sizeof(self16->index)); + // 0xff for each leq + auto results = vcleq_u8(vdupq_n_u8(child), indices); + uint64_t mask = self->numChildren == 16 + ? uint64_t(-1) + : (uint64_t(1) << (self->numChildren * 4)) - 1; + // 0xf for each 0xff (within mask) + uint64_t bitfield = + vget_lane_u64( + vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), + 0) & + mask; + int simd = + bitfield == 0 ? -1 : self16->index[__builtin_ctzll(bitfield) / 4]; + assert(simd == [&]() -> int { + for (int i = 0; i < self->numChildren; ++i) { + if (self16->index[i] >= child) { + return self16->index[i]; + } + } + return -1; + }()); + return simd; +#else for (int i = 0; i < self->numChildren; ++i) { if (i > 0) { assert(self16->index[i - 1] < self16->index[i]); @@ -706,6 +737,7 @@ int getChildGeq(Node *self, int child) { return self16->index[i]; } } +#endif } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); #if defined(HAS_AVX) || defined(HAS_ARM_NEON) @@ -732,39 +764,19 @@ int getChildGeq(Node *self, int child) { #endif } else { auto *self256 = static_cast(self); - // For some reason gcc can't auto vectorize this, and the plain loop is - // faster. -#if defined(__clang__) - int i = child; - constexpr int kUnrollCount = 8; // Must be a power of two and <= 8 - for (; (i & (kUnrollCount - 1)) != 0; ++i) { - if (self256->children[i]) { - return i; - } - } - for (; i < 256; i += kUnrollCount) { - uint8_t nonNull[kUnrollCount]; - for (int j = 0; j < kUnrollCount; ++j) { - nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0; - } - uint64_t word; - memcpy(&word, nonNull, kUnrollCount); - if (word) { - return i + __builtin_ctzll(word) / 8; - } - } -#else for (int i = child; i < 256; ++i) { if (self256->children[i]) { return i; } } -#endif } return -1; } int getChildLeq(Node *self, int child) { + if (child < 0) { + return -1; + } if (self->type == Type::Node4) { auto *self4 = static_cast(self); for (int i = self->numChildren - 1; i >= 0; --i) { @@ -777,14 +789,41 @@ int getChildLeq(Node *self, int child) { } } else if (self->type == Type::Node16) { auto *self16 = static_cast(self); - for (int i = self->numChildren - 1; i >= 0; --i) { - if (i > 0) { - assert(self16->index[i - 1] < self16->index[i]); +#ifdef HAS_AVX +// TODO +#elif defined(HAS_ARM_NEON) + uint8x16_t indices; + memcpy(&indices, self16->index, sizeof(self16->index)); + // 0xff for each leq + auto results = vcleq_u8(indices, vdupq_n_u8(child)); + uint64_t mask = self->numChildren == 16 + ? uint64_t(-1) + : (uint64_t(1) << (self->numChildren * 4)) - 1; + // 0xf for each 0xff (within mask) + uint64_t bitfield = + vget_lane_u64( + vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), + 0) & + mask; + int simd = + bitfield == 0 ? -1 : self16->index[15 - __builtin_clzll(bitfield) / 4]; + assert(simd == [&]() -> int { + for (int i = self->numChildren - 1; i >= 0; --i) { + if (self16->index[i] <= child) { + return self16->index[i]; + } } + return -1; + }()); + return simd; +#else + for (int i = self->numChildren - 1; i >= 0; --i) { if (self16->index[i] <= child) { return self16->index[i]; } } + return -1; +#endif } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); // TODO the plain loop is faster? @@ -820,37 +859,11 @@ int getChildLeq(Node *self, int child) { #endif } else { auto *self256 = static_cast(self); - // TODO: The plain loop is faster? -#if defined(__clang__) - int i = child; - constexpr int kUnrollCount = 8; // Must be a power of two and <= 8 - for (; (i & (kUnrollCount - 1)) != 0; --i) { - if (self256->children[i]) { - return i; - } - } - if (self256->children[i]) { - return i; - } - i -= kUnrollCount; - for (; i >= 0; i -= kUnrollCount) { - uint8_t nonNull[kUnrollCount]; - for (int j = 0; j < kUnrollCount; ++j) { - nonNull[j] = self256->children[i + j] != nullptr ? 0xff : 0; - } - uint64_t word; - memcpy(&word, nonNull, kUnrollCount); - if (word) { - return i + 7 - __builtin_clzll(word) / 8; - } - } -#else for (int i = child; i >= 0; --i) { if (self256->children[i]) { return i; } } -#endif } return -1; }