Implement "phase 1" of interleaved point writes

This commit is contained in:
2024-10-28 16:02:56 -07:00
parent f1292efe41
commit 2706b2f65e

View File

@@ -4057,6 +4057,125 @@ void Job::init(const ConflictSet::ReadRange *read, ConflictSet::Result *result,
}
} // namespace check
namespace interleaved_insert {
typedef PRESERVE_NONE void (*Continuation)(struct Job *, struct Context *);
// State relevant to an individual insertion
struct Job {
std::span<const uint8_t> remaining;
Node *n;
int index;
// State for context switching machinery - not application specific
Continuation continuation;
Job *prev;
Job *next;
void init(Context *, int index);
};
// Result of an insertion. The search path of insertionPoint + remaining == the
// original key, and there is existing node in the tree further along the search
// path of the original key
struct Result {
Node *insertionPoint;
std::span<const uint8_t> remaining;
};
// State relevant to every insertion
struct Context {
int count;
int64_t started;
const ConflictSet::WriteRange *writes;
Node *root;
InternalVersionT writeVersion;
Result *results;
};
PRESERVE_NONE void keepGoing(Job *job, Context *context) {
fprintf(stderr, "search path: %s, Remaining: %s\n",
getSearchPathPrintable(job->n).c_str(),
printable(job->remaining).c_str());
job = job->next;
MUSTTAIL return job->continuation(job, context);
}
PRESERVE_NONE void complete(Job *job, Context *context) {
fprintf(stderr, "search path: %s, Remaining: %s\n",
getSearchPathPrintable(job->n).c_str(),
printable(job->remaining).c_str());
if (context->started == context->count) {
if (job->prev == job) {
return;
}
job->prev->next = job->next;
job->next->prev = job->prev;
job = job->prev;
} else {
int temp = context->started++;
job->init(context, temp);
}
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT> PRESERVE_NONE void iter(Job *, Context *);
static Continuation iterTable[] = {iter<Node0>, iter<Node3>, iter<Node16>,
iter<Node48>, iter<Node256>};
PRESERVE_NONE void begin(Job *job, Context *context) {
if (job->remaining.size() == 0) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
TaggedNodePointer *child =
getChildUpdatingMaxVersion(job->n, job->remaining, context->writeVersion);
if (child == nullptr) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
job->n = *child;
if (job->remaining.size() == 0) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
job->continuation = iterTable[child->getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
template <class NodeT> void iter(Job *job, Context *context) {
assert(NodeT::kType == job->n->getType());
NodeT *n = static_cast<NodeT *>(job->n);
TaggedNodePointer *child =
getChildUpdatingMaxVersion(n, job->remaining, context->writeVersion);
if (child == nullptr) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
job->n = *child;
if (job->remaining.size() == 0) [[unlikely]] {
context->results[job->index] = {job->n, job->remaining};
MUSTTAIL return complete(job, context);
}
job->continuation = iterTable[child->getType()];
__builtin_prefetch(job->n);
MUSTTAIL return keepGoing(job, context);
}
void Job::init(Context *context, int index) {
this->index = index;
this->continuation = interleaved_insert::begin;
this->remaining = std::span<const uint8_t>(context->writes[index].begin.p,
context->writes[index].begin.len);
this->n = context->root;
}
} // namespace interleaved_insert
// Sequential implementations
namespace {
// Logically this is the same as performing firstGeq and then checking against
@@ -4583,6 +4702,50 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
check_bytes_total.add(check_byte_accum);
}
void interleavedPointWrites(const WriteRange *writes, int count,
InternalVersionT writeVersion) {
// Phase 1: Search for insertion points concurrently, without modifying the
// structure of the tree.
if (count == 0) {
return;
}
#if __has_attribute(preserve_none) && __has_attribute(musttail)
constexpr int kConcurrent = 16;
interleaved_insert::Job inProgress[kConcurrent];
interleaved_insert::Context context;
context.writeVersion = writeVersion;
context.count = count;
context.root = root;
context.writes = writes;
context.results = (interleaved_insert::Result *)safe_malloc(
sizeof(interleaved_insert::Result) * count);
int64_t started = std::min(kConcurrent, count);
context.started = started;
for (int i = 0; i < started; i++) {
inProgress[i].init(&context, i);
}
for (int i = 0; i < started - 1; i++) {
inProgress[i].next = inProgress + i + 1;
}
for (int i = 1; i < started; i++) {
inProgress[i].prev = inProgress + i - 1;
}
inProgress[0].prev = inProgress + started - 1;
inProgress[started - 1].next = inProgress;
// Kick off the sequence of tail calls that finally returns once all jobs
// are done
inProgress->continuation(inProgress, &context);
#endif
// Phase 2: Perform insertions. Nodes may be upsized during this phase, but
// old nodes get forwarding pointers installed and are released after
// phase 2.
}
void addWrites(const WriteRange *writes, int count, int64_t writeVersion) {
#if !USE_64_BIT
// There could be other conflict sets in the same thread. We need
@@ -4624,17 +4787,28 @@ struct __attribute__((visibility("hidden"))) ConflictSet::Impl {
}
}
bool allPointWrites = true;
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
writeContext.accum.write_bytes += w.begin.len + w.end.len;
auto begin = std::span<const uint8_t>(w.begin.p, w.begin.len);
auto end = std::span<const uint8_t>(w.end.p, w.end.len);
if (w.end.len > 0) {
addWriteRange(root, begin, end, InternalVersionT(writeVersion),
&writeContext, this);
} else {
addPointWrite(root, begin, InternalVersionT(writeVersion),
&writeContext);
if (writes[i].end.len > 0) {
allPointWrites = false;
break;
}
}
if (0 && allPointWrites) {
interleavedPointWrites(writes, count, InternalVersionT(writeVersion));
} else {
for (int i = 0; i < count; ++i) {
const auto &w = writes[i];
writeContext.accum.write_bytes += w.begin.len + w.end.len;
auto begin = std::span<const uint8_t>(w.begin.p, w.begin.len);
auto end = std::span<const uint8_t>(w.end.p, w.end.len);
if (w.end.len > 0) {
addWriteRange(root, begin, end, InternalVersionT(writeVersion),
&writeContext, this);
} else {
addPointWrite(root, begin, InternalVersionT(writeVersion),
&writeContext);
}
}
}