Some checks failed
Tests / Clang total: 2096, failed: 520, passed: 1576
Tests / Debug total: 2094, failed: 520, passed: 1574
Tests / SIMD fallback total: 2096, passed: 2096
Tests / Release [gcc] total: 2096, passed: 2096
Tests / Release [gcc,aarch64] total: 1564, passed: 1564
Tests / Coverage total: 1574, passed: 1574
weaselab/conflict-set/pipeline/head There was a failure building this commit
177 lines
4.4 KiB
C++
177 lines
4.4 KiB
C++
#pragma once
|
|
|
|
#include <assert.h>
|
|
#include <bit>
|
|
#include <stdint.h>
|
|
#include <string.h>
|
|
|
|
#ifdef HAS_AVX
|
|
#include <immintrin.h>
|
|
#elif HAS_ARM_NEON
|
|
#include <arm_neon.h>
|
|
#endif
|
|
|
|
#if defined(HAS_AVX) || defined(HAS_ARM_NEON)
|
|
constexpr int kStride = 64;
|
|
#else
|
|
constexpr int kStride = 16;
|
|
#endif
|
|
|
|
constexpr int kUnrollFactor = 4;
|
|
|
|
inline bool compareStride(const uint8_t *ap, const uint8_t *bp) {
|
|
#if defined(HAS_ARM_NEON)
|
|
static_assert(kStride == 64);
|
|
uint8x16_t x[4]; // GCOVR_EXCL_LINE
|
|
for (int i = 0; i < 4; ++i) {
|
|
x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16));
|
|
}
|
|
auto results = vreinterpretq_u16_u8(
|
|
vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3])));
|
|
bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) ==
|
|
uint64_t(-1);
|
|
#elif defined(HAS_AVX)
|
|
static_assert(kStride == 64);
|
|
__m128i x[4]; // GCOVR_EXCL_LINE
|
|
for (int i = 0; i < 4; ++i) {
|
|
x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)),
|
|
_mm_loadu_si128((__m128i *)(bp + i * 16)));
|
|
}
|
|
auto eq =
|
|
_mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]),
|
|
_mm_and_si128(x[2], x[3]))) == 0xffff;
|
|
#else
|
|
// Hope it gets vectorized
|
|
auto eq = memcmp(ap, bp, kStride) == 0;
|
|
#endif
|
|
return eq;
|
|
}
|
|
|
|
// Precondition: ap[:kStride] != bp[:kStride]
|
|
inline int firstNeqStride(const uint8_t *ap, const uint8_t *bp) {
|
|
#if defined(HAS_AVX)
|
|
static_assert(kStride == 64);
|
|
uint64_t c[kStride / 16]; // GCOVR_EXCL_LINE
|
|
for (int i = 0; i < kStride; i += 16) {
|
|
const auto a = _mm_loadu_si128((__m128i *)(ap + i));
|
|
const auto b = _mm_loadu_si128((__m128i *)(bp + i));
|
|
const auto compared = _mm_cmpeq_epi8(a, b);
|
|
c[i / 16] = _mm_movemask_epi8(compared) & 0xffff;
|
|
}
|
|
return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48));
|
|
#elif defined(HAS_ARM_NEON)
|
|
static_assert(kStride == 64);
|
|
for (int i = 0; i < kStride; i += 16) {
|
|
// 0xff for each match
|
|
uint16x8_t results =
|
|
vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i)));
|
|
// 0xf for each mismatch
|
|
uint64_t bitfield =
|
|
~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0);
|
|
if (bitfield) {
|
|
return i + (std::countr_zero(bitfield) >> 2);
|
|
}
|
|
}
|
|
__builtin_unreachable(); // GCOVR_EXCL_LINE
|
|
#else
|
|
int i = 0;
|
|
for (; i < kStride - 1; ++i) {
|
|
if (*ap++ != *bp++) {
|
|
break;
|
|
}
|
|
}
|
|
return i;
|
|
#endif
|
|
}
|
|
|
|
// This gets covered in local development
|
|
// GCOVR_EXCL_START
|
|
#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__)
|
|
__attribute__((target("avx512f,avx512bw"))) inline int
|
|
longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
|
|
int i = 0;
|
|
int end = cl & ~63;
|
|
while (i < end) {
|
|
const uint64_t eq =
|
|
_mm512_cmpeq_epi8_mask(_mm512_loadu_epi8(ap), _mm512_loadu_epi8(bp));
|
|
if (eq != uint64_t(-1)) {
|
|
return i + std::countr_one(eq);
|
|
}
|
|
i += 64;
|
|
ap += 64;
|
|
bp += 64;
|
|
}
|
|
if (i < cl) {
|
|
const uint64_t mask = (uint64_t(1) << (cl - i)) - 1;
|
|
const uint64_t eq = _mm512_cmpeq_epi8_mask(
|
|
_mm512_maskz_loadu_epi8(mask, ap), _mm512_maskz_loadu_epi8(mask, bp));
|
|
return i + std::countr_one(eq & mask);
|
|
}
|
|
assert(i == cl);
|
|
return i;
|
|
}
|
|
__attribute__((target("default")))
|
|
#endif
|
|
// GCOVR_EXCL_STOP
|
|
|
|
inline int
|
|
longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) {
|
|
if (!(cl >= 0)) {
|
|
__builtin_unreachable();
|
|
}
|
|
|
|
int i = 0;
|
|
int end;
|
|
|
|
// kStride * kUnrollCount at a time
|
|
end = cl & ~(kStride * kUnrollFactor - 1);
|
|
while (i < end) {
|
|
for (int j = 0; j < kUnrollFactor; ++j) {
|
|
if (!compareStride(ap, bp)) {
|
|
return i + firstNeqStride(ap, bp);
|
|
}
|
|
i += kStride;
|
|
ap += kStride;
|
|
bp += kStride;
|
|
}
|
|
}
|
|
|
|
// kStride at a time
|
|
end = cl & ~(kStride - 1);
|
|
while (i < end) {
|
|
if (!compareStride(ap, bp)) {
|
|
return i + firstNeqStride(ap, bp);
|
|
}
|
|
i += kStride;
|
|
ap += kStride;
|
|
bp += kStride;
|
|
}
|
|
|
|
// word at a time
|
|
end = cl & ~(sizeof(uint64_t) - 1);
|
|
while (i < end) {
|
|
uint64_t a; // GCOVR_EXCL_LINE
|
|
uint64_t b; // GCOVR_EXCL_LINE
|
|
memcpy(&a, ap, 8);
|
|
memcpy(&b, bp, 8);
|
|
const auto mismatched = a ^ b;
|
|
if (mismatched) {
|
|
return i + std::countr_zero(mismatched) / 8;
|
|
}
|
|
i += 8;
|
|
ap += 8;
|
|
bp += 8;
|
|
}
|
|
|
|
// byte at a time
|
|
while (i < cl) {
|
|
if (*ap != *bp) {
|
|
break;
|
|
}
|
|
++ap;
|
|
++bp;
|
|
++i;
|
|
}
|
|
|
|
return i;
|
|
} |