Separate Connection, ConnectionHandler, Server

This commit is contained in:
2025-08-19 13:23:18 -04:00
parent addac1b0b7
commit cb322bbb2b
7 changed files with 888 additions and 492 deletions

View File

@@ -1,270 +1,44 @@
#include "arena_allocator.hpp"
#include "config.hpp"
#include "connection.hpp"
#include "connection_handler.hpp"
#include "server.hpp"
#include <atomic>
#include <cassert>
#include <csignal>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <deque>
#include <fcntl.h>
#include <inttypes.h>
#include <iostream>
#include <limits.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/uio.h>
#include <sys/un.h>
#include <thread>
#include <unistd.h>
#include <vector>
std::atomic<int> activeConnections{0};
int shutdown_pipe[2] = {-1, -1};
#ifndef __has_feature
#define __has_feature(x) 0
#endif
// Global server instance for signal handler access
Server *g_server = nullptr;
void signal_handler(int sig) {
if (sig == SIGTERM || sig == SIGINT) {
if (shutdown_pipe[1] != -1) {
char val = 1;
// write() is async-signal-safe per POSIX - safe to use in signal handler
// Write single byte to avoid partial write complexity
while (write(shutdown_pipe[1], &val, 1) == -1) {
if (errno != EINTR) {
abort(); // graceful shutdown didn't work. Let's go ungraceful.
}
}
if (g_server) {
g_server->shutdown();
}
}
}
// Adapted from getaddrinfo man page
int getListenFd(const char *node, const char *service) {
struct addrinfo hints;
struct addrinfo *result, *rp;
int sfd, s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */
hints.ai_canonname = nullptr;
hints.ai_addr = nullptr;
hints.ai_next = nullptr;
s = getaddrinfo(node, service, &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
abort();
/**
* Simple echo handler that mirrors received data back to the client.
*
* This implementation preserves the current behavior of the server
* while demonstrating the ConnectionHandler interface pattern.
*/
class EchoHandler : public ConnectionHandler {
public:
ProcessResult process_data(std::string_view data,
std::unique_ptr<Connection> &conn_ptr) override {
// Echo the received data back to the client
conn_ptr->appendMessage(data);
return ProcessResult::Continue;
}
/* getaddrinfo() returns a list of address structures.
Try each address until we successfully bind(2).
If socket(2) (or bind(2)) fails, we (close the socket
and) try the next address. */
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
int val = 1;
if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == -1) {
perror("setsockopt SO_REUSEADDR");
close(sfd);
continue; // Try next address
}
// Enable TCP_NODELAY for low latency (disable Nagle's algorithm)
if (setsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) == -1) {
perror("setsockopt TCP_NODELAY");
close(sfd);
continue; // Try next address
}
// Set socket to non-blocking for graceful shutdown
int flags = fcntl(sfd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL");
close(sfd);
continue; // Try next address
}
if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL");
close(sfd);
continue; // Try next address
}
if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
void on_connection_established(Connection &conn) override {
// Could send a welcome message if desired
// conn.appendMessage("Welcome to WeaselDB echo server\n");
(void)conn; // Suppress unused parameter warning
}
freeaddrinfo(result); /* No longer needed */
if (rp == nullptr) { /* No address succeeded */
fprintf(stderr, "Could not bind\n");
abort();
}
int rv = listen(sfd, SOMAXCONN);
if (rv) {
perror("listen");
abort();
}
return sfd;
}
// Since only one thread owns a connection at a time, no synchronization is
// necessary
// Connection ownership model:
// - Created by accept thread, transferred to epoll via raw pointer
// - Network threads claim ownership by wrapping raw pointer in unique_ptr
// - Network thread optionally passes ownership to a thread pipeline
// - Owner eventually transfers back to epoll by releasing unique_ptr to raw
// pointer
// - RAII cleanup happens if network thread doesn't transfer back
struct Connection {
const int fd;
const int64_t id;
struct sockaddr_storage addr; // sockaddr_storage handles IPv4/IPv6
ArenaAllocator arena;
Connection(struct sockaddr_storage addr, int fd, int64_t id)
: fd(fd), id(id), addr(addr) {
activeConnections.fetch_add(1, std::memory_order_relaxed);
}
~Connection() {
activeConnections.fetch_sub(1, std::memory_order_relaxed);
int e = close(fd);
if (e == -1) {
perror("close");
abort();
}
}
std::deque<std::string_view, ArenaStlAllocator<std::string_view>> messages{
ArenaStlAllocator<std::string_view>{&arena}};
// Copies s into arena, and appends to messages
void appendMessage(std::string_view s) {
char *arena_str = arena.allocate<char>(s.size());
std::memcpy(arena_str, s.data(), s.size());
messages.emplace_back(arena_str, s.size());
}
// Whether or not to close the connection after completing writing the
// response
bool closeConnection{false};
bool readBytes(size_t /*max_request_size*/, size_t buffer_size) {
// Use Variable Length Array for optimal stack allocation
char buf[buffer_size];
for (;;) {
int r = read(fd, buf, buffer_size);
if (r == -1) {
if (errno == EINTR) {
continue;
}
if (errno == EAGAIN) {
return false;
}
perror("read");
return true;
}
if (r == 0) {
return true;
}
// "pump parser"
// TODO revisit
appendMessage({buf, size_t(r)});
}
}
bool writeBytes() {
while (!messages.empty()) {
// Build iovec array up to IOV_MAX limit
struct iovec iov[IOV_MAX];
int iov_count = 0;
for (auto it = messages.begin();
it != messages.end() && iov_count < IOV_MAX; ++it) {
const auto &msg = *it;
if (msg.size() > 0) {
iov[iov_count] = {
const_cast<void *>(static_cast<const void *>(msg.data())),
msg.size()};
iov_count++;
}
}
if (iov_count == 0) {
break;
}
ssize_t w;
for (;;) {
w = writev(fd, iov, iov_count);
if (w == -1) {
if (errno == EINTR) {
continue; // Standard practice: retry on signal interruption
}
if (errno == EAGAIN) {
return false;
}
perror("writev");
return true;
}
break;
}
assert(w > 0);
// Handle partial writes by updating string_view data/size
size_t bytes_written = static_cast<size_t>(w);
while (bytes_written > 0 && !messages.empty()) {
auto &front = messages.front();
if (bytes_written >= front.size()) {
// This message is completely written
bytes_written -= front.size();
messages.pop_front();
} else {
// Partial write of this message - update string_view
front = std::string_view(front.data() + bytes_written,
front.size() - bytes_written);
bytes_written = 0;
}
}
}
assert(messages.empty());
arena.reset();
return closeConnection;
}
// This is necessary because tsan doesn't (yet?) understand that there's a
// happens-before relationship for epoll_ctl(..., EPOLL_CTL_MOD, ...) and
// epoll_wait
#if __has_feature(thread_sanitizer)
void tsan_acquire() { tsan_sync.load(std::memory_order_acquire); }
void tsan_release() { tsan_sync.store(0, std::memory_order_release); }
std::atomic<int> tsan_sync;
#else
void tsan_acquire() {}
void tsan_release() {}
#endif
};
void print_help(const char *program_name) {
@@ -351,251 +125,20 @@ int main(int argc, char *argv[]) {
<< config->subscription.keepalive_interval.count() << " seconds"
<< std::endl;
// Create shutdown pipe for graceful shutdown
if (pipe(shutdown_pipe) == -1) {
perror("pipe");
abort();
}
// Set both ends to close-on-exec
if (fcntl(shutdown_pipe[0], F_SETFD, FD_CLOEXEC) == -1 ||
fcntl(shutdown_pipe[1], F_SETFD, FD_CLOEXEC) == -1) {
perror("fcntl FD_CLOEXEC");
abort();
}
// Create handler and server
EchoHandler echo_handler;
auto server = Server::create(*config, echo_handler);
g_server = server.get();
// Setup signal handling
signal(SIGPIPE, SIG_IGN);
signal(SIGTERM, signal_handler);
signal(SIGINT, signal_handler);
int sockfd = getListenFd(config->server.bind_address.c_str(),
std::to_string(config->server.port).c_str());
std::vector<std::thread> threads;
int network_epollfd = epoll_create1(EPOLL_CLOEXEC);
if (network_epollfd == -1) {
perror("epoll_create");
abort();
}
// Add shutdown pipe read end to network thread epoll
struct epoll_event shutdown_event;
shutdown_event.events = EPOLLIN;
shutdown_event.data.fd = shutdown_pipe[0];
if (epoll_ctl(network_epollfd, EPOLL_CTL_ADD, shutdown_pipe[0],
&shutdown_event) == -1) {
perror("epoll_ctl add shutdown event");
abort();
}
std::atomic<int64_t> connectionId{0};
// Network threads from configuration
int networkThreads = config->server.network_threads;
for (int networkThreadId = 0; networkThreadId < networkThreads;
++networkThreadId) {
threads.emplace_back(
[network_epollfd, networkThreadId,
max_request_size = config->server.max_request_size_bytes,
read_buffer_size = config->server.read_buffer_size,
event_batch_size = config->server.event_batch_size]() {
pthread_setname_np(
pthread_self(),
("network-" + std::to_string(networkThreadId)).c_str());
std::vector<struct epoll_event> events(event_batch_size);
for (;;) {
int eventCount = epoll_wait(network_epollfd, events.data(),
event_batch_size, -1 /* no timeout */);
if (eventCount == -1) {
if (errno == EINTR) {
continue;
}
perror("epoll_wait");
abort();
}
for (int i = 0; i < eventCount; ++i) {
// Check for shutdown event
if (events[i].data.fd == shutdown_pipe[0]) {
// Don't read pipe - all threads need to see shutdown signal
return;
}
// Take ownership from epoll: raw pointer -> unique_ptr
std::unique_ptr<Connection> conn{
static_cast<Connection *>(events[i].data.ptr)};
conn->tsan_acquire();
events[i].data.ptr =
nullptr; // Clear epoll pointer (we own it now)
const int fd = conn->fd;
if (events[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) {
// Connection closed or error occurred - unique_ptr destructor
// cleans up
continue;
}
// When we register our epoll interest, if we have something to
// write, we write. Otherwise we read.
assert(!((events[i].events & EPOLLIN) &&
(events[i].events & EPOLLOUT)));
if (events[i].events & EPOLLIN) {
bool done = conn->readBytes(max_request_size, read_buffer_size);
if (done) {
continue;
}
}
if (events[i].events & EPOLLOUT) {
bool done = conn->writeBytes();
if (done) {
continue;
}
}
if (conn->messages.empty()) {
events[i].events = EPOLLIN | EPOLLONESHOT | EPOLLRDHUP;
} else {
events[i].events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP;
}
// Transfer ownership back to epoll: unique_ptr -> raw pointer
conn->tsan_release();
Connection *raw_conn =
conn.release(); // Get raw pointer before epoll_ctl
events[i].data.ptr = raw_conn; // epoll now owns the connection
int e = epoll_ctl(network_epollfd, EPOLL_CTL_MOD, fd, &events[i]);
if (e == -1) {
perror("epoll_ctl");
delete raw_conn; // Clean up connection on epoll failure
continue;
}
}
}
});
}
// Accept threads from configuration
int acceptThreads = config->server.accept_threads;
// epoll instance for accept threads
int accept_epollfd = epoll_create1(EPOLL_CLOEXEC);
if (accept_epollfd == -1) {
perror("epoll_create1");
abort();
}
// Add shutdown pipe read end to accept epoll
if (epoll_ctl(accept_epollfd, EPOLL_CTL_ADD, shutdown_pipe[0],
&shutdown_event) == -1) {
perror("epoll_ctl shutdown pipe");
abort();
}
// Add listen socket to accept epoll with EPOLLEXCLUSIVE for better load
// balancing
struct epoll_event listen_event;
listen_event.events = EPOLLIN | EPOLLEXCLUSIVE;
listen_event.data.fd = sockfd;
if (epoll_ctl(accept_epollfd, EPOLL_CTL_ADD, sockfd, &listen_event) == -1) {
perror("epoll_ctl listen socket");
abort();
}
for (int acceptThreadId = 0; acceptThreadId < acceptThreads;
++acceptThreadId) {
threads.emplace_back([network_epollfd, acceptThreadId, sockfd,
&connectionId,
max_connections = config->server.max_connections,
accept_epollfd]() {
pthread_setname_np(pthread_self(),
("accept-" + std::to_string(acceptThreadId)).c_str());
for (;;) {
struct epoll_event events[2]; // listen socket + shutdown pipe
int ready = epoll_wait(accept_epollfd, events, 2, -1 /* no timeout */);
if (ready == -1) {
if (errno == EINTR)
continue;
perror("epoll_wait");
abort();
}
for (int i = 0; i < ready; ++i) {
if (events[i].data.fd == shutdown_pipe[0]) {
// Don't read pipe - all threads need to see shutdown signal
return;
}
if (events[i].data.fd == sockfd) {
// Listen socket ready - accept connections
for (;;) {
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
int fd = accept4(sockfd, (struct sockaddr *)&addr, &addrlen,
SOCK_NONBLOCK);
if (fd == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
break; // No more connections
if (errno == EINTR)
continue;
perror("accept4");
abort();
}
// Check connection limit (0 means unlimited). Limiting
// connections is best effort - race condition between check and
// increment is acceptable for this use case
if (max_connections > 0 &&
activeConnections.load(std::memory_order_relaxed) >=
max_connections) {
// Reject connection by immediately closing it
close(fd);
continue;
}
// Enable keepalive to detect dead connections
int keepalive = 1;
if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive,
sizeof(keepalive)) == -1) {
perror("setsockopt SO_KEEPALIVE");
// Continue anyway - not critical
}
auto conn = std::make_unique<Connection>(
addr, fd,
connectionId.fetch_add(1, std::memory_order_relaxed));
// Transfer new connection to network thread epoll
struct epoll_event event{};
event.events = EPOLLIN | EPOLLONESHOT | EPOLLRDHUP;
conn->tsan_release();
Connection *raw_conn =
conn.release(); // Get raw pointer before epoll_ctl
event.data.ptr =
raw_conn; // network epoll now owns the connection
int e = epoll_ctl(network_epollfd, EPOLL_CTL_ADD, fd, &event);
if (e == -1) {
perror("epoll_ctl");
delete raw_conn; // Clean up connection on epoll failure
continue;
}
}
}
}
}
});
}
for (auto &t : threads) {
t.join();
}
// Cleanup
close(shutdown_pipe[0]);
close(shutdown_pipe[1]);
close(accept_epollfd);
close(network_epollfd);
close(sockfd);
std::cout << "Starting WeaselDB server..." << std::endl;
server->run();
std::cout << "Server shutdown complete." << std::endl;
g_server = nullptr;
return 0;
}