From aa6f237d501a43756bd4c5132e30d2444dbf1eb5 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Tue, 19 Mar 2024 16:27:24 -0700 Subject: [PATCH] Document and test thread safety properties Closes #2 --- .clangd | 2 +- CMakeLists.txt | 15 ++++++++++ FuzzTestDriver.cpp | 24 +++++++++++----- Internal.h | 66 +++++++++++++++++++++++++++++++------------ include/ConflictSet.h | 19 ++++++++++++- 5 files changed, 99 insertions(+), 27 deletions(-) diff --git a/.clangd b/.clangd index 61e5abe..8ba862b 100644 --- a/.clangd +++ b/.clangd @@ -1,2 +1,2 @@ CompileFlags: - Add: [-DENABLE_MAIN, -UNDEBUG, -DENABLE_FUZZ, -fexceptions] + Add: [-DENABLE_MAIN, -UNDEBUG, -DENABLE_FUZZ, -DTHREAD_TEST, -fexceptions] diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a00e64..5932643 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,21 @@ if(BUILD_TESTING) add_test(NAME conflict_set_fuzz_${hash} COMMAND fuzz_driver ${TEST}) endforeach() + # tsan + + if(NOT CMAKE_CROSSCOMPILING) + add_executable(tsan_driver ConflictSet.cpp FuzzTestDriver.cpp) + target_compile_options(tsan_driver PRIVATE ${TEST_FLAGS} -fsanitize=thread) + target_link_options(tsan_driver PRIVATE -fsanitize=thread) + target_compile_definitions(tsan_driver PRIVATE ENABLE_FUZZ THREAD_TEST) + target_include_directories(tsan_driver + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include) + foreach(TEST ${CORPUS_TESTS}) + get_filename_component(hash ${TEST} NAME) + add_test(NAME conflict_set_tsan_${hash} COMMAND tsan_driver ${TEST}) + endforeach() + endif() + add_executable(driver TestDriver.cpp) target_compile_options(driver PRIVATE ${TEST_FLAGS}) target_link_libraries(driver PRIVATE ${PROJECT_NAME}) diff --git a/FuzzTestDriver.cpp b/FuzzTestDriver.cpp index a05c364..5a5c07d 100644 --- a/FuzzTestDriver.cpp +++ b/FuzzTestDriver.cpp @@ -2,15 +2,25 @@ #include #include #include +#include extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size); int main(int argc, char **argv) { - for (int i = 1; i < argc; ++i) { - std::ifstream t(argv[i], std::ios::binary); - std::stringstream buffer; - buffer << t.rdbuf(); - auto str = buffer.str(); - LLVMFuzzerTestOneInput((const uint8_t *)str.data(), str.size()); - } + auto doTest = [&]() { + for (int i = 1; i < argc; ++i) { + std::ifstream t(argv[i], std::ios::binary); + std::stringstream buffer; + buffer << t.rdbuf(); + auto str = buffer.str(); + LLVMFuzzerTestOneInput((const uint8_t *)str.data(), str.size()); + } + }; +#ifdef THREAD_TEST + std::thread thread2{doTest}; +#endif + doTest(); +#ifdef THREAD_TEST + thread2.join(); +#endif } diff --git a/Internal.h b/Internal.h index d6a3e53..186aff6 100644 --- a/Internal.h +++ b/Internal.h @@ -10,10 +10,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -652,32 +654,60 @@ template struct TestDriver { auto *results2 = new (arena) ConflictSet::Result[numPointReads + numRangeReads]; +#ifdef THREAD_TEST + auto *results3 = + new (arena) ConflictSet::Result[numPointReads + numRangeReads]; + std::latch ready{1}; + std::thread thread2{[&]() { + ready.count_down(); + cs.check(reads, results3, numPointReads + numRangeReads); + }}; + ready.wait(); +#endif + CALLGRIND_START_INSTRUMENTATION; cs.check(reads, results1, numPointReads + numRangeReads); CALLGRIND_STOP_INSTRUMENTATION; refImpl.check(reads, results2, numPointReads + numRangeReads); - for (int i = 0; i < numPointReads + numRangeReads; ++i) { - if (results1[i] != results2[i]) { - if (reads[i].end.len == 0) { - fprintf(stderr, - "Expected %s, got %s for read of {%s} at version %" PRId64 - "\n", - resultToStr(results2[i]), resultToStr(results1[i]), - printable(reads[i].begin).c_str(), reads[i].readVersion); - } else { - fprintf( - stderr, - "Expected %s, got %s for read of [%s, %s) at version %" PRId64 - "\n", - resultToStr(results2[i]), resultToStr(results1[i]), - printable(reads[i].begin).c_str(), - printable(reads[i].end).c_str(), reads[i].readVersion); + + auto compareResults = [reads](ConflictSet::Result *results1, + ConflictSet::Result *results2, int count) { + for (int i = 0; i < count; ++i) { + if (results1[i] != results2[i]) { + if (reads[i].end.len == 0) { + fprintf(stderr, + "Expected %s, got %s for read of {%s} at version %" PRId64 + "\n", + resultToStr(results2[i]), resultToStr(results1[i]), + printable(reads[i].begin).c_str(), reads[i].readVersion); + } else { + fprintf( + stderr, + "Expected %s, got %s for read of [%s, %s) at version %" PRId64 + "\n", + resultToStr(results2[i]), resultToStr(results1[i]), + printable(reads[i].begin).c_str(), + printable(reads[i].end).c_str(), reads[i].readVersion); + } + return false; } - ok = false; - return true; } + return true; + }; + + if (!compareResults(results1, results2, numPointReads + numRangeReads)) { + ok = false; + return true; } + +#ifdef THREAD_TEST + thread2.join(); + if (!compareResults(results3, results2, numPointReads + numRangeReads)) { + ok = false; + return true; + } +#endif } return false; } diff --git a/include/ConflictSet.h b/include/ConflictSet.h index 2f24f55..34b2bb0 100644 --- a/include/ConflictSet.h +++ b/include/ConflictSet.h @@ -19,7 +19,15 @@ limitations under the License. #include #ifdef __cplusplus - +/** A data structure for optimistic concurrency control on ranges of + * bitwise-lexicographically-ordered keys. + * + * Thread safety: + * - It's safe to operate on two different ConflictSets in two different + * threads concurrently + * - It's safe to have multiple threads operating on the same ConflictSet + * concurrently if and only if all threads only call `check`. + */ struct __attribute__((__visibility__("default"))) ConflictSet { enum Result { /** The result of a check which does not intersect any conflicting writes */ @@ -92,6 +100,15 @@ private: #else +/** A data structure for optimistic concurrency control on ranges of + * bitwise-lexicographically-ordered keys. + * + * Thread safety: + * - It's safe to operate on two different ConflictSets in two different + * threads concurrently + * - It's safe to have multiple threads operating on the same ConflictSet + * concurrently if and only if all threads only call `check`. + */ typedef struct ConflictSet ConflictSet; typedef enum {