Have Server take list of listen fds and add createLocalConnection

This commit is contained in:
2025-08-22 12:01:00 -04:00
parent ba3258ab16
commit 0e63d5e80f
3 changed files with 259 additions and 172 deletions

View File

@@ -6,7 +6,14 @@
#include "server.hpp"
#include <atomic>
#include <csignal>
#include <fcntl.h>
#include <iostream>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include <vector>
PERFETTO_TRACK_EVENT_STATIC_STORAGE();
@@ -21,6 +28,117 @@ void signal_handler(int sig) {
}
}
std::vector<int> create_listen_sockets(const weaseldb::Config &config) {
std::vector<int> listen_fds;
// Check if unix socket path is specified
if (!config.server.unix_socket_path.empty()) {
// Create unix socket
int sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
perror("socket");
throw std::runtime_error("Failed to create unix socket");
}
// Remove existing socket file if it exists
unlink(config.server.unix_socket_path.c_str());
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
if (config.server.unix_socket_path.length() >= sizeof(addr.sun_path)) {
close(sfd);
throw std::runtime_error("Unix socket path too long");
}
strncpy(addr.sun_path, config.server.unix_socket_path.c_str(),
sizeof(addr.sun_path) - 1);
if (bind(sfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
perror("bind");
close(sfd);
throw std::runtime_error("Failed to bind unix socket");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on unix socket");
}
listen_fds.push_back(sfd);
return listen_fds;
}
// TCP socket creation
struct addrinfo hints;
struct addrinfo *result, *rp;
int s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */
hints.ai_canonname = nullptr;
hints.ai_addr = nullptr;
hints.ai_next = nullptr;
s = getaddrinfo(config.server.bind_address.c_str(),
std::to_string(config.server.port).c_str(), &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
throw std::runtime_error("Failed to resolve bind address");
}
int sfd = -1;
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
int val = 1;
if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == -1) {
perror("setsockopt SO_REUSEADDR");
close(sfd);
continue;
}
// Enable TCP_NODELAY for low latency (only for TCP sockets)
if (rp->ai_family == AF_INET || rp->ai_family == AF_INET6) {
if (setsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) == -1) {
perror("setsockopt TCP_NODELAY");
close(sfd);
continue;
}
}
if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
sfd = -1;
}
freeaddrinfo(result);
if (rp == nullptr || sfd == -1) {
throw std::runtime_error("Could not bind to any address");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on socket");
}
listen_fds.push_back(sfd);
return listen_fds;
}
void print_help(const char *program_name) {
std::cout << "WeaselDB - High-performance write-side database component\n\n";
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
@@ -117,9 +235,18 @@ int main(int argc, char *argv[]) {
<< config->subscription.keepalive_interval.count() << " seconds"
<< std::endl;
// Create listen sockets
std::vector<int> listen_fds;
try {
listen_fds = create_listen_sockets(*config);
} catch (const std::exception &e) {
std::cerr << "Failed to create listen sockets: " << e.what() << std::endl;
return 1;
}
// Create handler and server
HttpHandler http_handler;
auto server = Server::create(*config, http_handler);
auto server = Server::create(*config, http_handler, listen_fds);
g_server = server.get();
// Setup signal handling

View File

@@ -18,14 +18,33 @@
#include <vector>
std::shared_ptr<Server> Server::create(const weaseldb::Config &config,
ConnectionHandler &handler) {
ConnectionHandler &handler,
const std::vector<int> &listen_fds) {
// Use std::shared_ptr constructor with private access
// We can't use make_shared here because constructor is private
return std::shared_ptr<Server>(new Server(config, handler));
return std::shared_ptr<Server>(new Server(config, handler, listen_fds));
}
Server::Server(const weaseldb::Config &config, ConnectionHandler &handler)
: config_(config), handler_(handler), connection_registry_() {}
Server::Server(const weaseldb::Config &config, ConnectionHandler &handler,
const std::vector<int> &provided_listen_fds)
: config_(config), handler_(handler), connection_registry_(),
listen_fds_(provided_listen_fds) {
// Server takes ownership of all provided listen fds
// Ensure all listen fds are non-blocking for safe epoll usage
for (int fd : listen_fds_) {
int flags = fcntl(fd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL on provided listen fd");
throw std::runtime_error("Failed to get flags for provided listen fd");
}
if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL O_NONBLOCK on provided listen fd");
throw std::runtime_error("Failed to set provided listen fd non-blocking");
}
}
// If empty vector provided, listen_fds_ will be empty (no listening)
// Server works purely with createLocalConnection()
}
Server::~Server() {
if (shutdown_pipe_[0] != -1) {
@@ -45,9 +64,11 @@ Server::~Server() {
}
epoll_fds_.clear();
if (listen_sockfd_ != -1) {
close(listen_sockfd_);
listen_sockfd_ = -1;
// Close all listen sockets (Server always owns them)
for (int fd : listen_fds_) {
if (fd != -1) {
close(fd);
}
}
// Clean up unix socket file if it exists
@@ -58,9 +79,6 @@ Server::~Server() {
void Server::run() {
setup_shutdown_pipe();
listen_sockfd_ = create_listen_socket();
create_epoll_instances();
// Create I/O threads locally in this call frame
@@ -139,6 +157,50 @@ void Server::receiveConnectionBack(std::unique_ptr<Connection> connection) {
}
}
int Server::createLocalConnection() {
int sockets[2];
if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sockets) != 0) {
perror("socketpair");
return -1;
}
int server_fd = sockets[0]; // Server keeps this end
int client_fd = sockets[1]; // Return this end to caller
// Create sockaddr_storage for the connection
struct sockaddr_storage addr{};
addr.ss_family = AF_UNIX;
// Calculate epoll_index for connection distribution
size_t epoll_index =
connection_distribution_counter_.fetch_add(1, std::memory_order_relaxed) %
epoll_fds_.size();
// Create Connection object
auto connection = std::unique_ptr<Connection>(new Connection(
addr, server_fd, connection_id_.fetch_add(1, std::memory_order_relaxed),
epoll_index, &handler_, weak_from_this()));
// Store in registry
connection_registry_.store(server_fd, std::move(connection));
// Add to appropriate epoll instance
struct epoll_event event{};
event.events = EPOLLIN | EPOLLONESHOT;
event.data.fd = server_fd;
int epollfd = epoll_fds_[epoll_index];
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, server_fd, &event) == -1) {
perror("epoll_ctl ADD local connection");
connection_registry_.remove(server_fd);
close(server_fd);
close(client_fd);
return -1;
}
return client_fd;
}
void Server::setup_shutdown_pipe() {
if (pipe(shutdown_pipe_) == -1) {
perror("pipe");
@@ -153,139 +215,6 @@ void Server::setup_shutdown_pipe() {
}
}
int Server::create_listen_socket() {
int sfd;
// Check if unix socket path is specified
if (!config_.server.unix_socket_path.empty()) {
// Create unix socket
sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
perror("socket");
throw std::runtime_error("Failed to create unix socket");
}
// Remove existing socket file if it exists
unlink(config_.server.unix_socket_path.c_str());
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
if (config_.server.unix_socket_path.length() >= sizeof(addr.sun_path)) {
close(sfd);
throw std::runtime_error("Unix socket path too long");
}
strncpy(addr.sun_path, config_.server.unix_socket_path.c_str(),
sizeof(addr.sun_path) - 1);
// Set socket to non-blocking for graceful shutdown
int flags = fcntl(sfd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL");
close(sfd);
throw std::runtime_error("Failed to get socket flags");
}
if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL");
close(sfd);
throw std::runtime_error("Failed to set socket non-blocking");
}
if (bind(sfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
perror("bind");
close(sfd);
throw std::runtime_error("Failed to bind unix socket");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on unix socket");
}
return sfd;
}
// TCP socket creation (original code)
struct addrinfo hints;
struct addrinfo *result, *rp;
int s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */
hints.ai_canonname = nullptr;
hints.ai_addr = nullptr;
hints.ai_next = nullptr;
s = getaddrinfo(config_.server.bind_address.c_str(),
std::to_string(config_.server.port).c_str(), &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
throw std::runtime_error("Failed to resolve bind address");
}
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
int val = 1;
if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == -1) {
perror("setsockopt SO_REUSEADDR");
close(sfd);
continue;
}
// Enable TCP_NODELAY for low latency (only for TCP sockets)
if (rp->ai_family == AF_INET || rp->ai_family == AF_INET6) {
if (setsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) == -1) {
perror("setsockopt TCP_NODELAY");
close(sfd);
continue;
}
}
// Set socket to non-blocking for graceful shutdown
int flags = fcntl(sfd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL");
close(sfd);
continue;
}
if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL");
close(sfd);
continue;
}
if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
}
freeaddrinfo(result);
if (rp == nullptr) {
throw std::runtime_error("Could not bind to any address");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on socket");
}
return sfd;
}
void Server::create_epoll_instances() {
// Create multiple epoll instances to reduce contention
epoll_fds_.resize(config_.server.epoll_instances);
@@ -308,18 +237,20 @@ void Server::create_epoll_instances() {
throw std::runtime_error("Failed to add shutdown pipe to epoll");
}
// Add listen socket to each epoll instance with EPOLLEXCLUSIVE to prevent
// thundering herd
// Add all listen sockets to each epoll instance with EPOLLEXCLUSIVE to
// prevent thundering herd
for (int listen_fd : listen_fds_) {
struct epoll_event listen_event;
listen_event.events = EPOLLIN | EPOLLEXCLUSIVE;
listen_event.data.fd = listen_sockfd_;
if (epoll_ctl(epoll_fds_[i], EPOLL_CTL_ADD, listen_sockfd_,
&listen_event) == -1) {
listen_event.data.fd = listen_fd;
if (epoll_ctl(epoll_fds_[i], EPOLL_CTL_ADD, listen_fd, &listen_event) ==
-1) {
perror("epoll_ctl listen socket");
throw std::runtime_error("Failed to add listen socket to epoll");
}
}
}
}
int Server::get_epoll_for_thread(int thread_id) const {
// Round-robin assignment of threads to epoll instances
@@ -340,6 +271,8 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
struct epoll_event events[config_.server.event_batch_size];
std::unique_ptr<Connection> batch[config_.server.event_batch_size];
int batch_events[config_.server.event_batch_size];
std::vector<int>
ready_listen_fds; // Reused across iterations to avoid allocation
for (;;) {
int event_count =
@@ -352,16 +285,19 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
abort();
}
bool listenReady = false;
ready_listen_fds.clear(); // Clear from previous iteration
int batch_count = 0;
for (int i = 0; i < event_count; ++i) {
// Check for shutdown event
if (events[i].data.fd == shutdown_pipe_[0]) {
return;
}
// Check for new connections
if (events[i].data.fd == listen_sockfd_) {
listenReady = true;
// Check for new connections on any listen socket
bool isListenSocket =
std::find(listen_fds_.begin(), listen_fds_.end(),
events[i].data.fd) != listen_fds_.end();
if (isListenSocket) {
ready_listen_fds.push_back(events[i].data.fd);
continue;
}
@@ -389,17 +325,17 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
{batch_events, (size_t)batch_count}, false);
}
// Reuse same batch array for accepting connections
if (listenReady) {
// Only accept on listen sockets that epoll indicates are ready
for (int listen_fd : ready_listen_fds) {
for (;;) {
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
int fd = accept4(listen_sockfd_, (struct sockaddr *)&addr, &addrlen,
int fd = accept4(listen_fd, (struct sockaddr *)&addr, &addrlen,
SOCK_NONBLOCK);
if (fd == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
break;
break; // Try next listen socket
if (errno == EINTR)
continue;
perror("accept4");
@@ -438,7 +374,8 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
true);
batch_count = 0;
}
}
} // End inner accept loop
} // End loop over listen_fds_
// Process remaining accepted connections
if (batch_count > 0) {
@@ -447,7 +384,6 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
batch_count = 0;
}
}
}
});
}
}

View File

@@ -43,10 +43,15 @@ public:
*
* @param config Server configuration (threads, ports, limits, etc.)
* @param handler Protocol handler for processing connection data
* @param listen_fds Vector of file descriptors to accept connections on.
* Server takes ownership and will close them on
* destruction. Server will set these to non-blocking mode for safe epoll
* usage. Empty vector means no listening sockets.
* @return shared_ptr to the newly created Server
*/
static std::shared_ptr<Server> create(const weaseldb::Config &config,
ConnectionHandler &handler);
ConnectionHandler &handler,
const std::vector<int> &listen_fds);
/**
* Destructor ensures proper cleanup of all resources.
@@ -75,6 +80,20 @@ public:
*/
void shutdown();
/**
* Creates a local connection using socketpair() for testing or local IPC.
*
* Creates a socketpair, registers one end as a Connection in the server,
* and returns the other end to the caller for communication.
*
* The caller takes ownership of the returned file descriptor and must close
* it.
*
* @return File descriptor for the client end of the socketpair, or -1 on
* error
*/
int createLocalConnection();
/**
* Release a connection back to its server for continued processing.
*
@@ -95,8 +114,13 @@ private:
*
* @param config Server configuration (threads, ports, limits, etc.)
* @param handler Protocol handler for processing connection data
* @param listen_fds Vector of file descriptors to accept connections on.
* Server takes ownership and will close them on
* destruction. Server will set these to non-blocking mode for safe epoll
* usage.
*/
explicit Server(const weaseldb::Config &config, ConnectionHandler &handler);
explicit Server(const weaseldb::Config &config, ConnectionHandler &handler,
const std::vector<int> &listen_fds);
const weaseldb::Config &config_;
ConnectionHandler &handler_;
@@ -116,12 +140,12 @@ private:
// Multiple epoll file descriptors to reduce contention
std::vector<int> epoll_fds_;
int listen_sockfd_ = -1;
std::vector<int>
listen_fds_; // FDs to accept connections on (Server owns these)
// Private helper methods
void setup_shutdown_pipe();
void setup_signal_handling();
int create_listen_socket();
void create_epoll_instances();
void start_io_threads(std::vector<std::thread> &threads);