#include "config.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include std::atomic activeConnections{0}; int shutdown_eventfd = -1; #ifndef __has_feature #define __has_feature(x) 0 #endif void signal_handler(int sig) { if (sig == SIGTERM || sig == SIGINT) { if (shutdown_eventfd != -1) { uint64_t val = 1; if (write(shutdown_eventfd, &val, sizeof(val)) == -1) { abort(); // Critical failure - can't signal shutdown } } } } // Adapted from getaddrinfo man page int getListenFd(const char *node, const char *service) { struct addrinfo hints; struct addrinfo *result, *rp; int sfd, s; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */ hints.ai_socktype = SOCK_STREAM; /* stream socket */ hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */ hints.ai_protocol = 0; /* Any protocol */ hints.ai_canonname = nullptr; hints.ai_addr = nullptr; hints.ai_next = nullptr; s = getaddrinfo(node, service, &hints, &result); if (s != 0) { fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s)); abort(); } /* getaddrinfo() returns a list of address structures. Try each address until we successfully bind(2). If socket(2) (or bind(2)) fails, we (close the socket and) try the next address. */ for (rp = result; rp != nullptr; rp = rp->ai_next) { sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (sfd == -1) { continue; } int val = 1; setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)); // Set socket to non-blocking for graceful shutdown int flags = fcntl(sfd, F_GETFL, 0); if (flags != -1) { fcntl(sfd, F_SETFL, flags | O_NONBLOCK); } if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) { break; /* Success */ } close(sfd); } freeaddrinfo(result); /* No longer needed */ if (rp == nullptr) { /* No address succeeded */ fprintf(stderr, "Could not bind\n"); abort(); } int rv = listen(sfd, SOMAXCONN); if (rv) { perror("listen"); abort(); } return sfd; } int getAcceptFd(int listenFd, struct sockaddr_storage *addr) { // Use sockaddr_storage (not sockaddr) to handle both IPv4 and IPv6 socklen_t addrlen = sizeof(sockaddr_storage); int fd = accept4(listenFd, (struct sockaddr *)addr, &addrlen, SOCK_NONBLOCK); return fd; } // Since only one thread owns a connection at a time, no synchronization is // necessary // Connection ownership model: // - Created by accept thread, transferred to epoll via raw pointer // - Network threads claim ownership by wrapping raw pointer in unique_ptr // - Network thread optionally passes ownership to a thread pipeline // - Owner eventually transfers back to epoll by releasing unique_ptr to raw // pointer // - RAII cleanup happens if network thread doesn't transfer back struct Connection { const int fd; const int64_t id; struct sockaddr_storage addr; // sockaddr_storage handles IPv4/IPv6 Connection(struct sockaddr_storage addr, int fd, int64_t id) : fd(fd), id(id), addr(addr) { activeConnections.fetch_add(1, std::memory_order_relaxed); } ~Connection() { activeConnections.fetch_sub(1, std::memory_order_relaxed); int e = close(fd); if (e == -1) { perror("close"); abort(); } } struct Task { std::string s; bool closeConnection{false}; int written = 0; }; std::deque tasks; void readBytes(size_t max_request_size) { // Use smaller buffer size but respect max request size // TODO revisit size_t buf_size = std::min(size_t(4096), max_request_size); std::vector buf(buf_size); for (;;) { int r = read(fd, buf.data(), buf.size()); if (r == -1) { if (errno == EINTR) { continue; } if (errno == EAGAIN) { return; } perror("read"); goto close_connection; } if (r == 0) { goto close_connection; } // "pump parser" // TODO revisit tasks.emplace_back(std::string{buf.data(), size_t(r)}); } close_connection: tasks.emplace_back(std::string{}, true); } bool writeBytes() { while (!tasks.empty()) { auto &front = tasks.front(); if (front.closeConnection) { return true; } int w; for (;;) { w = write(fd, front.s.data() + front.written, front.s.size() - front.written); if (w == -1) { if (errno == EINTR) { continue; // Standard practice: retry on signal interruption } if (errno == EAGAIN) { return false; } perror("write"); return true; } break; } assert(w != 0); front.written += w; if (front.written == int(front.s.size())) { tasks.pop_front(); } } return false; } // This is necessary because tsan doesn't (yet?) understand that there's a // happens-before relationship for epoll_ctl(..., EPOLL_CTL_MOD, ...) and // epoll_wait #if __has_feature(thread_sanitizer) void tsan_acquire() { tsan_sync.load(std::memory_order_acquire); } void tsan_release() { tsan_sync.store(0, std::memory_order_release); } std::atomic tsan_sync; #else void tsan_acquire() {} void tsan_release() {} #endif }; int main(int argc, char *argv[]) { std::string config_file = "config.toml"; if (argc > 1) { config_file = argv[1]; } auto config = weaseldb::ConfigParser::load_from_file(config_file); if (!config) { std::cerr << "Failed to load config from: " << config_file << std::endl; std::cerr << "Using default configuration..." << std::endl; config = weaseldb::Config{}; } std::cout << "Configuration loaded successfully:" << std::endl; std::cout << "Server bind address: " << config->server.bind_address << std::endl; std::cout << "Server port: " << config->server.port << std::endl; std::cout << "Max request size: " << config->server.max_request_size_bytes << " bytes" << std::endl; std::cout << "Accept threads: " << config->server.accept_threads << std::endl; std::cout << "Network threads: " << config->server.network_threads << std::endl; std::cout << "Event batch size: " << config->server.event_batch_size << std::endl; std::cout << "Max connections: " << config->server.max_connections << std::endl; std::cout << "Min request ID length: " << config->commit.min_request_id_length << std::endl; std::cout << "Request ID retention: " << config->commit.request_id_retention_hours.count() << " hours" << std::endl; std::cout << "Subscription buffer size: " << config->subscription.max_buffer_size_bytes << " bytes" << std::endl; std::cout << "Keepalive interval: " << config->subscription.keepalive_interval.count() << " seconds" << std::endl; // Create shutdown eventfd for graceful shutdown shutdown_eventfd = eventfd(0, EFD_CLOEXEC); if (shutdown_eventfd == -1) { perror("eventfd"); abort(); } signal(SIGPIPE, SIG_IGN); signal(SIGTERM, signal_handler); signal(SIGINT, signal_handler); int sockfd = getListenFd(config->server.bind_address.c_str(), std::to_string(config->server.port).c_str()); std::vector threads; int network_epollfd = epoll_create1(EPOLL_CLOEXEC); if (network_epollfd == -1) { perror("epoll_create"); abort(); } // Add shutdown eventfd to network thread epoll struct epoll_event shutdown_event; shutdown_event.events = EPOLLIN; shutdown_event.data.fd = shutdown_eventfd; epoll_ctl(network_epollfd, EPOLL_CTL_ADD, shutdown_eventfd, &shutdown_event); std::atomic connectionId{0}; // Network threads from configuration int networkThreads = config->server.network_threads; for (int i = 0; i < networkThreads; ++i) { threads.emplace_back( [network_epollfd, i, max_request_size = config->server.max_request_size_bytes, event_batch_size = config->server.event_batch_size]() { pthread_setname_np(pthread_self(), ("network-" + std::to_string(i)).c_str()); std::vector events(event_batch_size); for (;;) { int eventCount = epoll_wait(network_epollfd, events.data(), event_batch_size, -1 /* no timeout */); if (eventCount == -1) { if (errno == EINTR) { continue; } perror("epoll_wait"); abort(); } for (int i = 0; i < eventCount; ++i) { // Check for shutdown event if (events[i].data.fd == shutdown_eventfd) { // Don't read - let other threads see it too return; } // Take ownership from epoll: raw pointer -> unique_ptr std::unique_ptr conn{ static_cast(events[i].data.ptr)}; conn->tsan_acquire(); events[i].data.ptr = nullptr; // Clear epoll pointer (we own it now) const int fd = conn->fd; if (events[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) { // Connection closed or error occurred - unique_ptr destructor // cleans up continue; } if (events[i].events & EPOLLIN) { conn->readBytes(max_request_size); } if (events[i].events & EPOLLOUT) { bool done = conn->writeBytes(); if (done) { continue; } } if (conn->tasks.empty()) { events[i].events = EPOLLIN | EPOLLONESHOT | EPOLLRDHUP; } else { events[i].events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP; } // Transfer ownership back to epoll: unique_ptr -> raw pointer conn->tsan_release(); events[i].data.ptr = conn.release(); // epoll now owns the connection int e = epoll_ctl(network_epollfd, EPOLL_CTL_MOD, fd, &events[i]); if (e == -1) { perror("epoll_ctl"); abort(); // Process termination - OS cleans up leaked connection } } } }); } // Accept threads from configuration int acceptThreads = config->server.accept_threads; // epoll instance for accept threads int accept_epollfd = epoll_create1(EPOLL_CLOEXEC); if (accept_epollfd == -1) { perror("epoll_create1"); abort(); } // Add shutdown eventfd to accept epoll if (epoll_ctl(accept_epollfd, EPOLL_CTL_ADD, shutdown_eventfd, &shutdown_event) == -1) { perror("epoll_ctl shutdown eventfd"); abort(); } // Add listen socket to accept epoll struct epoll_event listen_event; listen_event.events = EPOLLIN; listen_event.data.fd = sockfd; if (epoll_ctl(accept_epollfd, EPOLL_CTL_ADD, sockfd, &listen_event) == -1) { perror("epoll_ctl listen socket"); abort(); } for (int i = 0; i < acceptThreads; ++i) { threads.emplace_back([network_epollfd, i, sockfd, &connectionId, max_connections = config->server.max_connections, accept_epollfd]() { pthread_setname_np(pthread_self(), ("accept-" + std::to_string(i)).c_str()); for (;;) { struct epoll_event events[2]; // listen socket + shutdown eventfd int ready = epoll_wait(accept_epollfd, events, 2, -1 /* no timeout */); if (ready == -1) { if (errno == EINTR) continue; perror("epoll_wait"); abort(); } for (int j = 0; j < ready; ++j) { if (events[j].data.fd == shutdown_eventfd) { // Don't read - let other threads see it too return; } if (events[j].data.fd == sockfd) { // Listen socket ready - accept connections for (;;) { struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); int fd = accept4(sockfd, (struct sockaddr *)&addr, &addrlen, SOCK_NONBLOCK); if (fd == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) break; // No more connections if (errno == EINTR) continue; perror("accept4"); abort(); } // Check connection limit (0 means unlimited). Limiting // connections is best effort but it's the best we can do. if (max_connections > 0 && activeConnections.load(std::memory_order_relaxed) >= max_connections) { // Reject connection by immediately closing it close(fd); continue; } auto conn = std::make_unique( addr, fd, connectionId.fetch_add(1, std::memory_order_relaxed)); // Transfer new connection to network thread epoll struct epoll_event event{}; event.events = EPOLLIN | EPOLLONESHOT | EPOLLRDHUP; conn->tsan_release(); event.data.ptr = conn.release(); // network epoll now owns the connection int e = epoll_ctl(network_epollfd, EPOLL_CTL_ADD, fd, &event); if (e == -1) { perror("epoll_ctl"); abort(); } } } } } close(accept_epollfd); }); } for (auto &t : threads) { t.join(); } // Cleanup close(shutdown_eventfd); close(accept_epollfd); close(network_epollfd); close(sockfd); return 0; }