From 122cddb54d0c1a28877157373ef4bc87812fc88c Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Tue, 23 Jan 2024 15:32:45 -0800 Subject: [PATCH] WIP - seems to work for point reads/writes --- CMakeLists.txt | 24 ++++- ConflictSet.cpp | 258 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 216 insertions(+), 66 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 66e2deb..6d54610 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,11 +13,29 @@ else() add_link_options(-Wl,--gc-sections) endif() +include(CheckIncludeFileCXX) +include(CMakePushCheckState) + +cmake_push_check_state() +# Fall back to non-simd implementations if avx isn't available +list(APPEND CMAKE_REQUIRED_FLAGS -mavx) +check_include_file_cxx("immintrin.h" HAS_AVX) +if(HAS_AVX) + add_compile_options(-mavx) + add_compile_definitions(HAS_AVX) +endif() +cmake_pop_check_state() + +check_include_file_cxx("arm_neon.h" HAS_ARM_NEON) +if (HAS_ARM_NEON) + add_compile_definitions(HAS_ARM_NEON) +endif() + add_library(conflict_set SHARED ConflictSet.cpp) target_include_directories(conflict_set PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_compile_options(conflict_set PRIVATE -fno-exceptions -fvisibility=hidden) -target_link_options(conflict_set PRIVATE -nodefaultlibs -lc -fvisibility=hidden) if (CMAKE_BUILD_TYPE STREQUAL Release) + target_link_options(conflict_set PRIVATE -nodefaultlibs -lc -fvisibility=hidden) add_custom_command(TARGET conflict_set POST_BUILD COMMAND ${CMAKE_STRIP} -x $) endif() @@ -25,7 +43,7 @@ if (NOT APPLE) target_link_options(conflict_set PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/linker.map") endif() -set(TEST_FLAGS -Wall -Wextra -Wpedantic -Wunreachable-code -Werror -UNDEBUG) +set(TEST_FLAGS -Wall -Wextra -Wpedantic -Wunreachable-code -UNDEBUG) include(CTest) @@ -42,8 +60,10 @@ add_test(NAME conflict_set_test COMMAND conflict_set_test) # fuzz test set(FUZZ_FLAGS "-fsanitize=fuzzer-no-link,address,undefined") include(CheckCXXCompilerFlag) +cmake_push_check_state() set(CMAKE_REQUIRED_LINK_OPTIONS ${FUZZ_FLAGS}) check_cxx_compiler_flag(${FUZZ_FLAGS} HAS_LIB_FUZZER) +cmake_pop_check_state() if (HAS_LIB_FUZZER) add_executable(conflict_set_fuzz_test ConflictSet.cpp) diff --git a/ConflictSet.cpp b/ConflictSet.cpp index 4773963..47cc445 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -12,7 +13,13 @@ #include #include -#define DEBUG 0 +#ifdef HAS_AVX +#include +#elif defined(HAS_ARM_NEON) +#include +#endif + +#define DEBUG_VERBOSE 0 __attribute__((always_inline)) void *safe_malloc(size_t s) { if (void *p = malloc(s)) { @@ -563,7 +570,7 @@ static int getNodeIndex(Node16 *self, uint8_t index) { } #ifdef HAS_AVX -static int firstNonNeg1(const int8_t x[16]) { +int firstNonNeg1(const int8_t x[16]) { __m128i key_vec = _mm_set1_epi8(-1); __m128i indices; memcpy(&indices, x, 16); @@ -571,12 +578,12 @@ static int firstNonNeg1(const int8_t x[16]) { uint32_t bitfield = _mm_movemask_epi8(results) ^ 0xffff; if (bitfield == 0) return -1; - return __builtin_ctz(bitfield); + return __builtin_clz(bitfield); } #endif #ifdef HAS_ARM_NEON -static int firstNonNeg1(const int8_t x[16]) { +int firstNonNeg1(const int8_t x[16]) { uint8x16_t indices; memcpy(&indices, x, 16); uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(-1), indices)); @@ -717,6 +724,7 @@ int getChildGeq(Node *self, int child) { } int getChildLeq(Node *self, int child) { + // TODO simd if (self->type == Type::Node4) { auto *self4 = static_cast(self); for (int i = self->numChildren - 1; i >= 0; --i) { @@ -739,7 +747,6 @@ int getChildLeq(Node *self, int child) { } } else if (self->type == Type::Node48) { auto *self48 = static_cast(self); - // TODO simd for (int i = child; i >= 0; --i) { if (self48->index[i] >= 0) { assert(self48->children[self48->index[i]] != nullptr); @@ -934,8 +941,8 @@ void debugPrintDot(FILE *file, Node *node) { for (int child = getChildGeq(n, 0); child >= 0; child = getChildGeq(n, child + 1)) { auto *c = getChildExists(n, child); - fprintf(file, " k_%p -> k_%p [label=\"'%c'\"];\n", (void *)n, (void *)c, - child); + fprintf(file, " k_%p -> k_%p [label=\"'%02x'\"];\n", (void *)n, + (void *)c, child); print(c); } } @@ -949,46 +956,100 @@ void debugPrintDot(FILE *file, Node *node) { fprintf(file, "}\n"); } -void printSearchPath(Node *n) { - Arena arena; +Node *nextPhysical(Node *node) { + int index = -1; + for (;;) { + auto nextChild = getChildGeq(node, index + 1); + if (nextChild >= 0) { + return getChildExists(node, nextChild); + } + index = node->parentsIndex; + node = node->parent; + if (node == nullptr) { + return nullptr; + } + } +} + +Node *nextLogical(Node *node) { + for (node = nextPhysical(node); node != nullptr && !node->entryPresent; + node = nextPhysical(node)) + ; + return node; +} + +std::string printable(std::string_view key) { + std::string result; + for (uint8_t c : key) { + result += "x"; + result += "0123456789abcdef"[c / 16]; + result += "0123456789abcdef"[c % 16]; + } + return result; +} + +std::string printable(const Key &key) { + return printable(std::string_view((const char *)key.p, key.len)); +} + +std::string_view getSearchPath(Arena &arena, Node *n) { + if (n->parent == nullptr) { + return {}; + } auto result = vector(arena); for (; n->parent != nullptr; n = n->parent) { result.push_back(n->parentsIndex); } std::reverse(result.begin(), result.end()); - result.push_back(0); - printf("Search path: %s\n", result.data()); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wreturn-stack-address" + return std::string_view((const char *)&result[0], result.size()); // NOLINT +#pragma GCC diagnostic pop +} + +void printLogical(std::string &result, Node *node) { + Arena arena; + for (Node *iter = node; iter != nullptr;) { + auto *next = nextLogical(iter); + std::string key; + for (uint8_t c : getSearchPath(arena, iter)) { + key += "x"; + key += "0123456789abcdef"[c / 16]; + key += "0123456789abcdef"[c % 16]; + } + if (iter->entry.pointVersion == iter->entry.rangeVersion) { + result += key + " -> " + std::to_string(iter->entry.pointVersion) + "\n"; + } else { + result += key + " -> " + std::to_string(iter->entry.pointVersion) + "\n"; + if (next == nullptr || (getSearchPath(arena, next) != + (std::string(getSearchPath(arena, iter)) + + std::string("\x00", 1)))) { + result += + key + "x00 -> " + std::to_string(iter->entry.rangeVersion) + "\n"; + } + } + iter = next; + } } Node *prevPhysical(Node *node) { - // Move up until there's a node at a lower index than the current search path - int selfIndex = 256; - for (;;) { - if (node->parent == nullptr) { - return nullptr; - } - auto prevChild = getChildLeq(node->parent, node->parentsIndex - 1); - if (prevChild >= 0) { - node = getChildExists(node->parent, prevChild); - break; - } else { - node = node->parent; - selfIndex = prevChild; - if (node->entryPresent) { - break; + assert(node->parent != nullptr); + auto prevChild = getChildLeq(node->parent, node->parentsIndex - 1); + assert(prevChild < node->parentsIndex); + if (prevChild >= 0) { + node = getChildExists(node->parent, prevChild); + // Move down the right spine + for (;;) { + auto rightMostChild = getChildLeq(node, 255); + if (rightMostChild >= 0) { + node = getChildExists(node, rightMostChild); + } else { + return node; } } + } else { + return node->parent; } - // Move down the right spine - for (;;) { - auto rightMostChild = getChildLeq(node, selfIndex - 1); - if (rightMostChild >= 0) { - node = getChildExists(node, rightMostChild); - } else { - break; - } - } - return node; } struct Iterator { @@ -996,33 +1057,45 @@ struct Iterator { int cmp; }; -Iterator lastLeq(Node *n, std::span key) { +Iterator lastLeq(Node *n, const std::span key) { + auto remaining = key; for (;;) { - if (key.size() == 0) { + Arena arena; + assert((std::string(getSearchPath(arena, n)) + + std::string((const char *)remaining.data(), remaining.size())) + .ends_with(std::string((const char *)key.data(), key.size()))); + if (remaining.size() == 0) { + // We've found the physical node corresponding to search path `key` if (n->entryPresent) { return {n, 0}; } else { break; } } else { - int c = getChildLeq(n, key[0]); - if (c == key[0]) { + int c = getChildLeq(n, remaining[0]); + if (c == remaining[0]) { n = getChildExists(n, c); - key = key.subspan(1, key.size() - 1); - } else if (c >= 0) { - n = getChildExists(n, c); - break; + remaining = remaining.subspan(1, remaining.size() - 1); } else { + // The physical node corresponding to search path `key` does not exist. + // Let's find the physical node corresponding to the highest search key + // (not necessarily present) less than key + // Move down the right spine + for (;;) { + if (c >= 0) { + n = getChildExists(n, c); + } else { + break; + } + c = getChildLeq(n, 255); + } break; } } } - for (;;) { - if (n->entryPresent) { - break; - } - n = prevPhysical(n); - assert(n != nullptr); + // Iterate backwards along existing physical nodes until we find a present + // entry + for (; !n->entryPresent; n = prevPhysical(n)) { } return {n, -1}; } @@ -1063,6 +1136,11 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { assert(r.end.len == 0); auto [l, c] = lastLeq(root, std::span(r.begin.p, r.begin.len)); +#if DEBUG_VERBOSE && !defined(NDEBUG) + Arena arena; + printf("LastLeq for `%s' got `%s'\n", printable(r.begin).c_str(), + printable(getSearchPath(arena, l)).c_str()); +#endif assert(l != nullptr); assert(l->entryPresent); result[i] = (c == 0 ? l->entry.pointVersion : l->entry.rangeVersion) > @@ -1113,6 +1191,59 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl { int64_t oldestVersion; }; +void checkParentPointers(Node *node, bool &success) { + for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { + auto *child = getChild(node, i); + if (child->parent != node) { + Arena arena; + fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n", + printable(getSearchPath(arena, node)).c_str(), i, + (void *)child->parent, (void *)node); + success = false; + } + checkParentPointers(child, success); + } +} + +int64_t checkMaxVersion(Node *node, bool &success) { + int64_t expected = + node->entryPresent + ? std::max(node->entry.pointVersion, node->entry.rangeVersion) + : std::numeric_limits::lowest(); + for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) { + auto *child = getChild(node, i); + expected = std::max(expected, checkMaxVersion(child, success)); + } + if (node->maxVersion != expected) { + Arena arena; + fprintf(stderr, "%s has max version %d. Expected %d\n", + printable(getSearchPath(arena, node)).c_str(), + int(node->maxVersion), int(expected)); + success = false; + } + return expected; +} + +bool checkCorrectness(Node *node, ReferenceImpl &refImpl) { + bool success = true; + + checkParentPointers(node, success); + + std::string logicalMap; + std::string referenceLogicalMap; + printLogical(logicalMap, node); + refImpl.printLogical(referenceLogicalMap); + if (logicalMap != referenceLogicalMap) { + fprintf(stderr, + "Logical map not equal to reference logical map.\n\nActual:\n" + "%s\nExpected:\n%s\n", + logicalMap.c_str(), referenceLogicalMap.c_str()); + success = false; + } + + return success; +} + // ==================== END IMPLEMENTATION ==================== void ConflictSet::check(const ReadRange *reads, Result *results, @@ -1225,7 +1356,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { } int keyLen = gArbitrary.bounded(8); auto *begin = new (arena) uint8_t[keyLen]; - gArbitrary.randomHex(begin, keyLen); + gArbitrary.randomBytes(begin, keyLen); keys.insert(std::string_view((const char *)begin, keyLen)); } auto iter = keys.begin(); @@ -1234,18 +1365,18 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { writes[i].begin.len = iter->size(); writes[i].end.len = 0; writes[i].writeVersion = v; -#if DEBUG - printf("Write: {%.*s} -> %d\n", writes[i].begin.len, writes[i].begin.p, +#if DEBUG_VERBOSE && !defined(NDEBUG) + printf("Write: {%s} -> %d\n", printable(writes[i].begin).c_str(), int(writes[i].writeVersion)); #endif } cs.addWrites(writes, numWrites); refImpl.addWrites(writes, numWrites); } - // bool success = checkCorrectness(cs.root, refImpl); - // if (!success) { - // abort(); - // } + bool success = checkCorrectness(cs.root, refImpl); + if (!success) { + abort(); + } { int numReads = gArbitrary.bounded(10); int64_t v = writeVersion - gArbitrary.bounded(10); @@ -1258,7 +1389,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { } int keyLen = gArbitrary.bounded(8); auto *begin = new (arena) uint8_t[keyLen]; - gArbitrary.randomHex(begin, keyLen); + gArbitrary.randomBytes(begin, keyLen); keys.insert(std::string_view((const char *)begin, keyLen)); } auto iter = keys.begin(); @@ -1267,8 +1398,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { reads[i].begin.len = iter->size(); reads[i].end.len = 0; reads[i].readVersion = v; -#if DEBUG - printf("Read: {%.*s} at %d\n", reads[i].begin.len, reads[i].begin.p, +#if DEBUG_VERBOSE && !defined(NDEBUG) + printf("Read: {%s} at %d\n", printable(reads[i].begin).c_str(), int(reads[i].readVersion)); #endif } @@ -1278,10 +1409,9 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { refImpl.check(reads, results2, numReads); for (int i = 0; i < numReads; ++i) { if (results1[i] != results2[i]) { - fprintf(stderr, - "Expected %d, got %d for read of %.*s at version %d\n", - results2[i], results1[i], reads[i].begin.len, - reads[i].begin.p, int(reads[i].readVersion)); + fprintf(stderr, "Expected %d, got %d for read of %s at version %d\n", + results2[i], results1[i], printable(reads[i].begin).c_str(), + int(reads[i].readVersion)); std::string referenceLogicalMap; refImpl.printLogical(referenceLogicalMap); fprintf(stderr, "Logical map:\n\n%s\n", referenceLogicalMap.c_str());