Remove "writes are canonical" precondition from addWrites
Some checks failed
weaselab/conflict-set/pipeline/head There was a failure building this commit

This commit is contained in:
2024-07-10 16:42:53 -07:00
parent 34cd210907
commit f19b403f19
3 changed files with 208 additions and 17 deletions

View File

@@ -664,6 +664,44 @@ template <class ConflictSetImpl> struct TestDriver {
fprintf(stderr, "Write @ %" PRId64 "\n", v); fprintf(stderr, "Write @ %" PRId64 "\n", v);
#endif #endif
// Test non-canonical writes
if (numPointWrites > 0) {
int overlaps = arbitrary.bounded(numPointWrites);
for (int i = 0; i < numPointWrites + numRangeWrites && overlaps > 0;
++i) {
if (writes[i].end.len == 0) {
int keyLen = prefixLen + arbitrary.bounded(kMaxKeySuffixLen);
auto *begin = new (arena) uint8_t[keyLen];
memset(begin, prefixByte, prefixLen);
arbitrary.randomBytes(begin + prefixLen, keyLen - prefixLen);
writes[i].end.len = keyLen;
writes[i].end.p = begin;
auto c =
std::span<const uint8_t>(writes[i].begin.p,
writes[i].begin.len) <=>
std::span<const uint8_t>(writes[i].end.p, writes[i].end.len);
if (c > 0) {
using std::swap;
swap(writes[i].begin, writes[i].end);
} else if (c == 0) {
// It's a point write after all, I guess
writes[i].end.len = 0;
}
--overlaps;
}
}
}
if (arbitrary.bounded(2)) {
// Shuffle writes
for (int i = numPointWrites + numRangeWrites - 1; i > 0; --i) {
int j = arbitrary.bounded(i + 1);
if (i != j) {
using std::swap;
swap(writes[i], writes[j]);
}
}
}
CALLGRIND_START_INSTRUMENTATION; CALLGRIND_START_INSTRUMENTATION;
cs.addWrites(writes, numPointWrites + numRangeWrites, v); cs.addWrites(writes, numPointWrites + numRangeWrites, v);
CALLGRIND_STOP_INSTRUMENTATION; CALLGRIND_STOP_INSTRUMENTATION;

View File

@@ -52,6 +52,135 @@ struct KeyRangeRef {
: begin(begin), end(keyAfter(arena, begin)) {} : begin(begin), end(keyAfter(arena, begin)) {}
}; };
struct KeyInfo {
StringRef key;
bool begin;
bool write;
KeyInfo() = default;
KeyInfo(StringRef key, bool begin, bool write)
: key(key), begin(begin), write(write) {}
};
force_inline int extra_ordering(const KeyInfo &ki) {
return ki.begin * 2 + (ki.write ^ ki.begin);
}
// returns true if done with string
force_inline bool getCharacter(const KeyInfo &ki, int character,
int &outputCharacter) {
// normal case
if (character < ki.key.size()) {
outputCharacter = 5 + ki.key.begin()[character];
return false;
}
// termination
if (character == ki.key.size()) {
outputCharacter = 0;
return false;
}
if (character == ki.key.size() + 1) {
// end/begin+read/write relative sorting
outputCharacter = extra_ordering(ki);
return false;
}
outputCharacter = 0;
return true;
}
bool operator<(const KeyInfo &lhs, const KeyInfo &rhs) {
int i = std::min(lhs.key.size(), rhs.key.size());
int c = memcmp(lhs.key.data(), rhs.key.data(), i);
if (c != 0)
return c < 0;
// Always sort shorter keys before longer keys.
if (lhs.key.size() < rhs.key.size()) {
return true;
}
if (lhs.key.size() > rhs.key.size()) {
return false;
}
// When the keys are the same length, use the extra ordering constraint.
return extra_ordering(lhs) < extra_ordering(rhs);
}
bool operator==(const KeyInfo &lhs, const KeyInfo &rhs) {
return !(lhs < rhs || rhs < lhs);
}
void swapSort(std::vector<KeyInfo> &points, int a, int b) {
if (points[b] < points[a]) {
KeyInfo temp;
temp = points[a];
points[a] = points[b];
points[b] = temp;
}
}
struct SortTask {
int begin;
int size;
int character;
SortTask(int begin, int size, int character)
: begin(begin), size(size), character(character) {}
};
void sortPoints(std::vector<KeyInfo> &points) {
std::vector<SortTask> tasks;
std::vector<KeyInfo> newPoints;
std::vector<int> counts;
tasks.emplace_back(0, points.size(), 0);
while (tasks.size()) {
SortTask st = tasks.back();
tasks.pop_back();
if (st.size < 10) {
std::sort(points.begin() + st.begin, points.begin() + st.begin + st.size);
continue;
}
newPoints.resize(st.size);
counts.assign(256 + 5, 0);
// get counts
int c;
bool allDone = true;
for (int i = st.begin; i < st.begin + st.size; i++) {
allDone &= getCharacter(points[i], st.character, c);
counts[c]++;
}
if (allDone)
continue;
// calculate offsets from counts and build next level of tasks
int total = 0;
for (int i = 0; i < counts.size(); i++) {
int temp = counts[i];
if (temp > 1)
tasks.emplace_back(st.begin + total, temp, st.character + 1);
counts[i] = total;
total += temp;
}
// put in their places
for (int i = st.begin; i < st.begin + st.size; i++) {
getCharacter(points[i], st.character, c);
newPoints[counts[c]++] = points[i];
}
// copy back into original points array
for (int i = 0; i < st.size; i++)
points[st.begin + i] = newPoints[i];
}
}
static thread_local uint32_t g_seed = 0; static thread_local uint32_t g_seed = 0;
static inline int skfastrand() { static inline int skfastrand() {
@@ -602,10 +731,40 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
void addWrites(const ConflictSet::WriteRange *writes, int count, void addWrites(const ConflictSet::WriteRange *writes, int count,
int64_t writeVersion) { int64_t writeVersion) {
auto points = std::vector<KeyInfo>(count * 2);
Arena arena;
for (int r = 0; r < count; r++) {
points.emplace_back(StringRef(writes[r].begin.p, writes[r].begin.len),
true, true);
points.emplace_back(
writes[r].end.len > 0
? StringRef{writes[r].end.p, size_t(writes[r].end.len)}
: keyAfter(arena, points.back().key),
false, true);
}
sortPoints(points);
int activeWriteCount = 0;
std::vector<std::pair<StringRef, StringRef>> combinedWriteConflictRanges;
for (const KeyInfo &point : points) {
if (point.write) {
if (point.begin) {
activeWriteCount++;
if (activeWriteCount == 1)
combinedWriteConflictRanges.emplace_back(point.key, StringRef());
} else /*if (point.end)*/ {
activeWriteCount--;
if (activeWriteCount == 0)
combinedWriteConflictRanges.back().second = point.key;
}
}
}
assert(writeVersion >= newestVersion); assert(writeVersion >= newestVersion);
newestVersion = writeVersion; newestVersion = writeVersion;
Arena arena; const int stringCount = combinedWriteConflictRanges.size() * 2;
const int stringCount = count * 2;
const int stripeSize = 16; const int stripeSize = 16;
SkipList::Finger fingers[stripeSize]; SkipList::Finger fingers[stripeSize];
@@ -616,15 +775,13 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
int ss = stringCount - (stripes - 1) * stripeSize; int ss = stringCount - (stripes - 1) * stripeSize;
for (int s = stripes - 1; s >= 0; s--) { for (int s = stripes - 1; s >= 0; s--) {
for (int i = 0; i * 2 < ss; ++i) { for (int i = 0; i * 2 < ss; ++i) {
const auto &w = writes[s * stripeSize / 2 + i]; const auto &w = combinedWriteConflictRanges[s * stripeSize / 2 + i];
#if DEBUG_VERBOSE #if DEBUG_VERBOSE
printf("Write begin: %s\n", printable(w.begin).c_str()); printf("Write begin: %s\n", printable(w.begin).c_str());
fflush(stdout); fflush(stdout);
#endif #endif
values[i * 2] = {w.begin.p, size_t(w.begin.len)}; values[i * 2] = w.first;
values[i * 2 + 1] = w.end.len > 0 values[i * 2 + 1] = w.second;
? StringRef{w.end.p, size_t(w.end.len)}
: keyAfter(arena, values[i * 2]);
keyUpdates += 3; keyUpdates += 3;
} }
skipList.find(values, fingers, temp, ss); skipList.find(values, fingers, temp, ss);

View File

@@ -70,11 +70,9 @@ struct __attribute__((__visibility__("default"))) ConflictSet {
/** The result of checking reads[i] is written in results[i] */ /** The result of checking reads[i] is written in results[i] */
void check(const ReadRange *reads, Result *results, int count) const; void check(const ReadRange *reads, Result *results, int count) const;
/** `writes` must be sorted ascending, and must not have adjacent or /** Reads intersecting writes where readVersion < `writeVersion` will result
* overlapping ranges. Reads intersecting writes where readVersion < * in `Conflict` (or `TooOld`, eventually). `writeVersion` must be greater
* `writeVersion` will result in `Conflict` (or `TooOld`, eventually). * than or equal to all previous write versions. */
* `writeVersion` must be greater than or equal to all previous write
* versions. */
void addWrites(const WriteRange *writes, int count, int64_t writeVersion); void addWrites(const WriteRange *writes, int count, int64_t writeVersion);
/** Reads where readVersion < oldestVersion will result in `TooOld`. Must be /** Reads where readVersion < oldestVersion will result in `TooOld`. Must be
@@ -161,11 +159,9 @@ void ConflictSet_check(const ConflictSet *cs,
const ConflictSet_ReadRange *reads, const ConflictSet_ReadRange *reads,
ConflictSet_Result *results, int count); ConflictSet_Result *results, int count);
/** `writes` must be sorted ascending, and must not have adjacent or /** Reads intersecting writes where readVersion < `writeVersion` will result in
* overlapping ranges. Reads intersecting writes where readVersion < * `Conflict` (or `TooOld`, eventually). `writeVersion` must be greater than or
* `writeVersion` will result in `Conflict` (or `TooOld`, eventually). * equal to all previous write versions. */
* `writeVersion` must be greater than or equal to all previous write versions.
*/
void ConflictSet_addWrites(ConflictSet *cs, void ConflictSet_addWrites(ConflictSet *cs,
const ConflictSet_WriteRange *writes, int count, const ConflictSet_WriteRange *writes, int count,
int64_t writeVersion); int64_t writeVersion);