diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 277d520..d192bc4 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -1773,22 +1773,60 @@ void eraseBetween(Node **inTree, Node3 *n, int begin, int end, void eraseBetween(Node **inTree, Node16 *n, int begin, int end, WriteContext *tls) { + if (end - begin == 256) { + for (int i = 0; i < n->numChildren; ++i) { + eraseTree(n->children[i], tls); + } + n->numChildren = 0; + auto *newNode = tls->allocate(n->partialKeyLen); + newNode->copyChildrenAndKeyFrom(*n); + tls->release(n); + *inTree = newNode; + return; + } + assert(end - begin < 256); + +#ifdef HAS_ARM_NEON + uint8x16_t indices; + memcpy(&indices, n->index, 16); + // 0xff for each in bounds + auto results = + vcltq_u8(vsubq_u8(indices, vdupq_n_u8(begin)), vdupq_n_u8(end - begin)); + // 0xf for each 0xff + uint64_t mask = vget_lane_u64( + vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(results), 4)), 0); +#elif defined(HAS_AVX) + __m128i indices; + memcpy(&indices, n->index, 16); + indices = _mm_sub_epi8(indices, _mm_set1_epi8(begin)); + uint32_t mask = ~_mm_movemask_epi8(_mm_cmpeq_epi8( + indices, _mm_max_epu8(indices, _mm_set1_epi8(end - begin)))); +#else const unsigned shiftUpperBound = end - begin; const unsigned shiftAmount = begin; auto inBounds = [&](unsigned c) { return c - shiftAmount < shiftUpperBound; }; - Node **nodeOut = n->children; - uint8_t *indexOut = n->index; - InternalVersionT *maxVOut = n->childMaxVersion; - for (int i = 0; i < n->numChildren; ++i) { - if (inBounds(n->index[i])) { - eraseTree(n->children[i], tls); - } else { - *nodeOut++ = n->children[i]; - *indexOut++ = n->index[i]; - *maxVOut++ = n->childMaxVersion[i]; - } + uint32_t mask = 0; + for (int i = 0; i < 16; ++i) { + mask |= inBounds(is[i]) << i; + } +#endif + mask &= (decltype(mask)(1) << n->numChildren) - 1; + + if (!mask) { + return; + } + + int first = std::countr_zero(mask); + int count = std::popcount(mask); + n->numChildren -= count; + for (int i = first; i < first + count; ++i) { + eraseTree(n->children[i], tls); + } + for (int i = first; i < n->numChildren; ++i) { + n->children[i] = n->children[i + count]; + n->childMaxVersion[i] = n->childMaxVersion[i + count]; + n->index[i] = n->index[i + count]; } - n->numChildren = nodeOut - n->children; if (n->numChildren > Node3::kMaxNodes) { // nop