WIP - seems to work for point reads/writes

This commit is contained in:
2024-01-23 15:32:45 -08:00
parent 407b9af750
commit 122cddb54d
2 changed files with 216 additions and 66 deletions

View File

@@ -13,11 +13,29 @@ else()
add_link_options(-Wl,--gc-sections)
endif()
include(CheckIncludeFileCXX)
include(CMakePushCheckState)
cmake_push_check_state()
# Fall back to non-simd implementations if avx isn't available
list(APPEND CMAKE_REQUIRED_FLAGS -mavx)
check_include_file_cxx("immintrin.h" HAS_AVX)
if(HAS_AVX)
add_compile_options(-mavx)
add_compile_definitions(HAS_AVX)
endif()
cmake_pop_check_state()
check_include_file_cxx("arm_neon.h" HAS_ARM_NEON)
if (HAS_ARM_NEON)
add_compile_definitions(HAS_ARM_NEON)
endif()
add_library(conflict_set SHARED ConflictSet.cpp)
target_include_directories(conflict_set PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_compile_options(conflict_set PRIVATE -fno-exceptions -fvisibility=hidden)
target_link_options(conflict_set PRIVATE -nodefaultlibs -lc -fvisibility=hidden)
if (CMAKE_BUILD_TYPE STREQUAL Release)
target_link_options(conflict_set PRIVATE -nodefaultlibs -lc -fvisibility=hidden)
add_custom_command(TARGET conflict_set POST_BUILD COMMAND ${CMAKE_STRIP} -x $<TARGET_FILE:conflict_set>)
endif()
@@ -25,7 +43,7 @@ if (NOT APPLE)
target_link_options(conflict_set PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/linker.map")
endif()
set(TEST_FLAGS -Wall -Wextra -Wpedantic -Wunreachable-code -Werror -UNDEBUG)
set(TEST_FLAGS -Wall -Wextra -Wpedantic -Wunreachable-code -UNDEBUG)
include(CTest)
@@ -42,8 +60,10 @@ add_test(NAME conflict_set_test COMMAND conflict_set_test)
# fuzz test
set(FUZZ_FLAGS "-fsanitize=fuzzer-no-link,address,undefined")
include(CheckCXXCompilerFlag)
cmake_push_check_state()
set(CMAKE_REQUIRED_LINK_OPTIONS ${FUZZ_FLAGS})
check_cxx_compiler_flag(${FUZZ_FLAGS} HAS_LIB_FUZZER)
cmake_pop_check_state()
if (HAS_LIB_FUZZER)
add_executable(conflict_set_fuzz_test ConflictSet.cpp)

View File

@@ -4,6 +4,7 @@
#include <cassert>
#include <cstdint>
#include <cstring>
#include <limits>
#include <map>
#include <set>
#include <span>
@@ -12,7 +13,13 @@
#include <utility>
#include <vector>
#define DEBUG 0
#ifdef HAS_AVX
#include <immintrin.h>
#elif defined(HAS_ARM_NEON)
#include <arm_neon.h>
#endif
#define DEBUG_VERBOSE 0
__attribute__((always_inline)) void *safe_malloc(size_t s) {
if (void *p = malloc(s)) {
@@ -563,7 +570,7 @@ static int getNodeIndex(Node16 *self, uint8_t index) {
}
#ifdef HAS_AVX
static int firstNonNeg1(const int8_t x[16]) {
int firstNonNeg1(const int8_t x[16]) {
__m128i key_vec = _mm_set1_epi8(-1);
__m128i indices;
memcpy(&indices, x, 16);
@@ -571,12 +578,12 @@ static int firstNonNeg1(const int8_t x[16]) {
uint32_t bitfield = _mm_movemask_epi8(results) ^ 0xffff;
if (bitfield == 0)
return -1;
return __builtin_ctz(bitfield);
return __builtin_clz(bitfield);
}
#endif
#ifdef HAS_ARM_NEON
static int firstNonNeg1(const int8_t x[16]) {
int firstNonNeg1(const int8_t x[16]) {
uint8x16_t indices;
memcpy(&indices, x, 16);
uint16x8_t results = vreinterpretq_u16_u8(vceqq_u8(vdupq_n_u8(-1), indices));
@@ -717,6 +724,7 @@ int getChildGeq(Node *self, int child) {
}
int getChildLeq(Node *self, int child) {
// TODO simd
if (self->type == Type::Node4) {
auto *self4 = static_cast<Node4 *>(self);
for (int i = self->numChildren - 1; i >= 0; --i) {
@@ -739,7 +747,6 @@ int getChildLeq(Node *self, int child) {
}
} else if (self->type == Type::Node48) {
auto *self48 = static_cast<Node48 *>(self);
// TODO simd
for (int i = child; i >= 0; --i) {
if (self48->index[i] >= 0) {
assert(self48->children[self48->index[i]] != nullptr);
@@ -934,8 +941,8 @@ void debugPrintDot(FILE *file, Node *node) {
for (int child = getChildGeq(n, 0); child >= 0;
child = getChildGeq(n, child + 1)) {
auto *c = getChildExists(n, child);
fprintf(file, " k_%p -> k_%p [label=\"'%c'\"];\n", (void *)n, (void *)c,
child);
fprintf(file, " k_%p -> k_%p [label=\"'%02x'\"];\n", (void *)n,
(void *)c, child);
print(c);
}
}
@@ -949,46 +956,100 @@ void debugPrintDot(FILE *file, Node *node) {
fprintf(file, "}\n");
}
void printSearchPath(Node *n) {
Arena arena;
Node *nextPhysical(Node *node) {
int index = -1;
for (;;) {
auto nextChild = getChildGeq(node, index + 1);
if (nextChild >= 0) {
return getChildExists(node, nextChild);
}
index = node->parentsIndex;
node = node->parent;
if (node == nullptr) {
return nullptr;
}
}
}
Node *nextLogical(Node *node) {
for (node = nextPhysical(node); node != nullptr && !node->entryPresent;
node = nextPhysical(node))
;
return node;
}
std::string printable(std::string_view key) {
std::string result;
for (uint8_t c : key) {
result += "x";
result += "0123456789abcdef"[c / 16];
result += "0123456789abcdef"[c % 16];
}
return result;
}
std::string printable(const Key &key) {
return printable(std::string_view((const char *)key.p, key.len));
}
std::string_view getSearchPath(Arena &arena, Node *n) {
if (n->parent == nullptr) {
return {};
}
auto result = vector<char>(arena);
for (; n->parent != nullptr; n = n->parent) {
result.push_back(n->parentsIndex);
}
std::reverse(result.begin(), result.end());
result.push_back(0);
printf("Search path: %s\n", result.data());
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wreturn-stack-address"
return std::string_view((const char *)&result[0], result.size()); // NOLINT
#pragma GCC diagnostic pop
}
void printLogical(std::string &result, Node *node) {
Arena arena;
for (Node *iter = node; iter != nullptr;) {
auto *next = nextLogical(iter);
std::string key;
for (uint8_t c : getSearchPath(arena, iter)) {
key += "x";
key += "0123456789abcdef"[c / 16];
key += "0123456789abcdef"[c % 16];
}
if (iter->entry.pointVersion == iter->entry.rangeVersion) {
result += key + " -> " + std::to_string(iter->entry.pointVersion) + "\n";
} else {
result += key + " -> " + std::to_string(iter->entry.pointVersion) + "\n";
if (next == nullptr || (getSearchPath(arena, next) !=
(std::string(getSearchPath(arena, iter)) +
std::string("\x00", 1)))) {
result +=
key + "x00 -> " + std::to_string(iter->entry.rangeVersion) + "\n";
}
}
iter = next;
}
}
Node *prevPhysical(Node *node) {
// Move up until there's a node at a lower index than the current search path
int selfIndex = 256;
for (;;) {
if (node->parent == nullptr) {
return nullptr;
}
auto prevChild = getChildLeq(node->parent, node->parentsIndex - 1);
if (prevChild >= 0) {
node = getChildExists(node->parent, prevChild);
break;
} else {
node = node->parent;
selfIndex = prevChild;
if (node->entryPresent) {
break;
assert(node->parent != nullptr);
auto prevChild = getChildLeq(node->parent, node->parentsIndex - 1);
assert(prevChild < node->parentsIndex);
if (prevChild >= 0) {
node = getChildExists(node->parent, prevChild);
// Move down the right spine
for (;;) {
auto rightMostChild = getChildLeq(node, 255);
if (rightMostChild >= 0) {
node = getChildExists(node, rightMostChild);
} else {
return node;
}
}
} else {
return node->parent;
}
// Move down the right spine
for (;;) {
auto rightMostChild = getChildLeq(node, selfIndex - 1);
if (rightMostChild >= 0) {
node = getChildExists(node, rightMostChild);
} else {
break;
}
}
return node;
}
struct Iterator {
@@ -996,33 +1057,45 @@ struct Iterator {
int cmp;
};
Iterator lastLeq(Node *n, std::span<const uint8_t> key) {
Iterator lastLeq(Node *n, const std::span<const uint8_t> key) {
auto remaining = key;
for (;;) {
if (key.size() == 0) {
Arena arena;
assert((std::string(getSearchPath(arena, n)) +
std::string((const char *)remaining.data(), remaining.size()))
.ends_with(std::string((const char *)key.data(), key.size())));
if (remaining.size() == 0) {
// We've found the physical node corresponding to search path `key`
if (n->entryPresent) {
return {n, 0};
} else {
break;
}
} else {
int c = getChildLeq(n, key[0]);
if (c == key[0]) {
int c = getChildLeq(n, remaining[0]);
if (c == remaining[0]) {
n = getChildExists(n, c);
key = key.subspan(1, key.size() - 1);
} else if (c >= 0) {
n = getChildExists(n, c);
break;
remaining = remaining.subspan(1, remaining.size() - 1);
} else {
// The physical node corresponding to search path `key` does not exist.
// Let's find the physical node corresponding to the highest search key
// (not necessarily present) less than key
// Move down the right spine
for (;;) {
if (c >= 0) {
n = getChildExists(n, c);
} else {
break;
}
c = getChildLeq(n, 255);
}
break;
}
}
}
for (;;) {
if (n->entryPresent) {
break;
}
n = prevPhysical(n);
assert(n != nullptr);
// Iterate backwards along existing physical nodes until we find a present
// entry
for (; !n->entryPresent; n = prevPhysical(n)) {
}
return {n, -1};
}
@@ -1063,6 +1136,11 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
assert(r.end.len == 0);
auto [l, c] =
lastLeq(root, std::span<const uint8_t>(r.begin.p, r.begin.len));
#if DEBUG_VERBOSE && !defined(NDEBUG)
Arena arena;
printf("LastLeq for `%s' got `%s'\n", printable(r.begin).c_str(),
printable(getSearchPath(arena, l)).c_str());
#endif
assert(l != nullptr);
assert(l->entryPresent);
result[i] = (c == 0 ? l->entry.pointVersion : l->entry.rangeVersion) >
@@ -1113,6 +1191,59 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
int64_t oldestVersion;
};
void checkParentPointers(Node *node, bool &success) {
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChild(node, i);
if (child->parent != node) {
Arena arena;
fprintf(stderr, "%s child %d has parent pointer %p. Expected %p\n",
printable(getSearchPath(arena, node)).c_str(), i,
(void *)child->parent, (void *)node);
success = false;
}
checkParentPointers(child, success);
}
}
int64_t checkMaxVersion(Node *node, bool &success) {
int64_t expected =
node->entryPresent
? std::max(node->entry.pointVersion, node->entry.rangeVersion)
: std::numeric_limits<int64_t>::lowest();
for (int i = getChildGeq(node, 0); i >= 0; i = getChildGeq(node, i + 1)) {
auto *child = getChild(node, i);
expected = std::max(expected, checkMaxVersion(child, success));
}
if (node->maxVersion != expected) {
Arena arena;
fprintf(stderr, "%s has max version %d. Expected %d\n",
printable(getSearchPath(arena, node)).c_str(),
int(node->maxVersion), int(expected));
success = false;
}
return expected;
}
bool checkCorrectness(Node *node, ReferenceImpl &refImpl) {
bool success = true;
checkParentPointers(node, success);
std::string logicalMap;
std::string referenceLogicalMap;
printLogical(logicalMap, node);
refImpl.printLogical(referenceLogicalMap);
if (logicalMap != referenceLogicalMap) {
fprintf(stderr,
"Logical map not equal to reference logical map.\n\nActual:\n"
"%s\nExpected:\n%s\n",
logicalMap.c_str(), referenceLogicalMap.c_str());
success = false;
}
return success;
}
// ==================== END IMPLEMENTATION ====================
void ConflictSet::check(const ReadRange *reads, Result *results,
@@ -1225,7 +1356,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
}
int keyLen = gArbitrary.bounded(8);
auto *begin = new (arena) uint8_t[keyLen];
gArbitrary.randomHex(begin, keyLen);
gArbitrary.randomBytes(begin, keyLen);
keys.insert(std::string_view((const char *)begin, keyLen));
}
auto iter = keys.begin();
@@ -1234,18 +1365,18 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
writes[i].begin.len = iter->size();
writes[i].end.len = 0;
writes[i].writeVersion = v;
#if DEBUG
printf("Write: {%.*s} -> %d\n", writes[i].begin.len, writes[i].begin.p,
#if DEBUG_VERBOSE && !defined(NDEBUG)
printf("Write: {%s} -> %d\n", printable(writes[i].begin).c_str(),
int(writes[i].writeVersion));
#endif
}
cs.addWrites(writes, numWrites);
refImpl.addWrites(writes, numWrites);
}
// bool success = checkCorrectness(cs.root, refImpl);
// if (!success) {
// abort();
// }
bool success = checkCorrectness(cs.root, refImpl);
if (!success) {
abort();
}
{
int numReads = gArbitrary.bounded(10);
int64_t v = writeVersion - gArbitrary.bounded(10);
@@ -1258,7 +1389,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
}
int keyLen = gArbitrary.bounded(8);
auto *begin = new (arena) uint8_t[keyLen];
gArbitrary.randomHex(begin, keyLen);
gArbitrary.randomBytes(begin, keyLen);
keys.insert(std::string_view((const char *)begin, keyLen));
}
auto iter = keys.begin();
@@ -1267,8 +1398,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
reads[i].begin.len = iter->size();
reads[i].end.len = 0;
reads[i].readVersion = v;
#if DEBUG
printf("Read: {%.*s} at %d\n", reads[i].begin.len, reads[i].begin.p,
#if DEBUG_VERBOSE && !defined(NDEBUG)
printf("Read: {%s} at %d\n", printable(reads[i].begin).c_str(),
int(reads[i].readVersion));
#endif
}
@@ -1278,10 +1409,9 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
refImpl.check(reads, results2, numReads);
for (int i = 0; i < numReads; ++i) {
if (results1[i] != results2[i]) {
fprintf(stderr,
"Expected %d, got %d for read of %.*s at version %d\n",
results2[i], results1[i], reads[i].begin.len,
reads[i].begin.p, int(reads[i].readVersion));
fprintf(stderr, "Expected %d, got %d for read of %s at version %d\n",
results2[i], results1[i], printable(reads[i].begin).c_str(),
int(reads[i].readVersion));
std::string referenceLogicalMap;
refImpl.printLogical(referenceLogicalMap);
fprintf(stderr, "Logical map:\n\n%s\n", referenceLogicalMap.c_str());