From d1b1e6d589b455dec12a760c67224c6be5a9c50a Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Thu, 21 Aug 2025 18:09:36 -0400 Subject: [PATCH] Connection registry Now we can use leak sanitizer. Yay! --- CMakeLists.txt | 11 +- src/connection.cpp | 2 + src/connection_registry.cpp | 83 ++++++++++ src/connection_registry.hpp | 93 +++++++++++ src/server.cpp | 82 ++++++---- src/server.hpp | 11 +- tests/test_connection_registry.cpp | 252 +++++++++++++++++++++++++++++ 7 files changed, 500 insertions(+), 34 deletions(-) create mode 100644 src/connection_registry.cpp create mode 100644 src/connection_registry.hpp create mode 100644 tests/test_connection_registry.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index dbace11..c9fdc6e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -128,6 +128,7 @@ set(SOURCES src/main.cpp src/config.cpp src/connection.cpp + src/connection_registry.cpp src/server.cpp src/json_commit_request_parser.cpp src/http_handler.cpp @@ -157,6 +158,10 @@ add_executable(test_arena_allocator tests/test_arena_allocator.cpp target_link_libraries(test_arena_allocator doctest::doctest) target_include_directories(test_arena_allocator PRIVATE src) +add_executable(test_connection_registry tests/test_connection_registry.cpp) +target_link_libraries(test_connection_registry doctest::doctest) +target_include_directories(test_connection_registry PRIVATE src) + add_executable( test_commit_request tests/test_commit_request.cpp src/json_commit_request_parser.cpp @@ -168,8 +173,9 @@ target_link_libraries(test_commit_request doctest::doctest weaseljson test_data target_include_directories(test_commit_request PRIVATE src tests) add_executable( - test_http_handler tests/test_http_handler.cpp src/http_handler.cpp - src/arena_allocator.cpp src/connection.cpp) + test_http_handler + tests/test_http_handler.cpp src/http_handler.cpp src/arena_allocator.cpp + src/connection.cpp src/connection_registry.cpp) target_link_libraries(test_http_handler doctest::doctest llhttp_static Threads::Threads perfetto) target_include_directories(test_http_handler PRIVATE src) @@ -213,6 +219,7 @@ add_executable(load_tester tools/load_tester.cpp) target_link_libraries(load_tester Threads::Threads llhttp_static perfetto) add_test(NAME arena_allocator_tests COMMAND test_arena_allocator) +add_test(NAME connection_registry_tests COMMAND test_connection_registry) add_test(NAME commit_request_tests COMMAND test_commit_request) add_test(NAME http_handler_tests COMMAND test_http_handler) add_test(NAME arena_allocator_benchmarks COMMAND bench_arena_allocator) diff --git a/src/connection.cpp b/src/connection.cpp index 97ba56c..94a7f55 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -5,6 +5,8 @@ #include #include +// TODO fix up this whole thing + std::unique_ptr Connection::createForServer(struct sockaddr_storage addr, int fd, int64_t id, ConnectionHandler *handler, diff --git a/src/connection_registry.cpp b/src/connection_registry.cpp new file mode 100644 index 0000000..2285d2e --- /dev/null +++ b/src/connection_registry.cpp @@ -0,0 +1,83 @@ +#include "connection_registry.hpp" +#include "connection.hpp" +#include +#include +#include + +ConnectionRegistry::ConnectionRegistry() : connections_(nullptr), max_fds_(0) { + // Get the process file descriptor limit + struct rlimit rlim; + if (getrlimit(RLIMIT_NOFILE, &rlim) == -1) { + throw std::runtime_error("Failed to get RLIMIT_NOFILE"); + } + max_fds_ = rlim.rlim_cur; + + // Allocate virtual address space using mmap + // This reserves virtual memory but doesn't allocate physical pages until + // touched + connections_ = static_cast( + mmap(nullptr, max_fds_ * sizeof(Connection *), PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + + if (connections_ == MAP_FAILED) { + throw std::runtime_error("Failed to mmap for connection registry"); + } + + // Initialize all pointers to null + // This will cause physical pages to be allocated on-demand + memset(connections_, 0, max_fds_ * sizeof(Connection *)); +} + +ConnectionRegistry::~ConnectionRegistry() { + if (connections_ != MAP_FAILED && connections_ != nullptr) { + munmap(connections_, max_fds_ * sizeof(Connection *)); + } +} + +void ConnectionRegistry::store(int fd, std::unique_ptr connection) { + if (fd < 0 || static_cast(fd) >= max_fds_) { + return; // Invalid fd - silently ignore to avoid crashes + } + // Release ownership from unique_ptr and store raw pointer + connections_[fd] = connection.release(); +} + +bool ConnectionRegistry::has(int fd) const { + if (fd < 0 || static_cast(fd) >= max_fds_) { + return false; // Invalid fd + } + return connections_[fd] != nullptr; +} + +std::unique_ptr ConnectionRegistry::remove(int fd) { + if (fd < 0 || static_cast(fd) >= max_fds_) { + return nullptr; // Invalid fd + } + + Connection *conn = connections_[fd]; + connections_[fd] = nullptr; + // Transfer ownership back to unique_ptr + return std::unique_ptr(conn); +} + +void ConnectionRegistry::shutdown_cleanup() { + // Iterate through all possible file descriptors and clean up any connections + // Following the critical ordering: remove -> delete (destructor handles + // close) + size_t connections_found = 0; + for (size_t fd = 0; fd < max_fds_; ++fd) { + Connection *conn = connections_[fd]; + if (conn != nullptr) { + connections_found++; + // Step 1: Remove from registry (set to null) + connections_[fd] = nullptr; + + // Steps 2 & 3: Delete the connection object (destructor handles closing + // fd) + delete conn; + } + } + + // Note: In normal shutdown, this should be 0 since all connections + // should have been properly cleaned up during normal operation +} \ No newline at end of file diff --git a/src/connection_registry.hpp b/src/connection_registry.hpp new file mode 100644 index 0000000..05878e0 --- /dev/null +++ b/src/connection_registry.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include + +class Connection; + +/** + * mmap-based Connection Registry for tracking active connections. + * + * This registry provides a lock-free mechanism for tracking all connections + * owned by the server, indexed by file descriptor. The design uses mmap to + * allocate a large virtual address space efficiently, with physical memory + * allocated on-demand as connections are created. + * + * CRITICAL ORDERING REQUIREMENT: + * All connection cleanup MUST follow this exact sequence: + * 1. Remove from registry: auto conn = registry.remove(fd) + * 2. Delete the connection: unique_ptr destructor handles it automatically + * + * This ordering prevents race conditions between cleanup and fd reuse. + * The unique_ptr interface ensures ownership is always clear and prevents + * double-delete bugs. + */ +class ConnectionRegistry { +public: + /** + * Initialize the connection registry. + * Allocates virtual address space based on RLIMIT_NOFILE. + * + * @throws std::runtime_error if mmap fails or RLIMIT_NOFILE cannot be read + */ + ConnectionRegistry(); + + /** + * Destructor ensures proper cleanup of mmap'd memory. + */ + ~ConnectionRegistry(); + + /** + * Store a connection in the registry, indexed by its file descriptor. + * Takes ownership of the connection via unique_ptr. + * + * @param fd File descriptor (must be valid and < max_fds_) + * @param connection unique_ptr to the connection (ownership transferred) + */ + void store(int fd, std::unique_ptr connection); + + /** + * Check if a connection exists in the registry by file descriptor. + * + * @param fd File descriptor + * @return true if connection exists, false otherwise + */ + bool has(int fd) const; + + /** + * Remove a connection from the registry and transfer ownership to caller. + * This transfers ownership via unique_ptr move semantics. + * + * @param fd File descriptor + * @return unique_ptr to the connection, or nullptr if not found + */ + std::unique_ptr remove(int fd); + + /** + * Get the maximum number of file descriptors supported. + * + * @return Maximum file descriptor limit + */ + size_t max_fds() const { return max_fds_; } + + /** + * Perform graceful shutdown cleanup. + * Iterates through all registry entries and cleans up any remaining + * connections using the critical ordering: remove -> close -> delete. + * + * This method is called during server shutdown to ensure no connections leak. + */ + virtual void shutdown_cleanup(); + + // Non-copyable and non-movable + ConnectionRegistry(const ConnectionRegistry &) = delete; + ConnectionRegistry &operator=(const ConnectionRegistry &) = delete; + ConnectionRegistry(ConnectionRegistry &&) = delete; + ConnectionRegistry &operator=(ConnectionRegistry &&) = delete; + +private: + Connection **connections_; ///< mmap'd array of raw connection pointers + size_t max_fds_; ///< Maximum file descriptor limit +}; \ No newline at end of file diff --git a/src/server.cpp b/src/server.cpp index 2f4109f..1bb6994 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -1,10 +1,12 @@ #include "server.hpp" #include "connection.hpp" +#include "connection_registry.hpp" #include #include #include #include #include +#include #include #include #include @@ -25,9 +27,18 @@ std::shared_ptr Server::create(const weaseldb::Config &config, } Server::Server(const weaseldb::Config &config, ConnectionHandler &handler) - : config_(config), handler_(handler) {} + : config_(config), handler_(handler), connection_registry_() {} -Server::~Server() { cleanup_resources(); } +Server::~Server() { + // CRITICAL: All I/O threads are guaranteed to be joined before the destructor + // is called because they are owned by the run() method's call frame. + // This eliminates any possibility of race conditions during connection + // cleanup. + + // Clean up any remaining connections using proper ordering + connection_registry_.shutdown_cleanup(); + cleanup_resources(); +} void Server::run() { setup_shutdown_pipe(); @@ -36,12 +47,21 @@ void Server::run() { create_epoll_instances(); - start_io_threads(); + // Create I/O threads locally in this call frame + // CRITICAL: By owning threads in run()'s call frame, we guarantee they are + // joined before run() returns, eliminating any race conditions in ~Server() + std::vector threads; + start_io_threads(threads); - // Wait for all threads to complete - for (auto &thread : threads_) { + // Wait for all threads to complete before returning + // This ensures all I/O threads are fully stopped before the Server + // destructor can be called, preventing race conditions during connection + // cleanup + for (auto &thread : threads) { thread.join(); } + + // At this point, all threads are joined and it's safe to destroy the Server } void Server::shutdown() { @@ -285,11 +305,11 @@ int Server::get_epoll_for_thread(int thread_id) const { return epoll_fds_[thread_id % epoll_fds_.size()]; } -void Server::start_io_threads() { +void Server::start_io_threads(std::vector &threads) { int io_threads = config_.server.io_threads; for (int thread_id = 0; thread_id < io_threads; ++thread_id) { - threads_.emplace_back([this, thread_id]() { + threads.emplace_back([this, thread_id]() { pthread_setname_np(pthread_self(), ("io-" + std::to_string(thread_id)).c_str()); @@ -325,16 +345,20 @@ void Server::start_io_threads() { } // Handle existing connection events - std::unique_ptr conn{ - static_cast(events[i].data.ptr)}; - conn->tsan_acquire(); - events[i].data.ptr = nullptr; - - if (events[i].events & EPOLLERR) { - continue; // Connection closed - unique_ptr destructor cleans up + std::unique_ptr conn; + { + // borrowed + Connection *conn_ = static_cast(events[i].data.ptr); + conn_->tsan_acquire(); + conn = connection_registry_.remove(conn_->getFd()); } - // Add to regular batch - I/O will be processed in batch + if (events[i].events & (EPOLLERR | EPOLLHUP)) { + // unique_ptr will automatically delete on scope exit + continue; + } + + // Transfer ownership from registry to batch processing batch[batch_count] = std::move(conn); batch_events[batch_count] = events[i].events; batch_count++; @@ -378,11 +402,11 @@ void Server::start_io_threads() { perror("setsockopt SO_KEEPALIVE"); } - // Add to batch - I/O will be processed in batch - batch[batch_count] = Connection::createForServer( + // Transfer ownership from registry to batch processing + batch[batch_count] = std::unique_ptr(new Connection( addr, fd, connection_id_.fetch_add(1, std::memory_order_relaxed), - &handler_, weak_from_this()); + &handler_, weak_from_this())); batch_events[batch_count] = EPOLLIN; // New connections always start with read batch_count++; @@ -473,28 +497,30 @@ void Server::process_connection_batch( } } - // Call post-batch handler + // Call post-batch handler - handlers can take ownership here handler_.on_post_batch(batch); // Transfer all remaining connections back to epoll - for (auto &conn : batch) { - if (conn) { + for (auto &conn_ptr : batch) { + if (conn_ptr) { + int fd = conn_ptr->getFd(); + struct epoll_event event{}; - if (!conn->hasMessages()) { + if (!conn_ptr->hasMessages()) { event.events = EPOLLIN | EPOLLONESHOT; } else { event.events = EPOLLOUT | EPOLLONESHOT; } - int fd = conn->getFd(); - conn->tsan_release(); - Connection *raw_conn = conn.release(); - event.data.ptr = raw_conn; - + conn_ptr->tsan_release(); + event.data.ptr = conn_ptr.get(); // Use raw pointer for epoll + // Put connection back in registry since handler didn't take ownership. + // Must happen before epoll_ctl + connection_registry_.store(fd, std::move(conn_ptr)); int epoll_op = is_new ? EPOLL_CTL_ADD : EPOLL_CTL_MOD; if (epoll_ctl(epollfd, epoll_op, fd, &event) == -1) { perror(is_new ? "epoll_ctl ADD" : "epoll_ctl MOD"); - delete raw_conn; + (void)connection_registry_.remove(fd); } } } diff --git a/src/server.hpp b/src/server.hpp index 9288421..715c817 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -2,6 +2,7 @@ #include "config.hpp" #include "connection_handler.hpp" +#include "connection_registry.hpp" #include #include #include @@ -99,8 +100,10 @@ private: const weaseldb::Config &config_; ConnectionHandler &handler_; - // Thread management - std::vector threads_; + // Connection registry + ConnectionRegistry connection_registry_; + + // Connection management std::atomic connection_id_{0}; // Round-robin counter for connection distribution @@ -122,14 +125,14 @@ private: void setup_signal_handling(); int create_listen_socket(); void create_epoll_instances(); - void start_io_threads(); + void start_io_threads(std::vector &threads); void cleanup_resources(); // Helper to get epoll fd for a thread using round-robin int get_epoll_for_thread(int thread_id) const; // Helper for processing connection I/O - void process_connection_io(std::unique_ptr &conn, int events); + void process_connection_io(std::unique_ptr &conn_ptr, int events); // Helper for processing a batch of connections with their events void process_connection_batch(int epollfd, diff --git a/tests/test_connection_registry.cpp b/tests/test_connection_registry.cpp new file mode 100644 index 0000000..2929b0f --- /dev/null +++ b/tests/test_connection_registry.cpp @@ -0,0 +1,252 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include + +#include +#include +#include +#include +#include + +// Forward declare Connection for registry +class Connection; + +// Simplified connection registry for testing (avoid linking issues) +class TestConnectionRegistry { +public: + TestConnectionRegistry() : connections_(nullptr), max_fds_(0) { + struct rlimit rlim; + if (getrlimit(RLIMIT_NOFILE, &rlim) == -1) { + throw std::runtime_error("Failed to get RLIMIT_NOFILE"); + } + max_fds_ = rlim.rlim_cur; + + connections_ = static_cast( + mmap(nullptr, max_fds_ * sizeof(Connection *), PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + + if (connections_ == MAP_FAILED) { + throw std::runtime_error("Failed to mmap for connection registry"); + } + + memset(connections_, 0, max_fds_ * sizeof(Connection *)); + } + + ~TestConnectionRegistry() { + if (connections_ != MAP_FAILED && connections_ != nullptr) { + munmap(connections_, max_fds_ * sizeof(Connection *)); + } + } + + void store(int fd, Connection *connection) { + if (fd < 0 || static_cast(fd) >= max_fds_) { + return; + } + connections_[fd] = connection; + } + + Connection *get(int fd) const { + if (fd < 0 || static_cast(fd) >= max_fds_) { + return nullptr; + } + return connections_[fd]; + } + + Connection *remove(int fd) { + if (fd < 0 || static_cast(fd) >= max_fds_) { + return nullptr; + } + + Connection *conn = connections_[fd]; + connections_[fd] = nullptr; + return conn; + } + + size_t max_fds() const { return max_fds_; } + +private: + Connection **connections_; + size_t max_fds_; +}; + +// Mock Connection class for testing +class MockConnection { +public: + MockConnection(int id) : id_(id) {} + int getId() const { return id_; } + +private: + int id_; +}; + +TEST_CASE("ConnectionRegistry basic functionality") { + TestConnectionRegistry registry; + + SUBCASE("max_fds returns valid limit") { + struct rlimit rlim; + getrlimit(RLIMIT_NOFILE, &rlim); + CHECK(registry.max_fds() == rlim.rlim_cur); + CHECK(registry.max_fds() > 0); + } + + SUBCASE("get returns nullptr for empty registry") { + CHECK(registry.get(0) == nullptr); + CHECK(registry.get(100) == nullptr); + CHECK(registry.get(1000) == nullptr); + } + + SUBCASE("get handles invalid file descriptors") { + CHECK(registry.get(-1) == nullptr); + CHECK(registry.get(static_cast(registry.max_fds())) == nullptr); + } +} + +TEST_CASE("ConnectionRegistry store and retrieve") { + TestConnectionRegistry registry; + + // Create some mock connections (using reinterpret_cast for testing) + MockConnection mock1(1); + MockConnection mock2(2); + Connection *conn1 = reinterpret_cast(&mock1); + Connection *conn2 = reinterpret_cast(&mock2); + + SUBCASE("store and get single connection") { + registry.store(5, conn1); + CHECK(registry.get(5) == conn1); + + // Other fds should still return nullptr + CHECK(registry.get(4) == nullptr); + CHECK(registry.get(6) == nullptr); + } + + SUBCASE("store multiple connections") { + registry.store(5, conn1); + registry.store(10, conn2); + + CHECK(registry.get(5) == conn1); + CHECK(registry.get(10) == conn2); + CHECK(registry.get(7) == nullptr); + } + + SUBCASE("overwrite existing connection") { + registry.store(5, conn1); + CHECK(registry.get(5) == conn1); + + registry.store(5, conn2); + CHECK(registry.get(5) == conn2); + } + + SUBCASE("store handles invalid file descriptors safely") { + registry.store(-1, conn1); // Should not crash + registry.store(static_cast(registry.max_fds()), + conn1); // Should not crash + + CHECK(registry.get(-1) == nullptr); + CHECK(registry.get(static_cast(registry.max_fds())) == nullptr); + } +} + +TEST_CASE("ConnectionRegistry remove functionality") { + TestConnectionRegistry registry; + + MockConnection mock1(1); + MockConnection mock2(2); + Connection *conn1 = reinterpret_cast(&mock1); + Connection *conn2 = reinterpret_cast(&mock2); + + SUBCASE("remove existing connection") { + registry.store(5, conn1); + CHECK(registry.get(5) == conn1); + + Connection *removed = registry.remove(5); + CHECK(removed == conn1); + CHECK(registry.get(5) == nullptr); + } + + SUBCASE("remove non-existing connection") { + Connection *removed = registry.remove(5); + CHECK(removed == nullptr); + } + + SUBCASE("remove after remove returns nullptr") { + registry.store(5, conn1); + Connection *removed1 = registry.remove(5); + Connection *removed2 = registry.remove(5); + + CHECK(removed1 == conn1); + CHECK(removed2 == nullptr); + } + + SUBCASE("remove handles invalid file descriptors") { + CHECK(registry.remove(-1) == nullptr); + CHECK(registry.remove(static_cast(registry.max_fds())) == nullptr); + } + + SUBCASE("remove doesn't affect other connections") { + registry.store(5, conn1); + registry.store(10, conn2); + + Connection *removed = registry.remove(5); + CHECK(removed == conn1); + CHECK(registry.get(5) == nullptr); + CHECK(registry.get(10) == conn2); // Should remain unchanged + } +} + +TEST_CASE("ConnectionRegistry large file descriptor handling") { + TestConnectionRegistry registry; + + MockConnection mock1(1); + Connection *conn1 = reinterpret_cast(&mock1); + + // Test with a large but valid file descriptor + int large_fd = static_cast(registry.max_fds()) - 1; + + SUBCASE("large valid fd works") { + registry.store(large_fd, conn1); + CHECK(registry.get(large_fd) == conn1); + + Connection *removed = registry.remove(large_fd); + CHECK(removed == conn1); + CHECK(registry.get(large_fd) == nullptr); + } +} + +TEST_CASE("ConnectionRegistry critical ordering simulation") { + TestConnectionRegistry registry; + + MockConnection mock1(1); + Connection *conn1 = reinterpret_cast(&mock1); + int fd = 5; + + SUBCASE("simulate proper cleanup ordering") { + // Step 1: Store connection + registry.store(fd, conn1); + CHECK(registry.get(fd) == conn1); + + // Step 2: Remove from registry (critical ordering step 1) + Connection *removed = registry.remove(fd); + CHECK(removed == conn1); + CHECK(registry.get(fd) == nullptr); + + // Steps 2 & 3 would be close(fd) and delete conn + // but we can't test those with mock objects + } + + SUBCASE("simulate fd reuse safety") { + // Store connection + registry.store(fd, conn1); + + // Remove from registry first (step 1) + Connection *removed = registry.remove(fd); + CHECK(removed == conn1); + + // Registry is now clear - safe for fd reuse + CHECK(registry.get(fd) == nullptr); + + // New connection could safely use same fd + MockConnection mock2(2); + Connection *conn2 = reinterpret_cast(&mock2); + registry.store(fd, conn2); + CHECK(registry.get(fd) == conn2); + } +} \ No newline at end of file