diff --git a/src/main.cpp b/src/main.cpp index aacf6dd..2f17ac3 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,32 +1,190 @@ -#include "commit_request.hpp" #include "config.hpp" -#include "json_commit_request_parser.hpp" +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include -void print_stats(const CommitRequest &request) { +#ifndef __has_feature +#define __has_feature(x) 0 +#endif - std::cout << "✓ Successfully parsed commit request:" << std::endl; - std::cout << " Request ID: " - << (request.request_id().has_value() ? request.request_id().value() - : "none") - << std::endl; - std::cout << " Leader ID: " << request.leader_id() << std::endl; - std::cout << " Read Version: " << request.read_version() << std::endl; - std::cout << " Preconditions: " << request.preconditions().size() - << std::endl; - std::cout << " Operations: " << request.operations().size() << std::endl; - std::cout << " Arena memory used: " << request.used_bytes() << " bytes" - << std::endl; +// Adapted from getaddrinfo man page +int getListenFd(const char *node, const char *service) { - if (!request.operations().empty()) { - const auto &op = request.operations()[0]; - std::cout << " First operation: " - << (op.type == Operation::Type::Write ? "write" : "other") - << " param1=" << op.param1 << " param2=" << op.param2 - << std::endl; + struct addrinfo hints; + struct addrinfo *result, *rp; + int sfd, s; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */ + hints.ai_socktype = SOCK_STREAM; /* stream socket */ + hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + hints.ai_canonname = nullptr; + hints.ai_addr = nullptr; + hints.ai_next = nullptr; + + s = getaddrinfo(node, service, &hints, &result); + if (s != 0) { + fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s)); + abort(); } + + /* getaddrinfo() returns a list of address structures. + Try each address until we successfully bind(2). + If socket(2) (or bind(2)) fails, we (close the socket + and) try the next address. */ + + for (rp = result; rp != nullptr; rp = rp->ai_next) { + sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (sfd == -1) { + continue; + } + + int val = 1; + setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)); + + if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) { + break; /* Success */ + } + + close(sfd); + } + + freeaddrinfo(result); /* No longer needed */ + + if (rp == nullptr) { /* No address succeeded */ + fprintf(stderr, "Could not bind\n"); + abort(); + } + + int rv = listen(sfd, SOMAXCONN); + if (rv) { + perror("listen"); + abort(); + } + + return sfd; } +int getAcceptFd(int listenFd, struct sockaddr *addr) { + socklen_t addrlen = sizeof(sockaddr); + int fd = accept4(listenFd, addr, &addrlen, SOCK_NONBLOCK); + return fd; +} + +// Connection lifecycle. Only one of these is the case at a time +// - Created on an accept thread from a call to accept +// - Waiting on connection fd to be readable/writable +// - Owned by a network thread, which drains readable and writable bytes +// - Owned by a thread in the request processing pipeline +// - Closed by a network thread according to http protocol +// +// Since only one thread owns a connection at a time, no synchronization is +// necessary +struct Connection { + const int fd; + const int64_t id; + struct sockaddr addr; + + Connection(struct sockaddr addr, int fd, int64_t id) + : fd(fd), id(id), addr(addr) {} + + ~Connection() { + int e = close(fd); + if (e == -1) { + perror("close"); + abort(); + } + } + + struct Task { + std::string s; + bool closeConnection{false}; + int written = 0; + }; + + std::deque tasks; + + void readBytes() { + for (;;) { + // TODO make size configurable + char buf[1024]; + int r = read(fd, buf, sizeof(buf)); + if (r == -1) { + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN) { + return; + } + perror("read"); + goto close_connection; + } + if (r == 0) { + goto close_connection; + } + // pump parser + tasks.emplace_back(std::string{buf, size_t(r)}); + } + close_connection: + tasks.emplace_back(std::string{}, true); + } + + bool writeBytes() { + while (!tasks.empty()) { + auto &front = tasks.front(); + if (front.closeConnection) { + return true; + } + int w; + for (;;) { + w = write(fd, front.s.data() + front.written, + front.s.size() - front.written); + if (w == -1) { + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN) { + return false; + } + perror("write"); + return true; + } + break; + } + assert(w != 0); + front.written += w; + if (front.written == int(front.s.size())) { + tasks.pop_front(); + } + } + return false; + } + +#if __has_feature(thread_sanitizer) + void tsan_acquire() { tsan_sync.load(std::memory_order_acquire); } + void tsan_release() { tsan_sync.store(0, std::memory_order_release); } + std::atomic tsan_sync; +#else + void tsan_acquire() {} + void tsan_release() {} +#endif +}; + int main(int argc, char *argv[]) { std::string config_file = "config.toml"; @@ -60,104 +218,126 @@ int main(int argc, char *argv[]) { << config->subscription.keepalive_interval.count() << " seconds" << std::endl; - // Demonstrate CommitRequest functionality - std::cout << "\n--- CommitRequest Demo ---" << std::endl; + signal(SIGPIPE, SIG_IGN); - CommitRequest request; - JsonCommitRequestParser parser; - - const std::string sample_json = R"({ - "request_id": "demo-12345", - "leader_id": "leader-abc", - "read_version": 42, - "preconditions": [ - { - "type": "point_read", - "version": 41, - "key": "dGVzdEtleQ==" - } - ], - "operations": [ - { - "type": "write", - "key": "dGVzdEtleQ==", - "value": "dGVzdFZhbHVl" - } - ] - })"; - auto copy = sample_json; - - auto parse_result = parser.parse(request, copy.data(), copy.size()); - if (parse_result == CommitRequestParser::ParseResult::Success) { - print_stats(request); - } else { - std::cout << "✗ Failed to parse commit request: "; - switch (parse_result) { - case CommitRequestParser::ParseResult::InvalidJson: - std::cout << "Invalid JSON format" << std::endl; - break; - case CommitRequestParser::ParseResult::MissingField: - std::cout << "Missing required field" << std::endl; - break; - case CommitRequestParser::ParseResult::InvalidField: - std::cout << "Invalid field value" << std::endl; - break; - case CommitRequestParser::ParseResult::OutOfMemory: - std::cout << "Out of memory" << std::endl; - break; - default: - std::cout << "Unknown error" << std::endl; - break; - } + int sockfd = getListenFd(config->server.bind_address.c_str(), + std::to_string(config->server.port).c_str()); + std::vector threads; + int epollfd = epoll_create(/*ignored*/ 1); + if (epollfd == -1) { + perror("epoll_create"); + abort(); } - // Demonstrate streaming parsing - std::cout << "\n--- Streaming Parse Demo ---" << std::endl; + // Network threads + // TODO make configurable + int networkThreads = 1; + // TODO make configurable + constexpr int kEventBatchSize = 10; + for (int i = 0; i < networkThreads; ++i) { + threads.emplace_back([epollfd, i]() { + pthread_setname_np(pthread_self(), + ("network-" + std::to_string(i)).c_str()); + for (;;) { + struct epoll_event events[kEventBatchSize]{}; + int eventCount; + for (;;) { + eventCount = + epoll_wait(epollfd, events, kEventBatchSize, /*no timeout*/ -1); + if (eventCount == -1) { + if (errno == EINTR) { + continue; + } + perror("epoll_wait"); + abort(); + } + break; + } - CommitRequest streaming_request; - JsonCommitRequestParser streaming_parser; + for (int i = 0; i < eventCount; ++i) { + std::unique_ptr conn{ + static_cast(events[i].data.ptr)}; + conn->tsan_acquire(); + events[i].data.ptr = nullptr; + const int fd = conn->fd; - if (streaming_parser.begin_streaming_parse(streaming_request)) { - std::cout << "✓ Initialized streaming parser" << std::endl; + if (events[i].events & EPOLLERR) { + // Done with connection + continue; + } + if (events[i].events & EPOLLOUT) { + // Write bytes, maybe close connection + bool finished = conn->writeBytes(); + if (finished) { + // Done with connection + continue; + } + } - // Simulate receiving data in small chunks like from a network socket - std::string copy = sample_json; + if (events[i].events & EPOLLIN) { + conn->readBytes(); + } - size_t chunk_size = 15; // Small chunks to simulate network packets - size_t offset = 0; - int chunk_count = 0; + if (events[i].events & EPOLLOUT) { + bool done = conn->writeBytes(); + if (done) { + continue; + } + } - CommitRequestParser::ParseStatus status = - CommitRequestParser::ParseStatus::Incomplete; + if (conn->tasks.empty()) { + // Transfer back to epoll instance. This thread or another thread + // will wake when fd is ready + events[i].events = EPOLLIN | EPOLLONESHOT; + } else { + events[i].events = EPOLLOUT | EPOLLONESHOT; + } + conn->tsan_release(); + events[i].data.ptr = conn.release(); + int e = epoll_ctl(epollfd, EPOLL_CTL_MOD, fd, &events[i]); + if (e == -1) { + perror("epoll_ctl"); + abort(); + } + } + } + }); + } - while (offset < copy.size() && - status == CommitRequestParser::ParseStatus::Incomplete) { - size_t len = std::min(chunk_size, copy.size() - offset); - std::string chunk = copy.substr(offset, len); + std::atomic connectionId{0}; - std::cout << " Chunk " << ++chunk_count << " (" << len << " bytes): \"" - << chunk << "\"" << std::endl; + // TODO make configurable + int acceptThreads = 1; + for (int i = 0; i < acceptThreads; ++i) { + threads.emplace_back([epollfd, i, sockfd, &connectionId]() { + pthread_setname_np(pthread_self(), + ("accept-" + std::to_string(i)).c_str()); + // Call accept in a loop + for (;;) { + struct sockaddr addr; + int fd = getAcceptFd(sockfd, &addr); + if (fd == -1) { + perror("accept4"); + continue; + } + auto conn = std::make_unique( + addr, fd, connectionId.fetch_add(1, std::memory_order_relaxed)); + // Post to epoll instance + struct epoll_event event{}; + event.events = EPOLLIN | EPOLLONESHOT; + conn->tsan_release(); + event.data.ptr = conn.release(); + int e = epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event); + if (e == -1) { + perror("epoll_ctl"); + abort(); + } + } + }); + } - // Need mutable data for weaseljson - std::string mutable_chunk = chunk; - status = streaming_parser.parse_chunk(mutable_chunk.data(), - mutable_chunk.size()); - - offset += len; - } - - if (status == CommitRequestParser::ParseStatus::Incomplete) { - std::cout << " Finalizing parse..." << std::endl; - status = streaming_parser.finish_streaming_parse(); - } - - if (status == CommitRequestParser::ParseStatus::Complete) { - print_stats(streaming_request); - } else { - std::cout << "✗ Streaming parse failed" << std::endl; - } - } else { - std::cout << "✗ Failed to initialize streaming parser" << std::endl; + for (auto &t : threads) { + t.join(); } return 0;