std::unique_ptr<Connection> -> Ref<Connection>

This commit is contained in:
2025-09-12 18:31:57 -04:00
parent 1fa3381e4b
commit de6f38694f
11 changed files with 79 additions and 92 deletions

View File

@@ -259,7 +259,7 @@ struct Connection {
* } * }
* *
* void on_data_arrived(std::string_view data, * void on_data_arrived(std::string_view data,
* std::unique_ptr<Connection>& conn_ptr) override { * Ref<Connection>& conn_ptr) override {
* auto* state = static_cast<HttpConnectionState*>(conn_ptr->user_data); * auto* state = static_cast<HttpConnectionState*>(conn_ptr->user_data);
* // Use state for protocol processing... * // Use state for protocol processing...
* } * }
@@ -333,6 +333,9 @@ private:
size_t epoll_index, ConnectionHandler *handler, size_t epoll_index, ConnectionHandler *handler,
WeakRef<Server> server); WeakRef<Server> server);
template <typename T, typename... Args>
friend Ref<T> make_ref(Args &&...args);
// Networking interface - only accessible by Server // Networking interface - only accessible by Server
int readBytes(char *buf, size_t buffer_size); int readBytes(char *buf, size_t buffer_size);
bool writeBytes(); bool writeBytes();

View File

@@ -1,9 +1,10 @@
#pragma once #pragma once
#include <memory>
#include <span> #include <span>
#include <string_view> #include <string_view>
#include "reference.hpp"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
struct Connection; struct Connection;
@@ -39,8 +40,7 @@ public:
* after the call to on_data_arrived. * after the call to on_data_arrived.
* @note May be called from an arbitrary server thread. * @note May be called from an arbitrary server thread.
*/ */
virtual void on_data_arrived(std::string_view /*data*/, virtual void on_data_arrived(std::string_view /*data*/, Ref<Connection> &) {};
std::unique_ptr<Connection> &) {};
/** /**
* Called when data has been successfully written to the connection. * Called when data has been successfully written to the connection.
@@ -55,7 +55,7 @@ public:
* @note May be called from an arbitrary server thread. * @note May be called from an arbitrary server thread.
* @note Called during writes, not necessarily when buffer becomes empty * @note Called during writes, not necessarily when buffer becomes empty
*/ */
virtual void on_write_progress(std::unique_ptr<Connection> &) {} virtual void on_write_progress(Ref<Connection> &) {}
/** /**
* Called when the connection's outgoing write buffer becomes empty. * Called when the connection's outgoing write buffer becomes empty.
@@ -72,7 +72,7 @@ public:
* @note May be called from an arbitrary server thread. * @note May be called from an arbitrary server thread.
* @note Only called on transitions from non-empty → empty buffer * @note Only called on transitions from non-empty → empty buffer
*/ */
virtual void on_write_buffer_drained(std::unique_ptr<Connection> &) {} virtual void on_write_buffer_drained(Ref<Connection> &) {}
/** /**
* Called when a new connection is established. * Called when a new connection is established.
@@ -107,6 +107,5 @@ public:
* *
* @param batch A span of unique_ptrs to the connections in the batch. * @param batch A span of unique_ptrs to the connections in the batch.
*/ */
virtual void virtual void on_batch_complete(std::span<Ref<Connection>> /*batch*/) {}
on_batch_complete(std::span<std::unique_ptr<Connection>> /*batch*/) {}
}; };

View File

@@ -1,6 +1,5 @@
#include "connection_registry.hpp" #include "connection_registry.hpp"
#include "connection.hpp" #include "connection.hpp"
#include <atomic>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <unistd.h> #include <unistd.h>
@@ -14,49 +13,49 @@ ConnectionRegistry::ConnectionRegistry() : connections_(nullptr), max_fds_(0) {
} }
max_fds_ = rlim.rlim_cur; max_fds_ = rlim.rlim_cur;
// Calculate size rounded up to page boundary // // Calculate size rounded up to page boundary
size_t array_size = max_fds_ * sizeof(Connection *); // size_t array_size = max_fds_ * sizeof(Connection *);
size_t page_size = getpagesize(); // size_t page_size = getpagesize();
size_t aligned_size = (array_size + page_size - 1) & ~(page_size - 1); // size_t aligned_size = (array_size + page_size - 1) & ~(page_size - 1);
// Allocate virtual address space using mmap // // Allocate virtual address space using mmap
// MAP_ANONYMOUS provides zero-initialized pages on-demand (lazy allocation) // // MAP_ANONYMOUS provides zero-initialized pages on-demand (lazy
connections_ = static_cast<std::atomic<Connection *> *>( // allocation) connections_ = static_cast<std::atomic<Connection *> *>(
mmap(nullptr, aligned_size, PROT_READ | PROT_WRITE, // mmap(nullptr, aligned_size, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); // MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
if (connections_ == MAP_FAILED) { // if (connections_ == MAP_FAILED) {
perror("mmap"); // perror("mmap");
std::abort(); // std::abort();
} // }
// Store aligned size for munmap // // Store aligned size for munmap
aligned_size_ = aligned_size; // aligned_size_ = aligned_size;
connections_ = new Ref<Connection>[max_fds_];
} }
ConnectionRegistry::~ConnectionRegistry() { ConnectionRegistry::~ConnectionRegistry() {
if (connections_ != nullptr) { delete[] connections_;
for (int fd = 0; fd < static_cast<int>(max_fds_); ++fd) { // if (connections_ != nullptr) {
delete connections_[fd].load(std::memory_order_relaxed); // for (int fd = 0; fd < static_cast<int>(max_fds_); ++fd) {
} // delete connections_[fd].load(std::memory_order_relaxed);
if (munmap(connections_, aligned_size_) == -1) { // }
perror("munmap"); // if (munmap(connections_, aligned_size_) == -1) {
} // perror("munmap");
} // }
// }
} }
void ConnectionRegistry::store(int fd, std::unique_ptr<Connection> connection) { void ConnectionRegistry::store(int fd, Ref<Connection> connection) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) { if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
std::abort(); std::abort();
} }
// Release ownership from unique_ptr and store raw pointer connections_[fd] = std::move(connection);
connections_[fd].store(connection.release(), std::memory_order_release);
} }
std::unique_ptr<Connection> ConnectionRegistry::remove(int fd) { Ref<Connection> ConnectionRegistry::remove(int fd) {
if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) { if (fd < 0 || static_cast<size_t>(fd) >= max_fds_) {
std::abort(); std::abort();
} }
return std::unique_ptr<Connection>( return std::move(connections_[fd]);
connections_[fd].exchange(nullptr, std::memory_order_acquire));
} }

View File

@@ -1,10 +1,11 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <memory>
#include <sys/mman.h> #include <sys/mman.h>
#include <sys/resource.h> #include <sys/resource.h>
#include "reference.hpp"
struct Connection; struct Connection;
/** /**
@@ -38,7 +39,7 @@ public:
* @param fd File descriptor (must be valid and < max_fds_) * @param fd File descriptor (must be valid and < max_fds_)
* @param connection unique_ptr to the connection (ownership transferred) * @param connection unique_ptr to the connection (ownership transferred)
*/ */
void store(int fd, std::unique_ptr<Connection> connection); void store(int fd, Ref<Connection> connection);
/** /**
* Remove a connection from the registry and transfer ownership to caller. * Remove a connection from the registry and transfer ownership to caller.
@@ -47,7 +48,7 @@ public:
* @param fd File descriptor * @param fd File descriptor
* @return unique_ptr to the connection, or nullptr if not found * @return unique_ptr to the connection, or nullptr if not found
*/ */
std::unique_ptr<Connection> remove(int fd); Ref<Connection> remove(int fd);
/** /**
* Get the maximum number of file descriptors supported. * Get the maximum number of file descriptors supported.
@@ -63,10 +64,7 @@ public:
ConnectionRegistry &operator=(ConnectionRegistry &&) = delete; ConnectionRegistry &operator=(ConnectionRegistry &&) = delete;
private: private:
std::atomic<Connection *> Ref<Connection> *connections_;
*connections_; ///< mmap'd array of raw connection pointers. It's
///< thread-safe without since epoll_ctl happens before
///< epoll_wait, but this makes tsan happy /shrug.
size_t max_fds_; ///< Maximum file descriptor limit size_t max_fds_; ///< Maximum file descriptor limit
size_t aligned_size_; ///< Page-aligned size for munmap size_t aligned_size_; ///< Page-aligned size for munmap
}; };

View File

@@ -76,8 +76,7 @@ void HttpHandler::on_connection_closed(Connection &conn) {
conn.user_data = nullptr; conn.user_data = nullptr;
} }
void HttpHandler::on_write_buffer_drained( void HttpHandler::on_write_buffer_drained(Ref<Connection> &conn_ptr) {
std::unique_ptr<Connection> &conn_ptr) {
// Reset arena after all messages have been written for the next request // Reset arena after all messages have been written for the next request
auto *state = static_cast<HttpConnectionState *>(conn_ptr->user_data); auto *state = static_cast<HttpConnectionState *>(conn_ptr->user_data);
if (state) { if (state) {
@@ -89,8 +88,7 @@ void HttpHandler::on_write_buffer_drained(
on_connection_established(*conn_ptr); on_connection_established(*conn_ptr);
} }
void HttpHandler::on_batch_complete( void HttpHandler::on_batch_complete(std::span<Ref<Connection>> batch) {
std::span<std::unique_ptr<Connection>> batch) {
// Collect commit, status, and health check requests for pipeline processing // Collect commit, status, and health check requests for pipeline processing
int pipeline_count = 0; int pipeline_count = 0;
@@ -147,7 +145,7 @@ void HttpHandler::on_batch_complete(
} }
void HttpHandler::on_data_arrived(std::string_view data, void HttpHandler::on_data_arrived(std::string_view data,
std::unique_ptr<Connection> &conn_ptr) { Ref<Connection> &conn_ptr) {
auto *state = static_cast<HttpConnectionState *>(conn_ptr->user_data); auto *state = static_cast<HttpConnectionState *>(conn_ptr->user_data);
if (!state) { if (!state) {
send_error_response(*conn_ptr, 500, "Internal server error", true); send_error_response(*conn_ptr, 500, "Internal server error", true);

View File

@@ -135,10 +135,9 @@ struct HttpHandler : ConnectionHandler {
void on_connection_established(Connection &conn) override; void on_connection_established(Connection &conn) override;
void on_connection_closed(Connection &conn) override; void on_connection_closed(Connection &conn) override;
void on_data_arrived(std::string_view data, void on_data_arrived(std::string_view data,
std::unique_ptr<Connection> &conn_ptr) override; Ref<Connection> &conn_ptr) override;
void on_write_buffer_drained(std::unique_ptr<Connection> &conn_ptr) override; void on_write_buffer_drained(Ref<Connection> &conn_ptr) override;
void on_batch_complete( void on_batch_complete(std::span<Ref<Connection>> /*batch*/) override;
std::span<std::unique_ptr<Connection>> /*batch*/) override;
// llhttp callbacks (public for HttpConnectionState access) // llhttp callbacks (public for HttpConnectionState access)
static int onUrl(llhttp_t *parser, const char *at, size_t length); static int onUrl(llhttp_t *parser, const char *at, size_t length);

View File

@@ -9,14 +9,13 @@
* Contains connection with parsed CommitRequest. * Contains connection with parsed CommitRequest.
*/ */
struct CommitEntry { struct CommitEntry {
std::unique_ptr<Connection> connection; Ref<Connection> connection;
int64_t assigned_version = 0; // Set by sequence stage int64_t assigned_version = 0; // Set by sequence stage
bool resolve_success = false; // Set by resolve stage bool resolve_success = false; // Set by resolve stage
bool persist_success = false; // Set by persist stage bool persist_success = false; // Set by persist stage
CommitEntry() = default; // Default constructor for variant CommitEntry() = default; // Default constructor for variant
explicit CommitEntry(std::unique_ptr<Connection> conn) explicit CommitEntry(Ref<Connection> conn) : connection(std::move(conn)) {}
: connection(std::move(conn)) {}
}; };
/** /**
@@ -24,12 +23,11 @@ struct CommitEntry {
* then transfer to status threadpool. * then transfer to status threadpool.
*/ */
struct StatusEntry { struct StatusEntry {
std::unique_ptr<Connection> connection; Ref<Connection> connection;
int64_t version_upper_bound = 0; // Set by sequence stage int64_t version_upper_bound = 0; // Set by sequence stage
StatusEntry() = default; // Default constructor for variant StatusEntry() = default; // Default constructor for variant
explicit StatusEntry(std::unique_ptr<Connection> conn) explicit StatusEntry(Ref<Connection> conn) : connection(std::move(conn)) {}
: connection(std::move(conn)) {}
}; };
/** /**
@@ -38,10 +36,10 @@ struct StatusEntry {
* Resolve stage can perform configurable CPU work for benchmarking. * Resolve stage can perform configurable CPU work for benchmarking.
*/ */
struct HealthCheckEntry { struct HealthCheckEntry {
std::unique_ptr<Connection> connection; Ref<Connection> connection;
HealthCheckEntry() = default; // Default constructor for variant HealthCheckEntry() = default; // Default constructor for variant
explicit HealthCheckEntry(std::unique_ptr<Connection> conn) explicit HealthCheckEntry(Ref<Connection> conn)
: connection(std::move(conn)) {} : connection(std::move(conn)) {}
}; };

View File

@@ -5,7 +5,6 @@
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <fcntl.h> #include <fcntl.h>
#include <memory>
#include <netdb.h> #include <netdb.h>
#include <netinet/tcp.h> #include <netinet/tcp.h>
#include <pthread.h> #include <pthread.h>
@@ -139,7 +138,7 @@ void Server::shutdown() {
} }
} }
void Server::release_back_to_server(std::unique_ptr<Connection> connection) { void Server::release_back_to_server(Ref<Connection> connection) {
if (!connection) { if (!connection) {
return; // Nothing to release return; // Nothing to release
} }
@@ -154,7 +153,7 @@ void Server::release_back_to_server(std::unique_ptr<Connection> connection) {
// unique_ptr destructs // unique_ptr destructs
} }
void Server::receiveConnectionBack(std::unique_ptr<Connection> connection) { void Server::receiveConnectionBack(Ref<Connection> connection) {
if (!connection) { if (!connection) {
return; // Nothing to process return; // Nothing to process
} }
@@ -216,9 +215,9 @@ int Server::create_local_connection() {
epoll_fds_.size(); epoll_fds_.size();
// Create Connection object // Create Connection object
auto connection = std::unique_ptr<Connection>(new Connection( auto connection = make_ref<Connection>(
addr, server_fd, connection_id_.fetch_add(1, std::memory_order_relaxed), addr, server_fd, connection_id_.fetch_add(1, std::memory_order_relaxed),
epoll_index, &handler_, self_.copy())); epoll_index, &handler_, self_.copy());
// Store in registry // Store in registry
connection_registry_.store(server_fd, std::move(connection)); connection_registry_.store(server_fd, std::move(connection));
@@ -316,8 +315,7 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
int epollfd = get_epoll_for_thread(thread_id); int epollfd = get_epoll_for_thread(thread_id);
std::vector<epoll_event> events(config_.server.event_batch_size); std::vector<epoll_event> events(config_.server.event_batch_size);
std::vector<std::unique_ptr<Connection>> batch( std::vector<Ref<Connection>> batch(config_.server.event_batch_size);
config_.server.event_batch_size);
std::vector<int> batch_events(config_.server.event_batch_size); std::vector<int> batch_events(config_.server.event_batch_size);
std::vector<int> std::vector<int>
ready_listen_fds; // Reused across iterations to avoid allocation ready_listen_fds; // Reused across iterations to avoid allocation
@@ -351,7 +349,7 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
// Handle existing connection events // Handle existing connection events
int fd = events[i].data.fd; int fd = events[i].data.fd;
std::unique_ptr<Connection> conn = connection_registry_.remove(fd); Ref<Connection> conn = connection_registry_.remove(fd);
assert(conn); assert(conn);
if (events[i].events & (EPOLLERR | EPOLLHUP)) { if (events[i].events & (EPOLLERR | EPOLLHUP)) {
@@ -419,10 +417,10 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
// Transfer ownership from registry to batch processing // Transfer ownership from registry to batch processing
size_t epoll_index = thread_id % epoll_fds_.size(); size_t epoll_index = thread_id % epoll_fds_.size();
batch[batch_count] = std::unique_ptr<Connection>(new Connection( batch[batch_count] = make_ref<Connection>(
addr, fd, addr, fd,
connection_id_.fetch_add(1, std::memory_order_relaxed), connection_id_.fetch_add(1, std::memory_order_relaxed),
epoll_index, &handler_, self_.copy())); epoll_index, &handler_, self_.copy());
batch_events[batch_count] = batch_events[batch_count] =
EPOLLIN; // New connections always start with read EPOLLIN; // New connections always start with read
batch_count++; batch_count++;
@@ -449,8 +447,7 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
} }
} }
void Server::process_connection_reads(std::unique_ptr<Connection> &conn, void Server::process_connection_reads(Ref<Connection> &conn, int events) {
int events) {
assert(conn); assert(conn);
// Handle EPOLLIN - read data and process it // Handle EPOLLIN - read data and process it
if (events & EPOLLIN) { if (events & EPOLLIN) {
@@ -481,8 +478,7 @@ void Server::process_connection_reads(std::unique_ptr<Connection> &conn,
} }
} }
void Server::process_connection_writes(std::unique_ptr<Connection> &conn, void Server::process_connection_writes(Ref<Connection> &conn, int /*events*/) {
int /*events*/) {
assert(conn); assert(conn);
// For simplicity, we always attempt to write when an event fires. We could be // For simplicity, we always attempt to write when an event fires. We could be
// more precise and skip the write if we detect that we've already seen EAGAIN // more precise and skip the write if we detect that we've already seen EAGAIN
@@ -521,8 +517,8 @@ void Server::process_connection_writes(std::unique_ptr<Connection> &conn,
} }
} }
void Server::process_connection_batch( void Server::process_connection_batch(int epollfd,
int epollfd, std::span<std::unique_ptr<Connection>> batch, std::span<Ref<Connection>> batch,
std::span<const int> events) { std::span<const int> events) {
// First process writes for each connection // First process writes for each connection

View File

@@ -106,7 +106,7 @@ struct Server {
* *
* @param connection unique_ptr to the connection being released back * @param connection unique_ptr to the connection being released back
*/ */
static void release_back_to_server(std::unique_ptr<Connection> connection); static void release_back_to_server(Ref<Connection> connection);
private: private:
friend struct Connection; friend struct Connection;
@@ -158,14 +158,11 @@ private:
int get_epoll_for_thread(int thread_id) const; int get_epoll_for_thread(int thread_id) const;
// Helper for processing connection I/O // Helper for processing connection I/O
void process_connection_reads(std::unique_ptr<Connection> &conn_ptr, void process_connection_reads(Ref<Connection> &conn_ptr, int events);
int events); void process_connection_writes(Ref<Connection> &conn_ptr, int events);
void process_connection_writes(std::unique_ptr<Connection> &conn_ptr,
int events);
// Helper for processing a batch of connections with their events // Helper for processing a batch of connections with their events
void process_connection_batch(int epollfd, void process_connection_batch(int epollfd, std::span<Ref<Connection>> batch,
std::span<std::unique_ptr<Connection>> batch,
std::span<const int> events); std::span<const int> events);
/** /**
@@ -176,7 +173,7 @@ private:
* *
* @param connection Unique pointer to the connection being released back * @param connection Unique pointer to the connection being released back
*/ */
void receiveConnectionBack(std::unique_ptr<Connection> connection); void receiveConnectionBack(Ref<Connection> connection);
// Make non-copyable and non-movable // Make non-copyable and non-movable
Server(const Server &) = delete; Server(const Server &) = delete;

View File

@@ -32,11 +32,11 @@ struct MockConnectionHandler : public ConnectionHandler {
bool write_progress_called = false; bool write_progress_called = false;
bool write_buffer_drained_called = false; bool write_buffer_drained_called = false;
void on_write_progress(std::unique_ptr<Connection> &) override { void on_write_progress(Ref<Connection> &) override {
write_progress_called = true; write_progress_called = true;
} }
void on_write_buffer_drained(std::unique_ptr<Connection> &) override { void on_write_buffer_drained(Ref<Connection> &) override {
write_buffer_drained_called = true; write_buffer_drained_called = true;
} }
}; };
@@ -50,7 +50,7 @@ TEST_CASE("ConnectionHandler hooks") {
CHECK_FALSE(handler.write_buffer_drained_called); CHECK_FALSE(handler.write_buffer_drained_called);
// Would normally be called by Server during write operations // Would normally be called by Server during write operations
std::unique_ptr<Connection> null_conn; Ref<Connection> null_conn;
handler.on_write_progress(null_conn); handler.on_write_progress(null_conn);
handler.on_write_buffer_drained(null_conn); handler.on_write_buffer_drained(null_conn);

View File

@@ -11,7 +11,7 @@
PERFETTO_TRACK_EVENT_STATIC_STORAGE(); PERFETTO_TRACK_EVENT_STATIC_STORAGE();
struct Message { struct Message {
std::unique_ptr<Connection> conn; Ref<Connection> conn;
std::string data; std::string data;
bool done; bool done;
}; };
@@ -27,7 +27,7 @@ public:
: pipeline(pipeline) {} : pipeline(pipeline) {}
void on_data_arrived(std::string_view data, void on_data_arrived(std::string_view data,
std::unique_ptr<Connection> &conn_ptr) override { Ref<Connection> &conn_ptr) override {
assert(conn_ptr); assert(conn_ptr);
auto guard = pipeline.push(1, true); auto guard = pipeline.push(1, true);
for (auto &message : guard.batch) { for (auto &message : guard.batch) {