#pragma once #include #include #include #include #ifdef HAS_AVX #include #elif HAS_ARM_NEON #include #endif #ifndef __SANITIZE_THREAD__ #if defined(__has_feature) #if __has_feature(thread_sanitizer) #define __SANITIZE_THREAD__ #endif #endif #endif #if defined(HAS_AVX) || defined(HAS_ARM_NEON) constexpr int kStride = 64; #else constexpr int kStride = 16; #endif constexpr int kUnrollFactor = 4; inline bool compareStride(const uint8_t *ap, const uint8_t *bp) { #if defined(HAS_ARM_NEON) static_assert(kStride == 64); uint8x16_t x[4]; // GCOVR_EXCL_LINE for (int i = 0; i < 4; ++i) { x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); } auto results = vreinterpretq_u16_u8( vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == uint64_t(-1); #elif defined(HAS_AVX) static_assert(kStride == 64); __m128i x[4]; // GCOVR_EXCL_LINE for (int i = 0; i < 4; ++i) { x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), _mm_loadu_si128((__m128i *)(bp + i * 16))); } auto eq = _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), _mm_and_si128(x[2], x[3]))) == 0xffff; #else // Hope it gets vectorized auto eq = memcmp(ap, bp, kStride) == 0; #endif return eq; } // Precondition: ap[:kStride] != bp[:kStride] inline int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { #if defined(HAS_AVX) static_assert(kStride == 64); uint64_t c[kStride / 16]; // GCOVR_EXCL_LINE for (int i = 0; i < kStride; i += 16) { const auto a = _mm_loadu_si128((__m128i *)(ap + i)); const auto b = _mm_loadu_si128((__m128i *)(bp + i)); const auto compared = _mm_cmpeq_epi8(a, b); c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; } return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); #elif defined(HAS_ARM_NEON) static_assert(kStride == 64); for (int i = 0; i < kStride; i += 16) { // 0xff for each match uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); // 0xf for each mismatch uint64_t bitfield = ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); if (bitfield) { return i + (std::countr_zero(bitfield) >> 2); } } __builtin_unreachable(); // GCOVR_EXCL_LINE #else int i = 0; for (; i < kStride - 1; ++i) { if (*ap++ != *bp++) { break; } } return i; #endif } // This gets covered in local development // GCOVR_EXCL_START #if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__) __attribute__((target("avx512f,avx512bw"))) inline int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { int i = 0; int end = cl & ~63; while (i < end) { const uint64_t eq = _mm512_cmpeq_epi8_mask(_mm512_loadu_epi8(ap), _mm512_loadu_epi8(bp)); if (eq != uint64_t(-1)) { return i + std::countr_one(eq); } i += 64; ap += 64; bp += 64; } if (i < cl) { const uint64_t mask = (uint64_t(1) << (cl - i)) - 1; const uint64_t eq = _mm512_cmpeq_epi8_mask( _mm512_maskz_loadu_epi8(mask, ap), _mm512_maskz_loadu_epi8(mask, bp)); return i + std::countr_one(eq & mask); } assert(i == cl); return i; } __attribute__((target("default"))) #endif // GCOVR_EXCL_STOP inline int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { if (!(cl >= 0)) { __builtin_unreachable(); // GCOVR_EXCL_LINE } int i = 0; int end; // kStride * kUnrollCount at a time end = cl & ~(kStride * kUnrollFactor - 1); while (i < end) { for (int j = 0; j < kUnrollFactor; ++j) { if (!compareStride(ap, bp)) { return i + firstNeqStride(ap, bp); } i += kStride; ap += kStride; bp += kStride; } } // kStride at a time end = cl & ~(kStride - 1); while (i < end) { if (!compareStride(ap, bp)) { return i + firstNeqStride(ap, bp); } i += kStride; ap += kStride; bp += kStride; } // word at a time end = cl & ~(sizeof(uint64_t) - 1); while (i < end) { uint64_t a; // GCOVR_EXCL_LINE uint64_t b; // GCOVR_EXCL_LINE memcpy(&a, ap, 8); memcpy(&b, bp, 8); const auto mismatched = a ^ b; if (mismatched) { return i + std::countr_zero(mismatched) / 8; } i += 8; ap += 8; bp += 8; } // byte at a time while (i < cl) { if (*ap != *bp) { break; } ++ap; ++bp; ++i; } return i; }