diff --git a/CMakeLists.txt b/CMakeLists.txt index 9cc6de3..fc87578 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,8 +52,10 @@ add_test(NAME conflict_set_cxx_api_test COMMAND conflict_set_cxx_api_test) target_compile_options(conflict_set_cxx_api_test PRIVATE -Wall -Wextra -Wpedantic -Wunreachable-code -Werror) # fuzz test +set(FUZZ_FLAGS "-fsanitize=fuzzer,address,undefined") include(CheckCXXCompilerFlag) -check_cxx_compiler_flag(HAS_LIB_FUZZER -fsanitize=fuzzer) +set(CMAKE_REQUIRED_LINK_OPTIONS ${FUZZ_FLAGS}) +check_cxx_compiler_flag(${FUZZ_FLAGS} HAS_LIB_FUZZER) if (HAS_LIB_FUZZER) add_executable(conflict_set_fuzz_test ConflictSet.cpp ConflictSet.h) @@ -61,6 +63,6 @@ if (HAS_LIB_FUZZER) # keep asserts for test target_compile_options(conflict_set_fuzz_test PRIVATE -UNDEBUG) target_compile_options(conflict_set_fuzz_test PRIVATE -Wall -Wextra -Wpedantic -Wunreachable-code) - target_compile_options(conflict_set_fuzz_test PRIVATE -fsanitize=fuzzer) - target_link_options(conflict_set_fuzz_test PRIVATE -fsanitize=fuzzer) + target_compile_options(conflict_set_fuzz_test PRIVATE ${FUZZ_FLAGS}) + target_link_options(conflict_set_fuzz_test PRIVATE ${FUZZ_FLAGS}) endif() diff --git a/ConflictSet.cpp b/ConflictSet.cpp index e30f337..03eb792 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -908,21 +908,38 @@ struct __attribute__((__visibility__("hidden"))) ConflictSet::Impl { } void check(const ReadRange *reads, Result *results, int count) const { - Arena arena; - auto *iters = new (arena) Iterator[count]; - auto *begins = new (arena) Key[count]; + int searchCount = 0; for (int i = 0; i < count; ++i) { - begins[i] = reads[i].begin; + if (reads[i].readVersion >= oldestVersion) { + ++searchCount; + } else { + results[i] = ConflictSet::TooOld; + } } - lastLeqMulti(arena, root, std::span(begins, count), iters); - // TODO check non-singleton reads lol + Arena arena; + auto *iters = new (arena) Iterator[searchCount]; + auto *begins = new (arena) Key[searchCount]; + int j = 0; for (int i = 0; i < count; ++i) { - assert(reads[i].end.len == 0); - assert(iters[i].node != nullptr); - if ((iters[i].cmp == 0 - ? iters[i].node->pointVersion - : iters[i].node->rangeVersion) > reads[i].readVersion) { - results[i] = ConflictSet::Conflict; + if (reads[i].readVersion >= oldestVersion) { + begins[j++] = reads[i].begin; + } + } + lastLeqMulti(arena, root, std::span(begins, searchCount), iters); + // TODO check non-singleton reads lol + j = 0; + for (int i = 0; i < count; ++i) { + if (reads[i].readVersion >= oldestVersion) { + assert(reads[i].end.len == 0); + assert(iters[i].node != nullptr); + if ((iters[j].cmp == 0 + ? iters[j].node->pointVersion + : iters[j].node->rangeVersion) > reads[i].readVersion) { + results[i] = ConflictSet::Conflict; + } else { + results[i] = ConflictSet::Commit; + } + ++j; } } } @@ -1226,25 +1243,60 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { while (gArbitrary.hasEntropy()) { Arena arena; - int numWrites = gArbitrary.bounded(10); - int64_t v = ++writeVersion; - auto *writes = new (arena) ConflictSet::WriteRange[numWrites]; - std::set keys; - while (int(keys.size()) < numWrites) { - keys.insert(gRandom.bounded(100)); + { + int numWrites = gArbitrary.bounded(10); + int64_t v = ++writeVersion; + auto *writes = new (arena) ConflictSet::WriteRange[numWrites]; + std::set, ArenaAlloc> keys{ + ArenaAlloc(&arena)}; + while (int(keys.size()) < numWrites) { + keys.insert(gRandom.bounded(100)); + } + auto iter = keys.begin(); + for (int i = 0; i < numWrites; ++i) { + writes[i].begin = toKey(arena, *iter++); + writes[i].end.len = 0; + writes[i].writeVersion = v; + } + cs.addWrites(writes, numWrites); + refImpl.addWrites(writes, numWrites); } - auto iter = keys.begin(); - for (int i = 0; i < numWrites; ++i) { - writes[i].begin = toKey(arena, *iter++); - writes[i].end.len = 0; - writes[i].writeVersion = v; - } - cs.addWrites(writes, numWrites); - refImpl.addWrites(writes, numWrites); bool success = checkCorrectness(cs.root, refImpl); if (!success) { abort(); } + { + int numReads = gArbitrary.bounded(10); + int64_t v = writeVersion - gArbitrary.bounded(10); + auto *reads = new (arena) ConflictSet::ReadRange[numReads]; + std::set, ArenaAlloc> keys{ + ArenaAlloc(&arena)}; + while (int(keys.size()) < numReads) { + keys.insert(gRandom.bounded(100)); + } + auto iter = keys.begin(); + for (int i = 0; i < numReads; ++i) { + reads[i].begin = toKey(arena, *iter++); + reads[i].end.len = 0; + reads[i].readVersion = v; + } + auto *results1 = new (arena) ConflictSet::Result[numReads]; + auto *results2 = new (arena) ConflictSet::Result[numReads]; + cs.check(reads, results1, numReads); + refImpl.check(reads, results2, numReads); + for (int i = 0; i < numReads; ++i) { + if (results1[i] != results2[i]) { + fprintf(stderr, + "Expected %d, got %d for read of %.*s at version %d\n", + results2[i], results1[i], reads[i].begin.len, + reads[i].begin.p, int(reads[i].readVersion)); + std::string referenceLogicalMap; + refImpl.printLogical(referenceLogicalMap); + fprintf(stderr, "Logical map:\n\n%s\n", referenceLogicalMap.c_str()); + abort(); + } + } + } } return 0; }