diff --git a/CMakeLists.txt b/CMakeLists.txt index c9fdc6e..e71a0f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -158,10 +158,6 @@ add_executable(test_arena_allocator tests/test_arena_allocator.cpp target_link_libraries(test_arena_allocator doctest::doctest) target_include_directories(test_arena_allocator PRIVATE src) -add_executable(test_connection_registry tests/test_connection_registry.cpp) -target_link_libraries(test_connection_registry doctest::doctest) -target_include_directories(test_connection_registry PRIVATE src) - add_executable( test_commit_request tests/test_commit_request.cpp src/json_commit_request_parser.cpp diff --git a/src/connection.cpp b/src/connection.cpp index 94a7f55..6d412b1 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -21,7 +21,9 @@ Connection::Connection(struct sockaddr_storage addr, int fd, int64_t id, ConnectionHandler *handler, std::weak_ptr server) : fd_(fd), id_(id), addr_(addr), arena_(), handler_(handler), server_(server) { - activeConnections.fetch_add(1, std::memory_order_relaxed); + if (auto server_ptr = server_.lock()) { + server_ptr->active_connections_.fetch_add(1, std::memory_order_relaxed); + } if (handler_) { handler_->on_connection_established(*this); } @@ -31,7 +33,9 @@ Connection::~Connection() { if (handler_) { handler_->on_connection_closed(*this); } - activeConnections.fetch_sub(1, std::memory_order_relaxed); + if (auto server_ptr = server_.lock()) { + server_ptr->active_connections_.fetch_sub(1, std::memory_order_relaxed); + } int e = close(fd_); if (e == -1) { perror("close"); diff --git a/src/connection.hpp b/src/connection.hpp index fec1fa6..37b2802 100644 --- a/src/connection.hpp +++ b/src/connection.hpp @@ -11,8 +11,6 @@ #include #include -extern std::atomic activeConnections; - #ifndef __has_feature #define __has_feature(x) 0 #endif diff --git a/src/connection_registry.cpp b/src/connection_registry.cpp index ae1caf0..a6ae0cb 100644 --- a/src/connection_registry.cpp +++ b/src/connection_registry.cpp @@ -36,7 +36,9 @@ ConnectionRegistry::~ConnectionRegistry() { for (size_t fd = 0; fd < max_fds_; ++fd) { delete connections_[fd]; } - munmap(connections_, aligned_size_); + if (munmap(connections_, aligned_size_) == -1) { + perror("munmap"); + } } } diff --git a/src/main.cpp b/src/main.cpp index be1161b..9d0c8d3 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -10,10 +10,6 @@ PERFETTO_TRACK_EVENT_STATIC_STORAGE(); -// TODO this should be scoped to a particular Server, and it's definition should -// be in server.cpp or connection.cpp -std::atomic activeConnections{0}; - // Global server instance for signal handler access static Server *g_server = nullptr; diff --git a/src/server.cpp b/src/server.cpp index 0e5510d..277776f 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -17,8 +17,6 @@ #include #include -extern std::atomic activeConnections; - std::shared_ptr Server::create(const weaseldb::Config &config, ConnectionHandler &handler) { // Use std::shared_ptr constructor with private access @@ -407,7 +405,7 @@ void Server::start_io_threads(std::vector &threads) { // Check connection limit if (config_.server.max_connections > 0 && - activeConnections.load(std::memory_order_relaxed) >= + active_connections_.load(std::memory_order_relaxed) >= config_.server.max_connections) { close(fd); continue; diff --git a/src/server.hpp b/src/server.hpp index 97e4847..b63f497 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -89,6 +89,7 @@ public: static void releaseBackToServer(std::unique_ptr connection); private: + friend class Connection; /** * Private constructor - use create() factory method instead. * @@ -105,6 +106,7 @@ private: // Connection management std::atomic connection_id_{0}; + std::atomic active_connections_{0}; // Round-robin counter for connection distribution std::atomic connection_distribution_counter_{0}; diff --git a/tests/test_connection_registry.cpp b/tests/test_connection_registry.cpp deleted file mode 100644 index 2929b0f..0000000 --- a/tests/test_connection_registry.cpp +++ /dev/null @@ -1,252 +0,0 @@ -#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN -#include - -#include -#include -#include -#include -#include - -// Forward declare Connection for registry -class Connection; - -// Simplified connection registry for testing (avoid linking issues) -class TestConnectionRegistry { -public: - TestConnectionRegistry() : connections_(nullptr), max_fds_(0) { - struct rlimit rlim; - if (getrlimit(RLIMIT_NOFILE, &rlim) == -1) { - throw std::runtime_error("Failed to get RLIMIT_NOFILE"); - } - max_fds_ = rlim.rlim_cur; - - connections_ = static_cast( - mmap(nullptr, max_fds_ * sizeof(Connection *), PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); - - if (connections_ == MAP_FAILED) { - throw std::runtime_error("Failed to mmap for connection registry"); - } - - memset(connections_, 0, max_fds_ * sizeof(Connection *)); - } - - ~TestConnectionRegistry() { - if (connections_ != MAP_FAILED && connections_ != nullptr) { - munmap(connections_, max_fds_ * sizeof(Connection *)); - } - } - - void store(int fd, Connection *connection) { - if (fd < 0 || static_cast(fd) >= max_fds_) { - return; - } - connections_[fd] = connection; - } - - Connection *get(int fd) const { - if (fd < 0 || static_cast(fd) >= max_fds_) { - return nullptr; - } - return connections_[fd]; - } - - Connection *remove(int fd) { - if (fd < 0 || static_cast(fd) >= max_fds_) { - return nullptr; - } - - Connection *conn = connections_[fd]; - connections_[fd] = nullptr; - return conn; - } - - size_t max_fds() const { return max_fds_; } - -private: - Connection **connections_; - size_t max_fds_; -}; - -// Mock Connection class for testing -class MockConnection { -public: - MockConnection(int id) : id_(id) {} - int getId() const { return id_; } - -private: - int id_; -}; - -TEST_CASE("ConnectionRegistry basic functionality") { - TestConnectionRegistry registry; - - SUBCASE("max_fds returns valid limit") { - struct rlimit rlim; - getrlimit(RLIMIT_NOFILE, &rlim); - CHECK(registry.max_fds() == rlim.rlim_cur); - CHECK(registry.max_fds() > 0); - } - - SUBCASE("get returns nullptr for empty registry") { - CHECK(registry.get(0) == nullptr); - CHECK(registry.get(100) == nullptr); - CHECK(registry.get(1000) == nullptr); - } - - SUBCASE("get handles invalid file descriptors") { - CHECK(registry.get(-1) == nullptr); - CHECK(registry.get(static_cast(registry.max_fds())) == nullptr); - } -} - -TEST_CASE("ConnectionRegistry store and retrieve") { - TestConnectionRegistry registry; - - // Create some mock connections (using reinterpret_cast for testing) - MockConnection mock1(1); - MockConnection mock2(2); - Connection *conn1 = reinterpret_cast(&mock1); - Connection *conn2 = reinterpret_cast(&mock2); - - SUBCASE("store and get single connection") { - registry.store(5, conn1); - CHECK(registry.get(5) == conn1); - - // Other fds should still return nullptr - CHECK(registry.get(4) == nullptr); - CHECK(registry.get(6) == nullptr); - } - - SUBCASE("store multiple connections") { - registry.store(5, conn1); - registry.store(10, conn2); - - CHECK(registry.get(5) == conn1); - CHECK(registry.get(10) == conn2); - CHECK(registry.get(7) == nullptr); - } - - SUBCASE("overwrite existing connection") { - registry.store(5, conn1); - CHECK(registry.get(5) == conn1); - - registry.store(5, conn2); - CHECK(registry.get(5) == conn2); - } - - SUBCASE("store handles invalid file descriptors safely") { - registry.store(-1, conn1); // Should not crash - registry.store(static_cast(registry.max_fds()), - conn1); // Should not crash - - CHECK(registry.get(-1) == nullptr); - CHECK(registry.get(static_cast(registry.max_fds())) == nullptr); - } -} - -TEST_CASE("ConnectionRegistry remove functionality") { - TestConnectionRegistry registry; - - MockConnection mock1(1); - MockConnection mock2(2); - Connection *conn1 = reinterpret_cast(&mock1); - Connection *conn2 = reinterpret_cast(&mock2); - - SUBCASE("remove existing connection") { - registry.store(5, conn1); - CHECK(registry.get(5) == conn1); - - Connection *removed = registry.remove(5); - CHECK(removed == conn1); - CHECK(registry.get(5) == nullptr); - } - - SUBCASE("remove non-existing connection") { - Connection *removed = registry.remove(5); - CHECK(removed == nullptr); - } - - SUBCASE("remove after remove returns nullptr") { - registry.store(5, conn1); - Connection *removed1 = registry.remove(5); - Connection *removed2 = registry.remove(5); - - CHECK(removed1 == conn1); - CHECK(removed2 == nullptr); - } - - SUBCASE("remove handles invalid file descriptors") { - CHECK(registry.remove(-1) == nullptr); - CHECK(registry.remove(static_cast(registry.max_fds())) == nullptr); - } - - SUBCASE("remove doesn't affect other connections") { - registry.store(5, conn1); - registry.store(10, conn2); - - Connection *removed = registry.remove(5); - CHECK(removed == conn1); - CHECK(registry.get(5) == nullptr); - CHECK(registry.get(10) == conn2); // Should remain unchanged - } -} - -TEST_CASE("ConnectionRegistry large file descriptor handling") { - TestConnectionRegistry registry; - - MockConnection mock1(1); - Connection *conn1 = reinterpret_cast(&mock1); - - // Test with a large but valid file descriptor - int large_fd = static_cast(registry.max_fds()) - 1; - - SUBCASE("large valid fd works") { - registry.store(large_fd, conn1); - CHECK(registry.get(large_fd) == conn1); - - Connection *removed = registry.remove(large_fd); - CHECK(removed == conn1); - CHECK(registry.get(large_fd) == nullptr); - } -} - -TEST_CASE("ConnectionRegistry critical ordering simulation") { - TestConnectionRegistry registry; - - MockConnection mock1(1); - Connection *conn1 = reinterpret_cast(&mock1); - int fd = 5; - - SUBCASE("simulate proper cleanup ordering") { - // Step 1: Store connection - registry.store(fd, conn1); - CHECK(registry.get(fd) == conn1); - - // Step 2: Remove from registry (critical ordering step 1) - Connection *removed = registry.remove(fd); - CHECK(removed == conn1); - CHECK(registry.get(fd) == nullptr); - - // Steps 2 & 3 would be close(fd) and delete conn - // but we can't test those with mock objects - } - - SUBCASE("simulate fd reuse safety") { - // Store connection - registry.store(fd, conn1); - - // Remove from registry first (step 1) - Connection *removed = registry.remove(fd); - CHECK(removed == conn1); - - // Registry is now clear - safe for fd reuse - CHECK(registry.get(fd) == nullptr); - - // New connection could safely use same fd - MockConnection mock2(2); - Connection *conn2 = reinterpret_cast(&mock2); - registry.store(fd, conn2); - CHECK(registry.get(fd) == conn2); - } -} \ No newline at end of file