Compare commits

..

13 Commits

12 changed files with 674 additions and 329 deletions

View File

@@ -128,6 +128,7 @@ set(SOURCES
src/main.cpp
src/config.cpp
src/connection.cpp
src/connection_registry.cpp
src/server.cpp
src/json_commit_request_parser.cpp
src/http_handler.cpp
@@ -168,14 +169,39 @@ target_link_libraries(test_commit_request doctest::doctest weaseljson test_data
target_include_directories(test_commit_request PRIVATE src tests)
add_executable(
test_http_handler tests/test_http_handler.cpp src/http_handler.cpp
src/arena_allocator.cpp src/connection.cpp)
test_http_handler
tests/test_http_handler.cpp src/http_handler.cpp src/arena_allocator.cpp
src/connection.cpp src/connection_registry.cpp)
target_link_libraries(test_http_handler doctest::doctest llhttp_static
Threads::Threads perfetto)
target_include_directories(test_http_handler PRIVATE src)
target_compile_definitions(test_http_handler
PRIVATE DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN)
add_executable(
test_server_connection_return
tests/test_server_connection_return.cpp
src/server.cpp
src/connection.cpp
src/connection_registry.cpp
src/arena_allocator.cpp
src/config.cpp
src/http_handler.cpp
${CMAKE_BINARY_DIR}/json_tokens.cpp)
add_dependencies(test_server_connection_return generate_json_tokens)
target_link_libraries(
test_server_connection_return
doctest::doctest
llhttp_static
Threads::Threads
toml11::toml11
perfetto
weaseljson
simdutf::simdutf)
target_include_directories(test_server_connection_return PRIVATE src)
target_compile_definitions(test_server_connection_return
PRIVATE DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN)
add_executable(bench_arena_allocator benchmarks/bench_arena_allocator.cpp
src/arena_allocator.cpp)
target_link_libraries(bench_arena_allocator nanobench)
@@ -215,6 +241,8 @@ target_link_libraries(load_tester Threads::Threads llhttp_static perfetto)
add_test(NAME arena_allocator_tests COMMAND test_arena_allocator)
add_test(NAME commit_request_tests COMMAND test_commit_request)
add_test(NAME http_handler_tests COMMAND test_http_handler)
add_test(NAME server_connection_return_tests
COMMAND test_server_connection_return)
add_test(NAME arena_allocator_benchmarks COMMAND bench_arena_allocator)
add_test(NAME commit_request_benchmarks COMMAND bench_commit_request)
add_test(NAME parser_comparison_benchmarks COMMAND bench_parser_comparison)

View File

@@ -57,6 +57,9 @@ ninja test # or ctest
**Debug tools:**
- `./debug_arena` - Analyze arena allocator behavior
**Load Testing:**
- `./load_tester` - A tool to generate load against the server for performance and stability analysis.
### Dependencies
**System requirements:**
@@ -114,6 +117,15 @@ Ultra-fast memory allocator optimized for request/response patterns:
- **Streaming data processing** with partial message handling
- **Connection lifecycle hooks** for initialization and cleanup
#### **Thread Pipeline** (`src/ThreadPipeline.h`)
A high-performance, multi-stage, lock-free pipeline for inter-thread communication.
- **Lock-Free Design**: Uses a shared ring buffer with atomic counters for coordination, avoiding locks for maximum throughput.
- **Multi-Stage Processing**: Allows items (like connections or data packets) to flow through a series of processing stages (e.g., from I/O threads to worker threads).
- **Batching Support**: Enables efficient batch processing of items to reduce overhead.
- **RAII Guards**: Utilizes RAII (`StageGuard`, `ProducerGuard`) to ensure thread-safe publishing and consumption of items in the pipeline, even in the presence of exceptions.
#### **Parsing Layer**
**JSON Commit Request Parser** (`src/json_commit_request_parser.{hpp,cpp}`):
@@ -286,6 +298,16 @@ This write-side component is designed to integrate with:
- **Configuration**: All configuration is TOML-based using `config.toml` (see `config.md`)
- **Testing Strategy**: Run unit tests, benchmarks, and debug tools before submitting changes
- **Build System**: CMake generates gperf hash tables at build time; always use ninja
- **Test Synchronization**:
- **ABSOLUTELY NEVER use sleep(), std::this_thread::sleep_for(), or any timeout-based waiting in tests**
- **NEVER use condition_variable.wait_for() or other timeout variants**
- Use deterministic synchronization only:
- **Blocking I/O** (blocking read/write calls that naturally wait)
- **condition_variable.wait()** with no timeout (waits indefinitely until condition is met)
- **std::latch, std::barrier, futures/promises** for coordination
- **RAII guards and resource management** for cleanup
- Tests should either pass (correct) or hang forever (indicates real bug to investigate)
- No timeouts, no flaky behavior, no false positives/negatives
---

View File

@@ -36,8 +36,6 @@
//
// Memory Model:
// - Ring buffer size must be power of 2 for efficient masking
// - Uses 64-bit indices to avoid ABA problems (indices never repeat until
// uint32_t overflow)
// - Actual ring slots accessed via: index & (slotCount - 1)
// - 128-byte aligned atomics prevent false sharing between CPU cache lines
//
@@ -50,7 +48,7 @@ template <class T> struct ThreadPipeline {
// Constructor
// lgSlotCount: log2 of ring buffer size (e.g., 10 -> 1024 slots)
// threadsPerStage: number of threads for each stage (e.g., {1, 4, 2} = 1
// producer, 4 stage-1 workers, 2 stage-2 workers)
// stage-0 worker, 4 stage-1 workers, 2 stage-2 workers)
ThreadPipeline(int lgSlotCount, const std::vector<int> &threadsPerStage)
: slotCount(1 << lgSlotCount), slotCountMask(slotCount - 1),
threadState(threadsPerStage.size()), ring(slotCount) {
@@ -230,9 +228,9 @@ private:
return safeLen;
}
struct alignas(128) ThreadState {
struct ThreadState {
// Where this thread has published up to
std::atomic<uint32_t> pops{0};
alignas(128) std::atomic<uint32_t> pops{0};
// Where this thread will publish to the next time it publishes
uint32_t localPops{0};
// Where the previous stage's threads have published up to last we checked
@@ -317,8 +315,8 @@ public:
// none available) Returns: StageGuard with batch of items to process
[[nodiscard]] StageGuard acquire(int stage, int thread, int maxBatch = 0,
bool mayBlock = true) {
assert(stage < threadState.size());
assert(thread < threadState[stage].size());
assert(stage < int(threadState.size()));
assert(thread < int(threadState[stage].size()));
auto batch = acquireHelper(stage, thread, maxBatch, mayBlock);
return StageGuard{std::move(batch), &threadState[stage][thread]};
}

View File

@@ -5,31 +5,23 @@
#include <errno.h>
#include <limits.h>
std::unique_ptr<Connection>
Connection::createForServer(struct sockaddr_storage addr, int fd, int64_t id,
ConnectionHandler *handler,
std::weak_ptr<Server> server) {
// Use unique_ptr constructor with private access (friend function)
// We can't use make_unique here because constructor is private
return std::unique_ptr<Connection>(
new Connection(addr, fd, id, handler, server));
}
Connection::Connection(struct sockaddr_storage addr, int fd, int64_t id,
ConnectionHandler *handler, std::weak_ptr<Server> server)
: fd_(fd), id_(id), addr_(addr), arena_(), handler_(handler),
server_(server) {
activeConnections.fetch_add(1, std::memory_order_relaxed);
if (handler_) {
size_t epoll_index, ConnectionHandler *handler,
Server &server)
: fd_(fd), id_(id), epoll_index_(epoll_index), addr_(addr), arena_(),
handler_(handler), server_(server.weak_from_this()) {
server.active_connections_.fetch_add(1, std::memory_order_relaxed);
assert(handler_);
handler_->on_connection_established(*this);
}
}
Connection::~Connection() {
if (handler_) {
handler_->on_connection_closed(*this);
}
activeConnections.fetch_sub(1, std::memory_order_relaxed);
if (auto server_ptr = server_.lock()) {
server_ptr->active_connections_.fetch_sub(1, std::memory_order_relaxed);
}
int e = close(fd_);
if (e == -1) {
perror("close");
@@ -124,15 +116,3 @@ bool Connection::writeBytes() {
return false;
}
void Connection::tsan_acquire() {
#if __has_feature(thread_sanitizer)
tsan_sync_.load(std::memory_order_acquire);
#endif
}
void Connection::tsan_release() {
#if __has_feature(thread_sanitizer)
tsan_sync_.store(0, std::memory_order_release);
#endif
}

View File

@@ -2,7 +2,6 @@
#include "arena_allocator.hpp"
#include "connection_handler.hpp"
#include <atomic>
#include <cassert>
#include <cstring>
#include <deque>
@@ -11,8 +10,6 @@
#include <sys/uio.h>
#include <unistd.h>
extern std::atomic<int> activeConnections;
#ifndef __has_feature
#define __has_feature(x) 0
#endif
@@ -322,45 +319,23 @@ private:
* @param fd File descriptor for the socket connection
* @param id Unique connection identifier generated by the server
* @param handler Protocol handler for processing connection data
* @param server Weak reference to the server for safe cleanup
* @param server Reference to server associated with this connection
*/
Connection(struct sockaddr_storage addr, int fd, int64_t id,
ConnectionHandler *handler, std::weak_ptr<Server> server);
/**
* @brief Server-only factory method for creating connections.
*
* This factory method can only be called by the Server class due to friend
* access. It provides controlled connection creation with proper resource
* management.
*
* @param addr Network address of the remote client (IPv4/IPv6 compatible)
* @param fd File descriptor for the socket connection
* @param id Unique connection identifier generated by the server
* @param handler Protocol handler for processing connection data
* @param server Weak reference to the server for safe cleanup
*
* @return std::unique_ptr<Connection> to the newly created connection
*
* @note This method is only accessible to the Server class and should be used
* exclusively by I/O threads when new connections arrive.
*/
static std::unique_ptr<Connection>
createForServer(struct sockaddr_storage addr, int fd, int64_t id,
ConnectionHandler *handler, std::weak_ptr<Server> server);
size_t epoll_index, ConnectionHandler *handler, Server &server);
// Networking interface - only accessible by Server
int readBytes(char *buf, size_t buffer_size);
bool writeBytes();
void tsan_acquire();
void tsan_release();
// Direct access methods for Server
int getFd() const { return fd_; }
bool hasMessages() const { return !messages_.empty(); }
bool shouldClose() const { return closeConnection_; }
size_t getEpollIndex() const { return epoll_index_; }
const int fd_;
const int64_t id_;
const size_t epoll_index_; // Index of the epoll instance this connection uses
struct sockaddr_storage addr_; // sockaddr_storage handles IPv4/IPv6
ArenaAllocator arena_;
ConnectionHandler *handler_;
@@ -372,9 +347,4 @@ private:
// Whether or not to close the connection after completing writing the
// response
bool closeConnection_{false};
// TSAN support for epoll synchronization
#if __has_feature(thread_sanitizer)
std::atomic<int> tsan_sync_;
#endif
};

View File

@@ -0,0 +1,60 @@
#include "connection_registry.hpp"
#include "connection.hpp"
#include <atomic>
#include <cstring>
#include <stdexcept>
#include <unistd.h>
ConnectionRegistry::ConnectionRegistry() : connections_(nullptr), max_fds_(0) {
// Get the process file descriptor limit
struct rlimit rlim;
if (getrlimit(RLIMIT_NOFILE, &rlim) == -1) {
throw std::runtime_error("Failed to get RLIMIT_NOFILE");
}
max_fds_ = rlim.rlim_cur;
// Calculate size rounded up to page boundary
size_t array_size = max_fds_ * sizeof(Connection *);
size_t page_size = getpagesize();
size_t aligned_size = (array_size + page_size - 1) & ~(page_size - 1);
// Allocate virtual address space using mmap
// MAP_ANONYMOUS provides zero-initialized pages on-demand (lazy allocation)
connections_ = static_cast<std::atomic<Connection *> *>(
mmap(nullptr, aligned_size, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
if (connections_ == MAP_FAILED) {
throw std::runtime_error("Failed to mmap for connection registry");
}
// Store aligned size for munmap
aligned_size_ = aligned_size;
}
ConnectionRegistry::~ConnectionRegistry() {
if (connections_ != nullptr) {
for (size_t fd = 0; fd < max_fds_; ++fd) {
delete connections_[fd].load(std::memory_order_relaxed);
}
if (munmap(connections_, aligned_size_) == -1) {
perror("munmap");
}
}
}
void ConnectionRegistry::store(int fd, std::unique_ptr<Connection> connection) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
abort();
}
// Release ownership from unique_ptr and store raw pointer
connections_[fd].store(connection.release(), std::memory_order_release);
}
std::unique_ptr<Connection> ConnectionRegistry::remove(int fd) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
abort();
}
return std::unique_ptr<Connection>(
connections_[fd].exchange(nullptr, std::memory_order_acquire));
}

View File

@@ -0,0 +1,72 @@
#pragma once
#include <cstddef>
#include <memory>
#include <sys/mman.h>
#include <sys/resource.h>
struct Connection;
/**
* mmap-based Connection Registry for tracking active connections.
*
* This registry provides a lock-free mechanism for tracking all connections
* owned by the server, indexed by file descriptor. The design uses mmap to
* allocate a large virtual address space efficiently, with physical memory
* allocated on-demand as connections are created.
*
*/
class ConnectionRegistry {
public:
/**
* Initialize the connection registry.
* Allocates virtual address space based on RLIMIT_NOFILE.
*
* @throws std::runtime_error if mmap fails or RLIMIT_NOFILE cannot be read
*/
ConnectionRegistry();
/**
* Destructor ensures proper cleanup of mmap'd memory.
*/
~ConnectionRegistry();
/**
* Store a connection in the registry, indexed by its file descriptor.
* Takes ownership of the connection via unique_ptr.
*
* @param fd File descriptor (must be valid and < max_fds_)
* @param connection unique_ptr to the connection (ownership transferred)
*/
void store(int fd, std::unique_ptr<Connection> connection);
/**
* Remove a connection from the registry and transfer ownership to caller.
* This transfers ownership via unique_ptr move semantics.
*
* @param fd File descriptor
* @return unique_ptr to the connection, or nullptr if not found
*/
std::unique_ptr<Connection> remove(int fd);
/**
* Get the maximum number of file descriptors supported.
*
* @return Maximum file descriptor limit
*/
size_t max_fds() const { return max_fds_; }
// Non-copyable and non-movable
ConnectionRegistry(const ConnectionRegistry &) = delete;
ConnectionRegistry &operator=(const ConnectionRegistry &) = delete;
ConnectionRegistry(ConnectionRegistry &&) = delete;
ConnectionRegistry &operator=(ConnectionRegistry &&) = delete;
private:
std::atomic<Connection *>
*connections_; ///< mmap'd array of raw connection pointers. It's
///< thread-safe without since epoll_ctl happens before
///< epoll_wait, but this makes tsan happy /shrug.
size_t max_fds_; ///< Maximum file descriptor limit
size_t aligned_size_; ///< Page-aligned size for munmap
};

View File

@@ -6,14 +6,17 @@
#include "server.hpp"
#include <atomic>
#include <csignal>
#include <fcntl.h>
#include <iostream>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include <vector>
PERFETTO_TRACK_EVENT_STATIC_STORAGE();
// TODO this should be scoped to a particular Server, and it's definition should
// be in server.cpp or connection.cpp
std::atomic<int> activeConnections{0};
// Global server instance for signal handler access
static Server *g_server = nullptr;
@@ -25,6 +28,117 @@ void signal_handler(int sig) {
}
}
std::vector<int> create_listen_sockets(const weaseldb::Config &config) {
std::vector<int> listen_fds;
// Check if unix socket path is specified
if (!config.server.unix_socket_path.empty()) {
// Create unix socket
int sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
perror("socket");
throw std::runtime_error("Failed to create unix socket");
}
// Remove existing socket file if it exists
unlink(config.server.unix_socket_path.c_str());
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
if (config.server.unix_socket_path.length() >= sizeof(addr.sun_path)) {
close(sfd);
throw std::runtime_error("Unix socket path too long");
}
strncpy(addr.sun_path, config.server.unix_socket_path.c_str(),
sizeof(addr.sun_path) - 1);
if (bind(sfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
perror("bind");
close(sfd);
throw std::runtime_error("Failed to bind unix socket");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on unix socket");
}
listen_fds.push_back(sfd);
return listen_fds;
}
// TCP socket creation
struct addrinfo hints;
struct addrinfo *result, *rp;
int s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */
hints.ai_canonname = nullptr;
hints.ai_addr = nullptr;
hints.ai_next = nullptr;
s = getaddrinfo(config.server.bind_address.c_str(),
std::to_string(config.server.port).c_str(), &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
throw std::runtime_error("Failed to resolve bind address");
}
int sfd = -1;
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
int val = 1;
if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == -1) {
perror("setsockopt SO_REUSEADDR");
close(sfd);
continue;
}
// Enable TCP_NODELAY for low latency (only for TCP sockets)
if (rp->ai_family == AF_INET || rp->ai_family == AF_INET6) {
if (setsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) == -1) {
perror("setsockopt TCP_NODELAY");
close(sfd);
continue;
}
}
if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
sfd = -1;
}
freeaddrinfo(result);
if (rp == nullptr || sfd == -1) {
throw std::runtime_error("Could not bind to any address");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on socket");
}
listen_fds.push_back(sfd);
return listen_fds;
}
void print_help(const char *program_name) {
std::cout << "WeaselDB - High-performance write-side database component\n\n";
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
@@ -121,9 +235,18 @@ int main(int argc, char *argv[]) {
<< config->subscription.keepalive_interval.count() << " seconds"
<< std::endl;
// Create listen sockets
std::vector<int> listen_fds;
try {
listen_fds = create_listen_sockets(*config);
} catch (const std::exception &e) {
std::cerr << "Failed to create listen sockets: " << e.what() << std::endl;
return 1;
}
// Create handler and server
HttpHandler http_handler;
auto server = Server::create(*config, http_handler);
auto server = Server::create(*config, http_handler, listen_fds);
g_server = server.get();
// Setup signal handling

View File

@@ -1,10 +1,12 @@
#include "server.hpp"
#include "connection.hpp"
#include "connection_registry.hpp"
#include <csignal>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fcntl.h>
#include <memory>
#include <netdb.h>
#include <netinet/tcp.h>
#include <pthread.h>
@@ -15,33 +17,91 @@
#include <unistd.h>
#include <vector>
extern std::atomic<int> activeConnections;
std::shared_ptr<Server> Server::create(const weaseldb::Config &config,
ConnectionHandler &handler) {
ConnectionHandler &handler,
const std::vector<int> &listen_fds) {
// Use std::shared_ptr constructor with private access
// We can't use make_shared here because constructor is private
return std::shared_ptr<Server>(new Server(config, handler));
return std::shared_ptr<Server>(new Server(config, handler, listen_fds));
}
Server::Server(const weaseldb::Config &config, ConnectionHandler &handler)
: config_(config), handler_(handler) {}
Server::Server(const weaseldb::Config &config, ConnectionHandler &handler,
const std::vector<int> &provided_listen_fds)
: config_(config), handler_(handler), connection_registry_(),
listen_fds_(provided_listen_fds) {
// Server takes ownership of all provided listen fds
// Ensure all listen fds are non-blocking for safe epoll usage
for (int fd : listen_fds_) {
int flags = fcntl(fd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL on provided listen fd");
throw std::runtime_error("Failed to get flags for provided listen fd");
}
if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL O_NONBLOCK on provided listen fd");
throw std::runtime_error("Failed to set provided listen fd non-blocking");
}
}
Server::~Server() { cleanup_resources(); }
void Server::run() {
// Setup shutdown pipe for graceful shutdown
setup_shutdown_pipe();
listen_sockfd_ = create_listen_socket();
// Create epoll instances immediately for createLocalConnection() support
create_epoll_instances();
start_io_threads();
// If empty vector provided, listen_fds_ will be empty (no listening)
// Server works purely with createLocalConnection()
}
// Wait for all threads to complete
for (auto &thread : threads_) {
Server::~Server() {
if (shutdown_pipe_[0] != -1) {
close(shutdown_pipe_[0]);
shutdown_pipe_[0] = -1;
}
if (shutdown_pipe_[1] != -1) {
close(shutdown_pipe_[1]);
shutdown_pipe_[1] = -1;
}
// Close all epoll instances
for (int epollfd : epoll_fds_) {
if (epollfd != -1) {
close(epollfd);
}
}
epoll_fds_.clear();
// Close all listen sockets (Server always owns them)
for (int fd : listen_fds_) {
if (fd != -1) {
close(fd);
}
}
// Clean up unix socket file if it exists
if (!config_.server.unix_socket_path.empty()) {
unlink(config_.server.unix_socket_path.c_str());
}
}
void Server::run() {
// Shutdown pipe and epoll instances are now created in constructor
// Create I/O threads locally in this call frame
// CRITICAL: By owning threads in run()'s call frame, we guarantee they are
// joined before run() returns, eliminating any race conditions in ~Server()
std::vector<std::thread> threads;
start_io_threads(threads);
// Wait for all threads to complete before returning
// This ensures all I/O threads are fully stopped before the Server
// destructor can be called, preventing race conditions during connection
// cleanup
for (auto &thread : threads) {
thread.join();
}
// At this point, all threads are joined and it's safe to destroy the Server
}
void Server::shutdown() {
@@ -64,16 +124,19 @@ void Server::releaseBackToServer(std::unique_ptr<Connection> connection) {
// Try to get the server from the connection's weak_ptr
if (auto server = connection->server_.lock()) {
// Server still exists - release raw pointer and let server take over
Connection *raw_conn = connection.release();
server->receiveConnectionBack(raw_conn);
// Server still exists - pass unique_ptr directly
server->receiveConnectionBack(std::move(connection));
}
// If server is gone, connection will be automatically cleaned up when
// unique_ptr destructs
}
void Server::receiveConnectionBack(Connection *connection) {
void Server::receiveConnectionBack(std::unique_ptr<Connection> connection) {
if (!connection) {
return; // Nothing to process
}
// Re-add the connection to epoll for continued processing
struct epoll_event event{};
@@ -83,19 +146,76 @@ void Server::receiveConnectionBack(Connection *connection) {
event.events = EPOLLOUT | EPOLLONESHOT;
}
connection->tsan_release();
event.data.ptr = connection;
int fd = connection->getFd();
event.data.fd = fd;
// Distribute connections round-robin across epoll instances
// Store connection in registry before adding to epoll
// This mirrors the pattern used in process_connection_batch
size_t epoll_index = connection->getEpollIndex();
int epollfd = epoll_fds_[epoll_index];
connection_registry_.store(fd, std::move(connection));
if (epoll_ctl(epollfd, EPOLL_CTL_MOD, fd, &event) == -1) {
perror("epoll_ctl MOD in receiveConnectionBack");
// Remove from registry and clean up on failure
(void)connection_registry_.remove(fd);
}
}
int Server::createLocalConnection() {
int sockets[2];
if (socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) != 0) {
perror("socketpair");
return -1;
}
int server_fd = sockets[0]; // Server keeps this end
int client_fd = sockets[1]; // Return this end to caller
int flags = fcntl(server_fd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL on provided listen fd");
throw std::runtime_error(
"Failed to get flags for server side of local connection");
}
if (fcntl(server_fd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL O_NONBLOCK on provided listen fd");
throw std::runtime_error(
"Failed to set server side of local connection to non-blocking");
}
// Create sockaddr_storage for the connection
struct sockaddr_storage addr{};
addr.ss_family = AF_UNIX;
// Calculate epoll_index for connection distribution
size_t epoll_index =
connection_distribution_counter_.fetch_add(1, std::memory_order_relaxed) %
epoll_fds_.size();
int epollfd = epoll_fds_[epoll_index];
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, connection->getFd(), &event) == -1) {
perror("epoll_ctl ADD in receiveConnectionBack");
delete connection; // Clean up on failure
// Create Connection object
auto connection = std::unique_ptr<Connection>(new Connection(
addr, server_fd, connection_id_.fetch_add(1, std::memory_order_relaxed),
epoll_index, &handler_, *this));
// Store in registry
connection_registry_.store(server_fd, std::move(connection));
// Add to appropriate epoll instance
struct epoll_event event{};
event.events = EPOLLIN | EPOLLONESHOT;
event.data.fd = server_fd;
int epollfd = epoll_fds_[epoll_index];
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, server_fd, &event) == -1) {
perror("epoll_ctl ADD local connection");
connection_registry_.remove(server_fd);
close(server_fd);
close(client_fd);
return -1;
}
return client_fd;
}
void Server::setup_shutdown_pipe() {
@@ -112,139 +232,6 @@ void Server::setup_shutdown_pipe() {
}
}
int Server::create_listen_socket() {
int sfd;
// Check if unix socket path is specified
if (!config_.server.unix_socket_path.empty()) {
// Create unix socket
sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
perror("socket");
throw std::runtime_error("Failed to create unix socket");
}
// Remove existing socket file if it exists
unlink(config_.server.unix_socket_path.c_str());
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
if (config_.server.unix_socket_path.length() >= sizeof(addr.sun_path)) {
close(sfd);
throw std::runtime_error("Unix socket path too long");
}
strncpy(addr.sun_path, config_.server.unix_socket_path.c_str(),
sizeof(addr.sun_path) - 1);
// Set socket to non-blocking for graceful shutdown
int flags = fcntl(sfd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL");
close(sfd);
throw std::runtime_error("Failed to get socket flags");
}
if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL");
close(sfd);
throw std::runtime_error("Failed to set socket non-blocking");
}
if (bind(sfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
perror("bind");
close(sfd);
throw std::runtime_error("Failed to bind unix socket");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on unix socket");
}
return sfd;
}
// TCP socket creation (original code)
struct addrinfo hints;
struct addrinfo *result, *rp;
int s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */
hints.ai_canonname = nullptr;
hints.ai_addr = nullptr;
hints.ai_next = nullptr;
s = getaddrinfo(config_.server.bind_address.c_str(),
std::to_string(config_.server.port).c_str(), &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
throw std::runtime_error("Failed to resolve bind address");
}
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
int val = 1;
if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == -1) {
perror("setsockopt SO_REUSEADDR");
close(sfd);
continue;
}
// Enable TCP_NODELAY for low latency (only for TCP sockets)
if (rp->ai_family == AF_INET || rp->ai_family == AF_INET6) {
if (setsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) == -1) {
perror("setsockopt TCP_NODELAY");
close(sfd);
continue;
}
}
// Set socket to non-blocking for graceful shutdown
int flags = fcntl(sfd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl F_GETFL");
close(sfd);
continue;
}
if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl F_SETFL");
close(sfd);
continue;
}
if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
}
freeaddrinfo(result);
if (rp == nullptr) {
throw std::runtime_error("Could not bind to any address");
}
if (listen(sfd, SOMAXCONN) == -1) {
perror("listen");
close(sfd);
throw std::runtime_error("Failed to listen on socket");
}
return sfd;
}
void Server::create_epoll_instances() {
// Create multiple epoll instances to reduce contention
epoll_fds_.resize(config_.server.epoll_instances);
@@ -259,7 +246,7 @@ void Server::create_epoll_instances() {
// Add shutdown pipe to each epoll instance
struct epoll_event shutdown_event;
shutdown_event.events = EPOLLIN;
shutdown_event.data.ptr = const_cast<char *>(&shutdown_pipe_tag);
shutdown_event.data.fd = shutdown_pipe_[0];
if (epoll_ctl(epoll_fds_[i], EPOLL_CTL_ADD, shutdown_pipe_[0],
&shutdown_event) == -1) {
@@ -267,17 +254,19 @@ void Server::create_epoll_instances() {
throw std::runtime_error("Failed to add shutdown pipe to epoll");
}
// Add listen socket to each epoll instance with EPOLLEXCLUSIVE to prevent
// thundering herd
// Add all listen sockets to each epoll instance with EPOLLEXCLUSIVE to
// prevent thundering herd
for (int listen_fd : listen_fds_) {
struct epoll_event listen_event;
listen_event.events = EPOLLIN | EPOLLEXCLUSIVE;
listen_event.data.ptr = const_cast<char *>(&listen_socket_tag);
if (epoll_ctl(epoll_fds_[i], EPOLL_CTL_ADD, listen_sockfd_,
&listen_event) == -1) {
listen_event.data.fd = listen_fd;
if (epoll_ctl(epoll_fds_[i], EPOLL_CTL_ADD, listen_fd, &listen_event) ==
-1) {
perror("epoll_ctl listen socket");
throw std::runtime_error("Failed to add listen socket to epoll");
}
}
}
}
int Server::get_epoll_for_thread(int thread_id) const {
@@ -285,11 +274,11 @@ int Server::get_epoll_for_thread(int thread_id) const {
return epoll_fds_[thread_id % epoll_fds_.size()];
}
void Server::start_io_threads() {
void Server::start_io_threads(std::vector<std::thread> &threads) {
int io_threads = config_.server.io_threads;
for (int thread_id = 0; thread_id < io_threads; ++thread_id) {
threads_.emplace_back([this, thread_id]() {
threads.emplace_back([this, thread_id]() {
pthread_setname_np(pthread_self(),
("io-" + std::to_string(thread_id)).c_str());
@@ -299,6 +288,8 @@ void Server::start_io_threads() {
struct epoll_event events[config_.server.event_batch_size];
std::unique_ptr<Connection> batch[config_.server.event_batch_size];
int batch_events[config_.server.event_batch_size];
std::vector<int>
ready_listen_fds; // Reused across iterations to avoid allocation
for (;;) {
int event_count =
@@ -311,30 +302,33 @@ void Server::start_io_threads() {
abort();
}
bool listenReady = false;
ready_listen_fds.clear(); // Clear from previous iteration
int batch_count = 0;
for (int i = 0; i < event_count; ++i) {
// Check for shutdown event
if (events[i].data.ptr == &shutdown_pipe_tag) {
if (events[i].data.fd == shutdown_pipe_[0]) {
return;
}
// Check for new connections
if (events[i].data.ptr == &listen_socket_tag) {
listenReady = true;
// Check for new connections on any listen socket
bool isListenSocket =
std::find(listen_fds_.begin(), listen_fds_.end(),
events[i].data.fd) != listen_fds_.end();
if (isListenSocket) {
ready_listen_fds.push_back(events[i].data.fd);
continue;
}
// Handle existing connection events
std::unique_ptr<Connection> conn{
static_cast<Connection *>(events[i].data.ptr)};
conn->tsan_acquire();
events[i].data.ptr = nullptr;
int fd = events[i].data.fd;
std::unique_ptr<Connection> conn = connection_registry_.remove(fd);
assert(conn);
if (events[i].events & EPOLLERR) {
continue; // Connection closed - unique_ptr destructor cleans up
if (events[i].events & (EPOLLERR | EPOLLHUP)) {
// unique_ptr will automatically delete on scope exit
continue;
}
// Add to regular batch - I/O will be processed in batch
// Transfer ownership from registry to batch processing
batch[batch_count] = std::move(conn);
batch_events[batch_count] = events[i].events;
batch_count++;
@@ -346,17 +340,17 @@ void Server::start_io_threads() {
{batch_events, (size_t)batch_count}, false);
}
// Reuse same batch array for accepting connections
if (listenReady) {
// Only accept on listen sockets that epoll indicates are ready
for (int listen_fd : ready_listen_fds) {
for (;;) {
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
int fd = accept4(listen_sockfd_, (struct sockaddr *)&addr, &addrlen,
int fd = accept4(listen_fd, (struct sockaddr *)&addr, &addrlen,
SOCK_NONBLOCK);
if (fd == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
break;
break; // Try next listen socket
if (errno == EINTR)
continue;
perror("accept4");
@@ -365,7 +359,7 @@ void Server::start_io_threads() {
// Check connection limit
if (config_.server.max_connections > 0 &&
activeConnections.load(std::memory_order_relaxed) >=
active_connections_.load(std::memory_order_relaxed) >=
config_.server.max_connections) {
close(fd);
continue;
@@ -378,11 +372,12 @@ void Server::start_io_threads() {
perror("setsockopt SO_KEEPALIVE");
}
// Add to batch - I/O will be processed in batch
batch[batch_count] = Connection::createForServer(
// Transfer ownership from registry to batch processing
size_t epoll_index = thread_id % epoll_fds_.size();
batch[batch_count] = std::unique_ptr<Connection>(new Connection(
addr, fd,
connection_id_.fetch_add(1, std::memory_order_relaxed),
&handler_, weak_from_this());
epoll_index, &handler_, *this));
batch_events[batch_count] =
EPOLLIN; // New connections always start with read
batch_count++;
@@ -394,7 +389,8 @@ void Server::start_io_threads() {
true);
batch_count = 0;
}
}
} // End inner accept loop
} // End loop over listen_fds_
// Process remaining accepted connections
if (batch_count > 0) {
@@ -403,7 +399,6 @@ void Server::start_io_threads() {
batch_count = 0;
}
}
}
});
}
}
@@ -473,58 +468,30 @@ void Server::process_connection_batch(
}
}
// Call post-batch handler
// Call post-batch handler - handlers can take ownership here
handler_.on_post_batch(batch);
// Transfer all remaining connections back to epoll
for (auto &conn : batch) {
if (conn) {
for (auto &conn_ptr : batch) {
if (conn_ptr) {
int fd = conn_ptr->getFd();
struct epoll_event event{};
if (!conn->hasMessages()) {
if (!conn_ptr->hasMessages()) {
event.events = EPOLLIN | EPOLLONESHOT;
} else {
event.events = EPOLLOUT | EPOLLONESHOT;
}
int fd = conn->getFd();
conn->tsan_release();
Connection *raw_conn = conn.release();
event.data.ptr = raw_conn;
event.data.fd = fd; // Use file descriptor for epoll
// Put connection back in registry since handler didn't take ownership.
// Must happen before epoll_ctl
connection_registry_.store(fd, std::move(conn_ptr));
int epoll_op = is_new ? EPOLL_CTL_ADD : EPOLL_CTL_MOD;
if (epoll_ctl(epollfd, epoll_op, fd, &event) == -1) {
perror(is_new ? "epoll_ctl ADD" : "epoll_ctl MOD");
delete raw_conn;
(void)connection_registry_.remove(fd);
}
}
}
}
void Server::cleanup_resources() {
if (shutdown_pipe_[0] != -1) {
close(shutdown_pipe_[0]);
shutdown_pipe_[0] = -1;
}
if (shutdown_pipe_[1] != -1) {
close(shutdown_pipe_[1]);
shutdown_pipe_[1] = -1;
}
// Close all epoll instances
for (int epollfd : epoll_fds_) {
if (epollfd != -1) {
close(epollfd);
}
}
epoll_fds_.clear();
if (listen_sockfd_ != -1) {
close(listen_sockfd_);
listen_sockfd_ = -1;
}
// Clean up unix socket file if it exists
if (!config_.server.unix_socket_path.empty()) {
unlink(config_.server.unix_socket_path.c_str());
}
}

View File

@@ -2,6 +2,7 @@
#include "config.hpp"
#include "connection_handler.hpp"
#include "connection_registry.hpp"
#include <atomic>
#include <memory>
#include <span>
@@ -42,10 +43,15 @@ public:
*
* @param config Server configuration (threads, ports, limits, etc.)
* @param handler Protocol handler for processing connection data
* @param listen_fds Vector of file descriptors to accept connections on.
* Server takes ownership and will close them on
* destruction. Server will set these to non-blocking mode for safe epoll
* usage. Empty vector means no listening sockets.
* @return shared_ptr to the newly created Server
*/
static std::shared_ptr<Server> create(const weaseldb::Config &config,
ConnectionHandler &handler);
ConnectionHandler &handler,
const std::vector<int> &listen_fds);
/**
* Destructor ensures proper cleanup of all resources.
@@ -74,6 +80,20 @@ public:
*/
void shutdown();
/**
* Creates a local connection using socketpair() for testing or local IPC.
*
* Creates a socketpair, registers one end as a Connection in the server,
* and returns the other end to the caller for communication.
*
* The caller takes ownership of the returned file descriptor and must close
* it.
*
* @return File descriptor for the client end of the socketpair, or -1 on
* error
*/
int createLocalConnection();
/**
* Release a connection back to its server for continued processing.
*
@@ -88,20 +108,29 @@ public:
static void releaseBackToServer(std::unique_ptr<Connection> connection);
private:
friend struct Connection;
/**
* Private constructor - use create() factory method instead.
*
* @param config Server configuration (threads, ports, limits, etc.)
* @param handler Protocol handler for processing connection data
* @param listen_fds Vector of file descriptors to accept connections on.
* Server takes ownership and will close them on
* destruction. Server will set these to non-blocking mode for safe epoll
* usage.
*/
explicit Server(const weaseldb::Config &config, ConnectionHandler &handler);
explicit Server(const weaseldb::Config &config, ConnectionHandler &handler,
const std::vector<int> &listen_fds);
const weaseldb::Config &config_;
ConnectionHandler &handler_;
// Thread management
std::vector<std::thread> threads_;
// Connection registry
ConnectionRegistry connection_registry_;
// Connection management
std::atomic<int64_t> connection_id_{0};
std::atomic<int> active_connections_{0};
// Round-robin counter for connection distribution
std::atomic<size_t> connection_distribution_counter_{0};
@@ -111,25 +140,20 @@ private:
// Multiple epoll file descriptors to reduce contention
std::vector<int> epoll_fds_;
int listen_sockfd_ = -1;
// Unique tags for special events to avoid type confusion in epoll data union
static inline const char listen_socket_tag = 0;
static inline const char shutdown_pipe_tag = 0;
std::vector<int>
listen_fds_; // FDs to accept connections on (Server owns these)
// Private helper methods
void setup_shutdown_pipe();
void setup_signal_handling();
int create_listen_socket();
void create_epoll_instances();
void start_io_threads();
void cleanup_resources();
void start_io_threads(std::vector<std::thread> &threads);
// Helper to get epoll fd for a thread using round-robin
int get_epoll_for_thread(int thread_id) const;
// Helper for processing connection I/O
void process_connection_io(std::unique_ptr<Connection> &conn, int events);
void process_connection_io(std::unique_ptr<Connection> &conn_ptr, int events);
// Helper for processing a batch of connections with their events
void process_connection_batch(int epollfd,
@@ -142,9 +166,9 @@ private:
* This method is thread-safe and can be called from any thread.
* The connection will be re-added to the epoll for continued processing.
*
* @param connection Raw pointer to the connection being released back
* @param connection Unique pointer to the connection being released back
*/
void receiveConnectionBack(Connection *connection);
void receiveConnectionBack(std::unique_ptr<Connection> connection);
// Make non-copyable and non-movable
Server(const Server &) = delete;

View File

@@ -0,0 +1,96 @@
#include "../src/ThreadPipeline.h"
#include "config.hpp"
#include "connection.hpp"
#include "perfetto_categories.hpp"
#include "server.hpp"
#include <doctest/doctest.h>
#include <thread>
// Perfetto static storage for tests
PERFETTO_TRACK_EVENT_STATIC_STORAGE();
struct Message {
std::unique_ptr<Connection> conn;
std::string data;
bool done;
};
struct EchoHandler : public ConnectionHandler {
private:
ThreadPipeline<Message> &pipeline;
public:
explicit EchoHandler(ThreadPipeline<Message> &pipeline)
: pipeline(pipeline) {}
void on_data_arrived(std::string_view data,
std::unique_ptr<Connection> &conn_ptr) override {
assert(conn_ptr);
auto guard = pipeline.push(1, true);
for (auto &message : guard.batch) {
message.conn = std::move(conn_ptr);
message.data = data;
message.done = false;
}
}
};
TEST_CASE("Echo server with connection ownership transfer") {
weaseldb::Config config;
config.server.io_threads = 1;
config.server.epoll_instances = 1;
ThreadPipeline<Message> pipeline{10, {1}};
EchoHandler handler{pipeline};
auto echoThread = std::thread{[&]() {
for (;;) {
auto guard = pipeline.acquire(0, 0);
for (auto &message : guard.batch) {
bool done = message.done;
if (done) {
return;
}
assert(message.conn);
message.conn->appendMessage(message.data);
Server::releaseBackToServer(std::move(message.conn));
}
}
}};
// Create server with NO listen sockets (empty vector)
auto server = Server::create(config, handler, {});
std::thread server_thread([&server]() { server->run(); });
// Create local connection
int client_fd = server->createLocalConnection();
REQUIRE(client_fd > 0);
// Write some test data
const char *test_message = "Hello, World!";
ssize_t bytes_written = write(client_fd, test_message, strlen(test_message));
REQUIRE(bytes_written == strlen(test_message));
// Read the echoed response
char buffer[1024] = {0};
ssize_t bytes_read = read(client_fd, buffer, sizeof(buffer) - 1);
if (bytes_read == -1) {
perror("read failed");
}
REQUIRE(bytes_read == strlen(test_message));
// Verify we got back exactly what we sent
CHECK(std::string(buffer, bytes_read) == std::string(test_message));
// Cleanup
close(client_fd);
server->shutdown();
server_thread.join();
{
auto guard = pipeline.push(1, true);
for (auto &message : guard.batch) {
message.done = true;
}
}
echoThread.join();
}

5
tsan.suppressions Normal file
View File

@@ -0,0 +1,5 @@
# ThreadSanitizer suppressions for WeaselDB
# This file suppresses known false positives in ThreadSanitizer
# tsan doesn't seem to understand that epoll_ctl happens before epoll_wait returns
race:epoll_ctl