diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 62b6fb9..7004be4 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -16,12 +16,12 @@ limitations under the License. #include "ConflictSet.h" #include "Internal.h" +#include "LongestCommonPrefix.h" #include #include #include #include -#include #include #include #include @@ -1687,167 +1687,6 @@ Node *nextSibling(Node *node) { } } -#if defined(HAS_AVX) || defined(HAS_ARM_NEON) -constexpr int kStride = 64; -#else -constexpr int kStride = 16; -#endif - -constexpr int kUnrollFactor = 4; - -bool compareStride(const uint8_t *ap, const uint8_t *bp) { -#if defined(HAS_ARM_NEON) - static_assert(kStride == 64); - uint8x16_t x[4]; // GCOVR_EXCL_LINE - for (int i = 0; i < 4; ++i) { - x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); - } - auto results = vreinterpretq_u16_u8( - vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); - bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == - uint64_t(-1); -#elif defined(HAS_AVX) - static_assert(kStride == 64); - __m128i x[4]; // GCOVR_EXCL_LINE - for (int i = 0; i < 4; ++i) { - x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), - _mm_loadu_si128((__m128i *)(bp + i * 16))); - } - auto eq = - _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), - _mm_and_si128(x[2], x[3]))) == 0xffff; -#else - // Hope it gets vectorized - auto eq = memcmp(ap, bp, kStride) == 0; -#endif - return eq; -} - -// Precondition: ap[:kStride] != bp[:kStride] -int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { -#if defined(HAS_AVX) - static_assert(kStride == 64); - uint64_t c[kStride / 16]; // GCOVR_EXCL_LINE - for (int i = 0; i < kStride; i += 16) { - const auto a = _mm_loadu_si128((__m128i *)(ap + i)); - const auto b = _mm_loadu_si128((__m128i *)(bp + i)); - const auto compared = _mm_cmpeq_epi8(a, b); - c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; - } - return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); -#elif defined(HAS_ARM_NEON) - static_assert(kStride == 64); - for (int i = 0; i < kStride; i += 16) { - // 0xff for each match - uint16x8_t results = - vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); - // 0xf for each mismatch - uint64_t bitfield = - ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); - if (bitfield) { - return i + (std::countr_zero(bitfield) >> 2); - } - } - __builtin_unreachable(); // GCOVR_EXCL_LINE -#else - int i = 0; - for (; i < kStride - 1; ++i) { - if (*ap++ != *bp++) { - break; - } - } - return i; -#endif -} - -// This gets covered in local development -// GCOVR_EXCL_START -#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__) -__attribute__((target("avx512f,avx512bw"))) int -longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { - int i = 0; - int end = cl & ~63; - while (i < end) { - const uint64_t eq = - _mm512_cmpeq_epi8_mask(_mm512_loadu_epi8(ap), _mm512_loadu_epi8(bp)); - if (eq != uint64_t(-1)) { - return i + std::countr_one(eq); - } - i += 64; - ap += 64; - bp += 64; - } - if (i < cl) { - const uint64_t mask = (uint64_t(1) << (cl - i)) - 1; - const uint64_t eq = _mm512_cmpeq_epi8_mask( - _mm512_maskz_loadu_epi8(mask, ap), _mm512_maskz_loadu_epi8(mask, bp)); - return i + std::countr_one(eq & mask); - } - assert(i == cl); - return i; -} -__attribute__((target("default"))) -#endif -// GCOVR_EXCL_STOP - -int longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { - assume(cl >= 0); - int i = 0; - int end; - - // kStride * kUnrollCount at a time - end = cl & ~(kStride * kUnrollFactor - 1); - while (i < end) { - for (int j = 0; j < kUnrollFactor; ++j) { - if (!compareStride(ap, bp)) { - return i + firstNeqStride(ap, bp); - } - i += kStride; - ap += kStride; - bp += kStride; - } - } - - // kStride at a time - end = cl & ~(kStride - 1); - while (i < end) { - if (!compareStride(ap, bp)) { - return i + firstNeqStride(ap, bp); - } - i += kStride; - ap += kStride; - bp += kStride; - } - - // word at a time - end = cl & ~(sizeof(uint64_t) - 1); - while (i < end) { - uint64_t a; // GCOVR_EXCL_LINE - uint64_t b; // GCOVR_EXCL_LINE - memcpy(&a, ap, 8); - memcpy(&b, bp, 8); - const auto mismatched = a ^ b; - if (mismatched) { - return i + std::countr_zero(mismatched) / 8; - } - i += 8; - ap += 8; - bp += 8; - } - - // byte at a time - while (i < cl) { - if (*ap != *bp) { - break; - } - ++ap; - ++bp; - ++i; - } - - return i; -} - // Logically this is the same as performing firstGeq and then checking against // point or range version according to cmp, but this version short circuits as // soon as it can prove that there's no conflict. diff --git a/Jenkinsfile b/Jenkinsfile index a6a19c6..b9ed422 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -117,15 +117,18 @@ pipeline { } } steps { + script { + sources = "ConflictSet.cpp LongestCommonPrefix.h" + } CleanBuildAndTest("-DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_C_FLAGS=--coverage -DCMAKE_CXX_FLAGS=--coverage -DCMAKE_BUILD_TYPE=Debug -DDISABLE_TSAN=ON") - sh ''' - gcovr -f ConflictSet.cpp --cobertura > build/coverage.xml - ''' + sh """ + gcovr -f ${sources} --cobertura > build/coverage.xml + """ recordCoverage qualityGates: [[criticality: 'NOTE', metric: 'MODULE']], tools: [[parser: 'COBERTURA', pattern: 'build/coverage.xml']] - sh ''' - gcovr -f ConflictSet.cpp - gcovr -f ConflictSet.cpp --fail-under-line 100 > /dev/null - ''' + sh """ + gcovr -f ${sources} + gcovr -f ${sources} --fail-under-line 100 > /dev/null + """ } } } diff --git a/LongestCommonPrefix.h b/LongestCommonPrefix.h new file mode 100644 index 0000000..2d1d3b9 --- /dev/null +++ b/LongestCommonPrefix.h @@ -0,0 +1,177 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef HAS_AVX +#include +#elif HAS_ARM_NEON +#include +#endif + +#if defined(HAS_AVX) || defined(HAS_ARM_NEON) +constexpr int kStride = 64; +#else +constexpr int kStride = 16; +#endif + +constexpr int kUnrollFactor = 4; + +inline bool compareStride(const uint8_t *ap, const uint8_t *bp) { +#if defined(HAS_ARM_NEON) + static_assert(kStride == 64); + uint8x16_t x[4]; // GCOVR_EXCL_LINE + for (int i = 0; i < 4; ++i) { + x[i] = vceqq_u8(vld1q_u8(ap + i * 16), vld1q_u8(bp + i * 16)); + } + auto results = vreinterpretq_u16_u8( + vandq_u8(vandq_u8(x[0], x[1]), vandq_u8(x[2], x[3]))); + bool eq = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0) == + uint64_t(-1); +#elif defined(HAS_AVX) + static_assert(kStride == 64); + __m128i x[4]; // GCOVR_EXCL_LINE + for (int i = 0; i < 4; ++i) { + x[i] = _mm_cmpeq_epi8(_mm_loadu_si128((__m128i *)(ap + i * 16)), + _mm_loadu_si128((__m128i *)(bp + i * 16))); + } + auto eq = + _mm_movemask_epi8(_mm_and_si128(_mm_and_si128(x[0], x[1]), + _mm_and_si128(x[2], x[3]))) == 0xffff; +#else + // Hope it gets vectorized + auto eq = memcmp(ap, bp, kStride) == 0; +#endif + return eq; +} + +// Precondition: ap[:kStride] != bp[:kStride] +inline int firstNeqStride(const uint8_t *ap, const uint8_t *bp) { +#if defined(HAS_AVX) + static_assert(kStride == 64); + uint64_t c[kStride / 16]; // GCOVR_EXCL_LINE + for (int i = 0; i < kStride; i += 16) { + const auto a = _mm_loadu_si128((__m128i *)(ap + i)); + const auto b = _mm_loadu_si128((__m128i *)(bp + i)); + const auto compared = _mm_cmpeq_epi8(a, b); + c[i / 16] = _mm_movemask_epi8(compared) & 0xffff; + } + return std::countr_zero(~(c[0] | c[1] << 16 | c[2] << 32 | c[3] << 48)); +#elif defined(HAS_ARM_NEON) + static_assert(kStride == 64); + for (int i = 0; i < kStride; i += 16) { + // 0xff for each match + uint16x8_t results = + vreinterpretq_u16_u8(vceqq_u8(vld1q_u8(ap + i), vld1q_u8(bp + i))); + // 0xf for each mismatch + uint64_t bitfield = + ~vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(results, 4)), 0); + if (bitfield) { + return i + (std::countr_zero(bitfield) >> 2); + } + } + __builtin_unreachable(); // GCOVR_EXCL_LINE +#else + int i = 0; + for (; i < kStride - 1; ++i) { + if (*ap++ != *bp++) { + break; + } + } + return i; +#endif +} + +// This gets covered in local development +// GCOVR_EXCL_START +#if defined(HAS_AVX) && !defined(__SANITIZE_THREAD__) +__attribute__((target("avx512f,avx512bw"))) inline int +longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { + int i = 0; + int end = cl & ~63; + while (i < end) { + const uint64_t eq = + _mm512_cmpeq_epi8_mask(_mm512_loadu_epi8(ap), _mm512_loadu_epi8(bp)); + if (eq != uint64_t(-1)) { + return i + std::countr_one(eq); + } + i += 64; + ap += 64; + bp += 64; + } + if (i < cl) { + const uint64_t mask = (uint64_t(1) << (cl - i)) - 1; + const uint64_t eq = _mm512_cmpeq_epi8_mask( + _mm512_maskz_loadu_epi8(mask, ap), _mm512_maskz_loadu_epi8(mask, bp)); + return i + std::countr_one(eq & mask); + } + assert(i == cl); + return i; +} +__attribute__((target("default"))) +#endif +// GCOVR_EXCL_STOP + +inline int +longestCommonPrefix(const uint8_t *ap, const uint8_t *bp, int cl) { + if (!(cl >= 0)) { + __builtin_unreachable(); + } + + int i = 0; + int end; + + // kStride * kUnrollCount at a time + end = cl & ~(kStride * kUnrollFactor - 1); + while (i < end) { + for (int j = 0; j < kUnrollFactor; ++j) { + if (!compareStride(ap, bp)) { + return i + firstNeqStride(ap, bp); + } + i += kStride; + ap += kStride; + bp += kStride; + } + } + + // kStride at a time + end = cl & ~(kStride - 1); + while (i < end) { + if (!compareStride(ap, bp)) { + return i + firstNeqStride(ap, bp); + } + i += kStride; + ap += kStride; + bp += kStride; + } + + // word at a time + end = cl & ~(sizeof(uint64_t) - 1); + while (i < end) { + uint64_t a; // GCOVR_EXCL_LINE + uint64_t b; // GCOVR_EXCL_LINE + memcpy(&a, ap, 8); + memcpy(&b, bp, 8); + const auto mismatched = a ^ b; + if (mismatched) { + return i + std::countr_zero(mismatched) / 8; + } + i += 8; + ap += 8; + bp += 8; + } + + // byte at a time + while (i < cl) { + if (*ap != *bp) { + break; + } + ++ap; + ++bp; + ++i; + } + + return i; +} \ No newline at end of file