Connection registry

Now we can use leak sanitizer. Yay!
This commit is contained in:
2025-08-21 18:09:36 -04:00
parent 810b5e006d
commit d1b1e6d589
7 changed files with 500 additions and 34 deletions

View File

@@ -128,6 +128,7 @@ set(SOURCES
src/main.cpp
src/config.cpp
src/connection.cpp
src/connection_registry.cpp
src/server.cpp
src/json_commit_request_parser.cpp
src/http_handler.cpp
@@ -157,6 +158,10 @@ add_executable(test_arena_allocator tests/test_arena_allocator.cpp
target_link_libraries(test_arena_allocator doctest::doctest)
target_include_directories(test_arena_allocator PRIVATE src)
add_executable(test_connection_registry tests/test_connection_registry.cpp)
target_link_libraries(test_connection_registry doctest::doctest)
target_include_directories(test_connection_registry PRIVATE src)
add_executable(
test_commit_request
tests/test_commit_request.cpp src/json_commit_request_parser.cpp
@@ -168,8 +173,9 @@ target_link_libraries(test_commit_request doctest::doctest weaseljson test_data
target_include_directories(test_commit_request PRIVATE src tests)
add_executable(
test_http_handler tests/test_http_handler.cpp src/http_handler.cpp
src/arena_allocator.cpp src/connection.cpp)
test_http_handler
tests/test_http_handler.cpp src/http_handler.cpp src/arena_allocator.cpp
src/connection.cpp src/connection_registry.cpp)
target_link_libraries(test_http_handler doctest::doctest llhttp_static
Threads::Threads perfetto)
target_include_directories(test_http_handler PRIVATE src)
@@ -213,6 +219,7 @@ add_executable(load_tester tools/load_tester.cpp)
target_link_libraries(load_tester Threads::Threads llhttp_static perfetto)
add_test(NAME arena_allocator_tests COMMAND test_arena_allocator)
add_test(NAME connection_registry_tests COMMAND test_connection_registry)
add_test(NAME commit_request_tests COMMAND test_commit_request)
add_test(NAME http_handler_tests COMMAND test_http_handler)
add_test(NAME arena_allocator_benchmarks COMMAND bench_arena_allocator)

View File

@@ -5,6 +5,8 @@
#include <errno.h>
#include <limits.h>
// TODO fix up this whole thing
std::unique_ptr<Connection>
Connection::createForServer(struct sockaddr_storage addr, int fd, int64_t id,
ConnectionHandler *handler,

View File

@@ -0,0 +1,83 @@
#include "connection_registry.hpp"
#include "connection.hpp"
#include <cstring>
#include <stdexcept>
#include <unistd.h>
ConnectionRegistry::ConnectionRegistry() : connections_(nullptr), max_fds_(0) {
// Get the process file descriptor limit
struct rlimit rlim;
if (getrlimit(RLIMIT_NOFILE, &rlim) == -1) {
throw std::runtime_error("Failed to get RLIMIT_NOFILE");
}
max_fds_ = rlim.rlim_cur;
// Allocate virtual address space using mmap
// This reserves virtual memory but doesn't allocate physical pages until
// touched
connections_ = static_cast<Connection **>(
mmap(nullptr, max_fds_ * sizeof(Connection *), PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
if (connections_ == MAP_FAILED) {
throw std::runtime_error("Failed to mmap for connection registry");
}
// Initialize all pointers to null
// This will cause physical pages to be allocated on-demand
memset(connections_, 0, max_fds_ * sizeof(Connection *));
}
ConnectionRegistry::~ConnectionRegistry() {
if (connections_ != MAP_FAILED && connections_ != nullptr) {
munmap(connections_, max_fds_ * sizeof(Connection *));
}
}
void ConnectionRegistry::store(int fd, std::unique_ptr<Connection> connection) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
return; // Invalid fd - silently ignore to avoid crashes
}
// Release ownership from unique_ptr and store raw pointer
connections_[fd] = connection.release();
}
bool ConnectionRegistry::has(int fd) const {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
return false; // Invalid fd
}
return connections_[fd] != nullptr;
}
std::unique_ptr<Connection> ConnectionRegistry::remove(int fd) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
return nullptr; // Invalid fd
}
Connection *conn = connections_[fd];
connections_[fd] = nullptr;
// Transfer ownership back to unique_ptr
return std::unique_ptr<Connection>(conn);
}
void ConnectionRegistry::shutdown_cleanup() {
// Iterate through all possible file descriptors and clean up any connections
// Following the critical ordering: remove -> delete (destructor handles
// close)
size_t connections_found = 0;
for (size_t fd = 0; fd < max_fds_; ++fd) {
Connection *conn = connections_[fd];
if (conn != nullptr) {
connections_found++;
// Step 1: Remove from registry (set to null)
connections_[fd] = nullptr;
// Steps 2 & 3: Delete the connection object (destructor handles closing
// fd)
delete conn;
}
}
// Note: In normal shutdown, this should be 0 since all connections
// should have been properly cleaned up during normal operation
}

View File

@@ -0,0 +1,93 @@
#pragma once
#include <cstddef>
#include <memory>
#include <sys/mman.h>
#include <sys/resource.h>
class Connection;
/**
* mmap-based Connection Registry for tracking active connections.
*
* This registry provides a lock-free mechanism for tracking all connections
* owned by the server, indexed by file descriptor. The design uses mmap to
* allocate a large virtual address space efficiently, with physical memory
* allocated on-demand as connections are created.
*
* CRITICAL ORDERING REQUIREMENT:
* All connection cleanup MUST follow this exact sequence:
* 1. Remove from registry: auto conn = registry.remove(fd)
* 2. Delete the connection: unique_ptr destructor handles it automatically
*
* This ordering prevents race conditions between cleanup and fd reuse.
* The unique_ptr interface ensures ownership is always clear and prevents
* double-delete bugs.
*/
class ConnectionRegistry {
public:
/**
* Initialize the connection registry.
* Allocates virtual address space based on RLIMIT_NOFILE.
*
* @throws std::runtime_error if mmap fails or RLIMIT_NOFILE cannot be read
*/
ConnectionRegistry();
/**
* Destructor ensures proper cleanup of mmap'd memory.
*/
~ConnectionRegistry();
/**
* Store a connection in the registry, indexed by its file descriptor.
* Takes ownership of the connection via unique_ptr.
*
* @param fd File descriptor (must be valid and < max_fds_)
* @param connection unique_ptr to the connection (ownership transferred)
*/
void store(int fd, std::unique_ptr<Connection> connection);
/**
* Check if a connection exists in the registry by file descriptor.
*
* @param fd File descriptor
* @return true if connection exists, false otherwise
*/
bool has(int fd) const;
/**
* Remove a connection from the registry and transfer ownership to caller.
* This transfers ownership via unique_ptr move semantics.
*
* @param fd File descriptor
* @return unique_ptr to the connection, or nullptr if not found
*/
std::unique_ptr<Connection> remove(int fd);
/**
* Get the maximum number of file descriptors supported.
*
* @return Maximum file descriptor limit
*/
size_t max_fds() const { return max_fds_; }
/**
* Perform graceful shutdown cleanup.
* Iterates through all registry entries and cleans up any remaining
* connections using the critical ordering: remove -> close -> delete.
*
* This method is called during server shutdown to ensure no connections leak.
*/
virtual void shutdown_cleanup();
// Non-copyable and non-movable
ConnectionRegistry(const ConnectionRegistry &) = delete;
ConnectionRegistry &operator=(const ConnectionRegistry &) = delete;
ConnectionRegistry(ConnectionRegistry &&) = delete;
ConnectionRegistry &operator=(ConnectionRegistry &&) = delete;
private:
Connection **connections_; ///< mmap'd array of raw connection pointers
size_t max_fds_; ///< Maximum file descriptor limit
};

View File

@@ -1,10 +1,12 @@
#include "server.hpp"
#include "connection.hpp"
#include "connection_registry.hpp"
#include <csignal>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fcntl.h>
#include <memory>
#include <netdb.h>
#include <netinet/tcp.h>
#include <pthread.h>
@@ -25,9 +27,18 @@ std::shared_ptr<Server> Server::create(const weaseldb::Config &config,
}
Server::Server(const weaseldb::Config &config, ConnectionHandler &handler)
: config_(config), handler_(handler) {}
: config_(config), handler_(handler), connection_registry_() {}
Server::~Server() { cleanup_resources(); }
Server::~Server() {
// CRITICAL: All I/O threads are guaranteed to be joined before the destructor
// is called because they are owned by the run() method's call frame.
// This eliminates any possibility of race conditions during connection
// cleanup.
// Clean up any remaining connections using proper ordering
connection_registry_.shutdown_cleanup();
cleanup_resources();
}
void Server::run() {
setup_shutdown_pipe();
@@ -36,12 +47,21 @@ void Server::run() {
create_epoll_instances();
start_io_threads();
// Create I/O threads locally in this call frame
// CRITICAL: By owning threads in run()'s call frame, we guarantee they are
// joined before run() returns, eliminating any race conditions in ~Server()
std::vector<std::thread> threads;
start_io_threads(threads);
// Wait for all threads to complete
for (auto &thread : threads_) {
// Wait for all threads to complete before returning
// This ensures all I/O threads are fully stopped before the Server
// destructor can be called, preventing race conditions during connection
// cleanup
for (auto &thread : threads) {
thread.join();
}
// At this point, all threads are joined and it's safe to destroy the Server
}
void Server::shutdown() {
@@ -285,11 +305,11 @@ int Server::get_epoll_for_thread(int thread_id) const {
return epoll_fds_[thread_id % epoll_fds_.size()];
}
void Server::start_io_threads() {
void Server::start_io_threads(std::vector<std::thread> &threads) {
int io_threads = config_.server.io_threads;
for (int thread_id = 0; thread_id < io_threads; ++thread_id) {
threads_.emplace_back([this, thread_id]() {
threads.emplace_back([this, thread_id]() {
pthread_setname_np(pthread_self(),
("io-" + std::to_string(thread_id)).c_str());
@@ -325,16 +345,20 @@ void Server::start_io_threads() {
}
// Handle existing connection events
std::unique_ptr<Connection> conn{
static_cast<Connection *>(events[i].data.ptr)};
conn->tsan_acquire();
events[i].data.ptr = nullptr;
if (events[i].events & EPOLLERR) {
continue; // Connection closed - unique_ptr destructor cleans up
std::unique_ptr<Connection> conn;
{
// borrowed
Connection *conn_ = static_cast<Connection *>(events[i].data.ptr);
conn_->tsan_acquire();
conn = connection_registry_.remove(conn_->getFd());
}
// Add to regular batch - I/O will be processed in batch
if (events[i].events & (EPOLLERR | EPOLLHUP)) {
// unique_ptr will automatically delete on scope exit
continue;
}
// Transfer ownership from registry to batch processing
batch[batch_count] = std::move(conn);
batch_events[batch_count] = events[i].events;
batch_count++;
@@ -378,11 +402,11 @@ void Server::start_io_threads() {
perror("setsockopt SO_KEEPALIVE");
}
// Add to batch - I/O will be processed in batch
batch[batch_count] = Connection::createForServer(
// Transfer ownership from registry to batch processing
batch[batch_count] = std::unique_ptr<Connection>(new Connection(
addr, fd,
connection_id_.fetch_add(1, std::memory_order_relaxed),
&handler_, weak_from_this());
&handler_, weak_from_this()));
batch_events[batch_count] =
EPOLLIN; // New connections always start with read
batch_count++;
@@ -473,28 +497,30 @@ void Server::process_connection_batch(
}
}
// Call post-batch handler
// Call post-batch handler - handlers can take ownership here
handler_.on_post_batch(batch);
// Transfer all remaining connections back to epoll
for (auto &conn : batch) {
if (conn) {
for (auto &conn_ptr : batch) {
if (conn_ptr) {
int fd = conn_ptr->getFd();
struct epoll_event event{};
if (!conn->hasMessages()) {
if (!conn_ptr->hasMessages()) {
event.events = EPOLLIN | EPOLLONESHOT;
} else {
event.events = EPOLLOUT | EPOLLONESHOT;
}
int fd = conn->getFd();
conn->tsan_release();
Connection *raw_conn = conn.release();
event.data.ptr = raw_conn;
conn_ptr->tsan_release();
event.data.ptr = conn_ptr.get(); // Use raw pointer for epoll
// Put connection back in registry since handler didn't take ownership.
// Must happen before epoll_ctl
connection_registry_.store(fd, std::move(conn_ptr));
int epoll_op = is_new ? EPOLL_CTL_ADD : EPOLL_CTL_MOD;
if (epoll_ctl(epollfd, epoll_op, fd, &event) == -1) {
perror(is_new ? "epoll_ctl ADD" : "epoll_ctl MOD");
delete raw_conn;
(void)connection_registry_.remove(fd);
}
}
}

View File

@@ -2,6 +2,7 @@
#include "config.hpp"
#include "connection_handler.hpp"
#include "connection_registry.hpp"
#include <atomic>
#include <memory>
#include <span>
@@ -99,8 +100,10 @@ private:
const weaseldb::Config &config_;
ConnectionHandler &handler_;
// Thread management
std::vector<std::thread> threads_;
// Connection registry
ConnectionRegistry connection_registry_;
// Connection management
std::atomic<int64_t> connection_id_{0};
// Round-robin counter for connection distribution
@@ -122,14 +125,14 @@ private:
void setup_signal_handling();
int create_listen_socket();
void create_epoll_instances();
void start_io_threads();
void start_io_threads(std::vector<std::thread> &threads);
void cleanup_resources();
// Helper to get epoll fd for a thread using round-robin
int get_epoll_for_thread(int thread_id) const;
// Helper for processing connection I/O
void process_connection_io(std::unique_ptr<Connection> &conn, int events);
void process_connection_io(std::unique_ptr<Connection> &conn_ptr, int events);
// Helper for processing a batch of connections with their events
void process_connection_batch(int epollfd,

View File

@@ -0,0 +1,252 @@
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
#include <doctest/doctest.h>
#include <cstring>
#include <stdexcept>
#include <sys/mman.h>
#include <sys/resource.h>
#include <unistd.h>
// Forward declare Connection for registry
class Connection;
// Simplified connection registry for testing (avoid linking issues)
class TestConnectionRegistry {
public:
TestConnectionRegistry() : connections_(nullptr), max_fds_(0) {
struct rlimit rlim;
if (getrlimit(RLIMIT_NOFILE, &rlim) == -1) {
throw std::runtime_error("Failed to get RLIMIT_NOFILE");
}
max_fds_ = rlim.rlim_cur;
connections_ = static_cast<Connection **>(
mmap(nullptr, max_fds_ * sizeof(Connection *), PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
if (connections_ == MAP_FAILED) {
throw std::runtime_error("Failed to mmap for connection registry");
}
memset(connections_, 0, max_fds_ * sizeof(Connection *));
}
~TestConnectionRegistry() {
if (connections_ != MAP_FAILED && connections_ != nullptr) {
munmap(connections_, max_fds_ * sizeof(Connection *));
}
}
void store(int fd, Connection *connection) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
return;
}
connections_[fd] = connection;
}
Connection *get(int fd) const {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
return nullptr;
}
return connections_[fd];
}
Connection *remove(int fd) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
return nullptr;
}
Connection *conn = connections_[fd];
connections_[fd] = nullptr;
return conn;
}
size_t max_fds() const { return max_fds_; }
private:
Connection **connections_;
size_t max_fds_;
};
// Mock Connection class for testing
class MockConnection {
public:
MockConnection(int id) : id_(id) {}
int getId() const { return id_; }
private:
int id_;
};
TEST_CASE("ConnectionRegistry basic functionality") {
TestConnectionRegistry registry;
SUBCASE("max_fds returns valid limit") {
struct rlimit rlim;
getrlimit(RLIMIT_NOFILE, &rlim);
CHECK(registry.max_fds() == rlim.rlim_cur);
CHECK(registry.max_fds() > 0);
}
SUBCASE("get returns nullptr for empty registry") {
CHECK(registry.get(0) == nullptr);
CHECK(registry.get(100) == nullptr);
CHECK(registry.get(1000) == nullptr);
}
SUBCASE("get handles invalid file descriptors") {
CHECK(registry.get(-1) == nullptr);
CHECK(registry.get(static_cast<int>(registry.max_fds())) == nullptr);
}
}
TEST_CASE("ConnectionRegistry store and retrieve") {
TestConnectionRegistry registry;
// Create some mock connections (using reinterpret_cast for testing)
MockConnection mock1(1);
MockConnection mock2(2);
Connection *conn1 = reinterpret_cast<Connection *>(&mock1);
Connection *conn2 = reinterpret_cast<Connection *>(&mock2);
SUBCASE("store and get single connection") {
registry.store(5, conn1);
CHECK(registry.get(5) == conn1);
// Other fds should still return nullptr
CHECK(registry.get(4) == nullptr);
CHECK(registry.get(6) == nullptr);
}
SUBCASE("store multiple connections") {
registry.store(5, conn1);
registry.store(10, conn2);
CHECK(registry.get(5) == conn1);
CHECK(registry.get(10) == conn2);
CHECK(registry.get(7) == nullptr);
}
SUBCASE("overwrite existing connection") {
registry.store(5, conn1);
CHECK(registry.get(5) == conn1);
registry.store(5, conn2);
CHECK(registry.get(5) == conn2);
}
SUBCASE("store handles invalid file descriptors safely") {
registry.store(-1, conn1); // Should not crash
registry.store(static_cast<int>(registry.max_fds()),
conn1); // Should not crash
CHECK(registry.get(-1) == nullptr);
CHECK(registry.get(static_cast<int>(registry.max_fds())) == nullptr);
}
}
TEST_CASE("ConnectionRegistry remove functionality") {
TestConnectionRegistry registry;
MockConnection mock1(1);
MockConnection mock2(2);
Connection *conn1 = reinterpret_cast<Connection *>(&mock1);
Connection *conn2 = reinterpret_cast<Connection *>(&mock2);
SUBCASE("remove existing connection") {
registry.store(5, conn1);
CHECK(registry.get(5) == conn1);
Connection *removed = registry.remove(5);
CHECK(removed == conn1);
CHECK(registry.get(5) == nullptr);
}
SUBCASE("remove non-existing connection") {
Connection *removed = registry.remove(5);
CHECK(removed == nullptr);
}
SUBCASE("remove after remove returns nullptr") {
registry.store(5, conn1);
Connection *removed1 = registry.remove(5);
Connection *removed2 = registry.remove(5);
CHECK(removed1 == conn1);
CHECK(removed2 == nullptr);
}
SUBCASE("remove handles invalid file descriptors") {
CHECK(registry.remove(-1) == nullptr);
CHECK(registry.remove(static_cast<int>(registry.max_fds())) == nullptr);
}
SUBCASE("remove doesn't affect other connections") {
registry.store(5, conn1);
registry.store(10, conn2);
Connection *removed = registry.remove(5);
CHECK(removed == conn1);
CHECK(registry.get(5) == nullptr);
CHECK(registry.get(10) == conn2); // Should remain unchanged
}
}
TEST_CASE("ConnectionRegistry large file descriptor handling") {
TestConnectionRegistry registry;
MockConnection mock1(1);
Connection *conn1 = reinterpret_cast<Connection *>(&mock1);
// Test with a large but valid file descriptor
int large_fd = static_cast<int>(registry.max_fds()) - 1;
SUBCASE("large valid fd works") {
registry.store(large_fd, conn1);
CHECK(registry.get(large_fd) == conn1);
Connection *removed = registry.remove(large_fd);
CHECK(removed == conn1);
CHECK(registry.get(large_fd) == nullptr);
}
}
TEST_CASE("ConnectionRegistry critical ordering simulation") {
TestConnectionRegistry registry;
MockConnection mock1(1);
Connection *conn1 = reinterpret_cast<Connection *>(&mock1);
int fd = 5;
SUBCASE("simulate proper cleanup ordering") {
// Step 1: Store connection
registry.store(fd, conn1);
CHECK(registry.get(fd) == conn1);
// Step 2: Remove from registry (critical ordering step 1)
Connection *removed = registry.remove(fd);
CHECK(removed == conn1);
CHECK(registry.get(fd) == nullptr);
// Steps 2 & 3 would be close(fd) and delete conn
// but we can't test those with mock objects
}
SUBCASE("simulate fd reuse safety") {
// Store connection
registry.store(fd, conn1);
// Remove from registry first (step 1)
Connection *removed = registry.remove(fd);
CHECK(removed == conn1);
// Registry is now clear - safe for fd reuse
CHECK(registry.get(fd) == nullptr);
// New connection could safely use same fd
MockConnection mock2(2);
Connection *conn2 = reinterpret_cast<Connection *>(&mock2);
registry.store(fd, conn2);
CHECK(registry.get(fd) == conn2);
}
}