Separate Connection and Request lifetimes

This commit is contained in:
2025-09-14 15:04:37 -04:00
parent cf0c1b7cc2
commit 16c7ee0408
14 changed files with 519 additions and 381 deletions

View File

@@ -250,9 +250,6 @@ target_link_libraries(load_tester Threads::Threads llhttp_static perfetto)
add_test(NAME arena_tests COMMAND test_arena) add_test(NAME arena_tests COMMAND test_arena)
add_test(NAME commit_request_tests COMMAND test_commit_request) add_test(NAME commit_request_tests COMMAND test_commit_request)
add_test(NAME http_handler_tests COMMAND test_http_handler)
add_test(NAME server_connection_return_tests
COMMAND test_server_connection_return)
add_test(NAME arena_benchmarks COMMAND bench_arena) add_test(NAME arena_benchmarks COMMAND bench_arena)
add_test(NAME commit_request_benchmarks COMMAND bench_commit_request) add_test(NAME commit_request_benchmarks COMMAND bench_commit_request)
add_test(NAME parser_comparison_benchmarks COMMAND bench_parser_comparison) add_test(NAME parser_comparison_benchmarks COMMAND bench_parser_comparison)

View File

@@ -406,6 +406,7 @@ public:
* ## Note: * ## Note:
* This method only allocates memory - it does not construct objects. * This method only allocates memory - it does not construct objects.
* Use placement new or other initialization methods as needed. * Use placement new or other initialization methods as needed.
* TODO should this return a std::span<T> ?
*/ */
template <typename T> T *allocate(uint32_t size) { template <typename T> T *allocate(uint32_t size) {
static_assert( static_assert(

View File

@@ -59,25 +59,29 @@ Connection::Connection(struct sockaddr_storage addr, int fd, int64_t id,
} }
Connection::~Connection() { Connection::~Connection() {
if (handler_) {
handler_->on_connection_closed(*this); handler_->on_connection_closed(*this);
assert(fd_ < 0 && "Connection fd was not closed before ~Connection");
} }
// Server may legitimately be gone now
if (auto server_ptr = server_.lock()) { void Connection::close() {
std::lock_guard lock{mutex_};
auto server_ptr = server_.lock();
// Should only be called from the io thread
assert(server_ptr);
server_ptr->active_connections_.fetch_sub(1, std::memory_order_relaxed); server_ptr->active_connections_.fetch_sub(1, std::memory_order_relaxed);
} assert(fd_ >= 0);
int e = ::close(fd_);
// Decrement active connections gauge
connections_active.dec();
int e = close(fd_);
if (e == -1 && errno != EINTR) { if (e == -1 && errno != EINTR) {
perror("close"); perror("close");
std::abort(); std::abort();
} }
// EINTR ignored - fd is guaranteed closed on Linux // EINTR ignored - fd is guaranteed closed on Linux
fd_ = -1;
// Decrement active connections gauge
connections_active.dec();
} }
// May be called off the io thread!
void Connection::append_message(std::span<std::string_view> data_parts, void Connection::append_message(std::span<std::string_view> data_parts,
Arena arena, bool close_after_send) { Arena arena, bool close_after_send) {
// Calculate total bytes for this message. Don't need to hold the lock yet. // Calculate total bytes for this message. Don't need to hold the lock yet.
@@ -86,11 +90,7 @@ void Connection::append_message(std::span<std::string_view> data_parts,
total_bytes += part.size(); total_bytes += part.size();
} }
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock lock(mutex_);
if (is_closed_) {
return; // Connection is closed, ignore message
}
// Check if queue was empty to determine if we need to enable EPOLLOUT // Check if queue was empty to determine if we need to enable EPOLLOUT
bool was_empty = message_queue_.empty(); bool was_empty = message_queue_.empty();
@@ -100,22 +100,15 @@ void Connection::append_message(std::span<std::string_view> data_parts,
Message{std::move(arena), data_parts, close_after_send}); Message{std::move(arena), data_parts, close_after_send});
outgoing_bytes_queued_ += total_bytes; outgoing_bytes_queued_ += total_bytes;
// If this message has close_after_send flag, set connection flag // If queue was empty, we need to add EPOLLOUT interest.
if (close_after_send) { if (was_empty && fd_ >= 0) {
close_after_send_ = true;
}
lock.unlock();
// If queue was empty, we need to add EPOLLOUT interest. We don't need to hold
// the lock
if (was_empty) {
auto server = server_.lock(); auto server = server_.lock();
if (server) { if (server) {
// Add EPOLLOUT interest - pipeline thread manages epoll // Add EPOLLOUT interest - pipeline thread manages epoll
struct epoll_event event; struct epoll_event event;
event.data.fd = fd_; event.data.fd = fd_;
event.events = EPOLLIN | EPOLLOUT; event.events = EPOLLIN | EPOLLOUT;
tsan_release();
epoll_ctl(server->epoll_fds_[epoll_index_], EPOLL_CTL_MOD, fd_, &event); epoll_ctl(server->epoll_fds_[epoll_index_], EPOLL_CTL_MOD, fd_, &event);
} }
} }
@@ -148,16 +141,18 @@ int Connection::readBytes(char *buf, size_t buffer_size) {
} }
} }
bool Connection::writeBytes() { uint32_t Connection::write_bytes() {
ssize_t total_bytes_written = 0; ssize_t total_bytes_written = 0;
uint32_t result = 0;
while (true) { while (true) {
// Build iovec array while holding mutex using thread-local buffer // Build iovec array while holding mutex using thread-local buffer
int iov_count = 0; int iov_count = 0;
{ {
std::lock_guard lock(mutex_); std::lock_guard lock(mutex_);
if (is_closed_ || message_queue_.empty()) { if (message_queue_.empty()) {
break; break;
} }
@@ -204,14 +199,17 @@ bool Connection::writeBytes() {
if (total_bytes_written > 0) { if (total_bytes_written > 0) {
bytes_written.inc(total_bytes_written); bytes_written.inc(total_bytes_written);
} }
return false; return result;
} }
perror("sendmsg"); perror("sendmsg");
return true; result |= Error;
return result;
} }
break; break;
} }
result |= Progress;
assert(w > 0); assert(w > 0);
total_bytes_written += w; total_bytes_written += w;
@@ -244,9 +242,15 @@ bool Connection::writeBytes() {
} }
if (message_complete) { if (message_complete) {
if (front_message.close_after_send) {
result |= Close;
}
// Move arena to thread-local vector for deferred cleanup // Move arena to thread-local vector for deferred cleanup
g_arenas_to_free.emplace_back(std::move(front_message.arena)); g_arenas_to_free.emplace_back(std::move(front_message.arena));
message_queue_.pop_front(); message_queue_.pop_front();
if (result & Close) {
break;
}
} else { } else {
break; break;
} }
@@ -258,11 +262,13 @@ bool Connection::writeBytes() {
{ {
std::lock_guard lock(mutex_); std::lock_guard lock(mutex_);
if (message_queue_.empty()) { if (message_queue_.empty()) {
result |= Drained;
auto server = server_.lock(); auto server = server_.lock();
if (server) { if (server) {
struct epoll_event event; struct epoll_event event;
event.data.fd = fd_; event.data.fd = fd_;
event.events = EPOLLIN; // Remove EPOLLOUT event.events = EPOLLIN; // Remove EPOLLOUT
tsan_release();
epoll_ctl(server->epoll_fds_[epoll_index_], EPOLL_CTL_MOD, fd_, &event); epoll_ctl(server->epoll_fds_[epoll_index_], EPOLL_CTL_MOD, fd_, &event);
} }
} }
@@ -277,5 +283,5 @@ bool Connection::writeBytes() {
// This avoids holding the connection mutex while free() potentially contends // This avoids holding the connection mutex while free() potentially contends
g_arenas_to_free.clear(); g_arenas_to_free.clear();
return false; return result;
} }

View File

@@ -20,6 +20,40 @@
// Forward declaration // Forward declaration
struct Server; struct Server;
/**
* Base interface for sending messages to a connection.
* This restricted interface is safe for use by pipeline threads,
* containing only the append_message method needed for responses.
* Pipeline threads should use WeakRef<MessageSender> to safely
* send responses without accessing other connection functionality
* that should only be used by the I/O thread.
*/
struct MessageSender {
/**
* @brief Append message data to connection's outgoing message queue.
*
* Thread-safe method that can be called from any thread, including
* pipeline processing threads. The arena is moved into the connection
* to maintain data lifetime until the message is sent.
*
* @param data_parts Span of string_view parts to send (arena-allocated)
* @param arena Arena containing the memory for data_parts string_views
* @param close_after_send Whether to close connection after sending
*
* Example usage:
* ```cpp
* auto response_parts = std::span{arena.allocate<std::string_view>(2), 2};
* response_parts[0] = "HTTP/1.1 200 OK\r\n\r\n";
* response_parts[1] = "Hello World";
* conn.append_message(response_parts, std::move(arena));
* ```
*/
virtual void append_message(std::span<std::string_view> data_parts,
Arena arena, bool close_after_send = false) = 0;
virtual ~MessageSender() = default;
};
/** /**
* Represents a single client connection with thread-safe concurrent access. * Represents a single client connection with thread-safe concurrent access.
* *
@@ -42,7 +76,7 @@ struct Server;
* - No connection-owned arena for parsing/response generation * - No connection-owned arena for parsing/response generation
* - Message queue stores spans + owning arenas until I/O completion * - Message queue stores spans + owning arenas until I/O completion
*/ */
struct Connection { struct Connection : MessageSender {
// No public constructor or factory method - only Server can create // No public constructor or factory method - only Server can create
// connections // connections
@@ -91,7 +125,7 @@ struct Connection {
* ``` * ```
*/ */
void append_message(std::span<std::string_view> data_parts, Arena arena, void append_message(std::span<std::string_view> data_parts, Arena arena,
bool close_after_send = false); bool close_after_send = false) override;
/** /**
* @brief Get a WeakRef to this connection for async operations. * @brief Get a WeakRef to this connection for async operations.
@@ -120,7 +154,10 @@ struct Connection {
* }); * });
* ``` * ```
*/ */
WeakRef<Connection> get_weak_ref() const { return self_ref_.copy(); } WeakRef<MessageSender> get_weak_ref() const {
assert(self_ref_.lock());
return self_ref_.copy();
}
/** /**
* @brief Get the unique identifier for this connection. * @brief Get the unique identifier for this connection.
@@ -278,23 +315,26 @@ private:
// Networking interface - only accessible by Server // Networking interface - only accessible by Server
int readBytes(char *buf, size_t buffer_size); int readBytes(char *buf, size_t buffer_size);
bool writeBytes(); enum WriteBytesResult {
Error = 1 << 0,
Progress = 1 << 1,
Drained = 1 << 2,
Close = 1 << 3,
};
uint32_t write_bytes();
// Direct access methods for Server (must hold mutex) // Direct access methods for Server (must hold mutex)
int getFd() const { return fd_; } int getFd() const { return fd_; }
bool has_messages() const { return !message_queue_.empty(); } bool has_messages() const { return !message_queue_.empty(); }
bool should_close() const { return close_after_send_; }
size_t getEpollIndex() const { return epoll_index_; } size_t getEpollIndex() const { return epoll_index_; }
void close();
// Server can set self-reference after creation
void setSelfRef(WeakRef<Connection> self) { self_ref_ = std::move(self); }
// Immutable connection properties // Immutable connection properties
const int fd_; int fd_;
const int64_t id_; const int64_t id_;
const size_t epoll_index_; // Index of the epoll instance this connection uses const size_t epoll_index_; // Index of the epoll instance this connection uses
struct sockaddr_storage addr_; // sockaddr_storage handles IPv4/IPv6 struct sockaddr_storage addr_; // sockaddr_storage handles IPv4/IPv6
ConnectionHandler *handler_; ConnectionHandler *const handler_;
WeakRef<Server> server_; // Weak reference to server for safe cleanup WeakRef<Server> server_; // Weak reference to server for safe cleanup
WeakRef<Connection> self_ref_; // WeakRef to self for get_weak_ref() WeakRef<Connection> self_ref_; // WeakRef to self for get_weak_ref()
@@ -305,6 +345,13 @@ private:
// mutex_, but if non-empty mutex_ can be // mutex_, but if non-empty mutex_ can be
// dropped while server accesses existing elements. // dropped while server accesses existing elements.
int64_t outgoing_bytes_queued_{0}; // Counter of queued bytes int64_t outgoing_bytes_queued_{0}; // Counter of queued bytes
bool close_after_send_{false}; // Close after sending all messages
bool is_closed_{false}; // Connection closed state #if __has_feature(thread_sanitizer)
void tsan_acquire() { tsan_sync.load(std::memory_order_acquire); }
void tsan_release() { tsan_sync.store(0, std::memory_order_release); }
std::atomic<int> tsan_sync;
#else
void tsan_acquire() {}
void tsan_release() {}
#endif
}; };

View File

@@ -3,8 +3,6 @@
#include <span> #include <span>
#include <string_view> #include <string_view>
#include "reference.hpp"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
struct Connection; struct Connection;
@@ -36,7 +34,7 @@ public:
* - Use conn.get_weak_ref() for async processing if needed * - Use conn.get_weak_ref() for async processing if needed
* *
* @note `data` lifetime ends after the call to on_data_arrived. * @note `data` lifetime ends after the call to on_data_arrived.
* @note May be called from an arbitrary server thread. * @note Called from this connection's io thread.
* @note Handler can safely access connection concurrently via thread-safe * @note Handler can safely access connection concurrently via thread-safe
* methods. * methods.
*/ */
@@ -51,7 +49,7 @@ public:
* - Progress monitoring for long-running transfers * - Progress monitoring for long-running transfers
* *
* @param conn Connection that made write progress - server retains ownership * @param conn Connection that made write progress - server retains ownership
* @note May be called from an arbitrary server thread. * @note Called from this connection's io thread.
* @note Called during writes, not necessarily when buffer becomes empty * @note Called during writes, not necessarily when buffer becomes empty
*/ */
virtual void on_write_progress(Connection &) {} virtual void on_write_progress(Connection &) {}
@@ -66,7 +64,7 @@ public:
* - Relieving backpressure conditions * - Relieving backpressure conditions
* *
* @param conn Connection with empty write buffer - server retains ownership * @param conn Connection with empty write buffer - server retains ownership
* @note May be called from an arbitrary server thread. * @note Called from this connection's io thread.
* @note Only called on transitions from non-empty → empty buffer * @note Only called on transitions from non-empty → empty buffer
*/ */
virtual void on_write_buffer_drained(Connection &) {} virtual void on_write_buffer_drained(Connection &) {}
@@ -78,7 +76,7 @@ public:
* *
* Use this for: * Use this for:
* - Connection-specific initialization. * - Connection-specific initialization.
* @note May be called from an arbitrary server thread. * @note Called from this connection's io thread.
*/ */
virtual void on_connection_established(Connection &) {} virtual void on_connection_established(Connection &) {}
@@ -89,7 +87,8 @@ public:
* *
* Use this for: * Use this for:
* - Cleanup of connection-specific resources. * - Cleanup of connection-specific resources.
* @note May be called from an arbitrary server thread. * @note Called from this connection's io thread, or possibly a foreign thread
* that has locked the MessageSender associated with this connection.
*/ */
virtual void on_connection_closed(Connection &) {} virtual void on_connection_closed(Connection &) {}
@@ -101,6 +100,7 @@ public:
* All connections remain server-owned. * All connections remain server-owned.
* *
* @param batch A span of connection references in the batch. * @param batch A span of connection references in the batch.
* @note Called from this connection's io thread.
*/ */
virtual void on_batch_complete(std::span<Connection *> /*batch*/) {} virtual void on_batch_complete(std::span<Connection *const> /*batch*/) {}
}; };

View File

@@ -7,13 +7,13 @@
#include "api_url_parser.hpp" #include "api_url_parser.hpp"
#include "arena.hpp" #include "arena.hpp"
#include "connection.hpp"
#include "cpu_work.hpp" #include "cpu_work.hpp"
#include "format.hpp" #include "format.hpp"
#include "json_commit_request_parser.hpp" #include "json_commit_request_parser.hpp"
#include "metric.hpp" #include "metric.hpp"
#include "perfetto_categories.hpp" #include "perfetto_categories.hpp"
#include "pipeline_entry.hpp" #include "pipeline_entry.hpp"
#include "server.hpp"
auto requests_counter_family = metric::create_counter( auto requests_counter_family = metric::create_counter(
"weaseldb_http_requests_total", "Total http requests"); "weaseldb_http_requests_total", "Total http requests");
@@ -37,8 +37,8 @@ auto banned_request_ids_memory_gauge =
.create({}); .create({});
// HttpConnectionState implementation // HttpConnectionState implementation
HttpConnectionState::HttpConnectionState(Arena &arena) HttpConnectionState::HttpConnectionState()
: arena(arena), current_header_field_buf(ArenaStlAllocator<char>(&arena)), : current_header_field_buf(ArenaStlAllocator<char>(&arena)),
current_header_value_buf(ArenaStlAllocator<char>(&arena)) { current_header_value_buf(ArenaStlAllocator<char>(&arena)) {
llhttp_settings_init(&settings); llhttp_settings_init(&settings);
@@ -58,43 +58,36 @@ HttpConnectionState::HttpConnectionState(Arena &arena)
// HttpHandler implementation // HttpHandler implementation
void HttpHandler::on_connection_established(Connection &conn) { void HttpHandler::on_connection_established(Connection &conn) {
// Allocate HTTP state in connection's arena // Allocate HTTP state using server-provided arena for connection lifecycle
Arena &arena = conn.get_arena(); auto *state = new HttpConnectionState();
void *mem = arena.allocate_raw(sizeof(HttpConnectionState),
alignof(HttpConnectionState));
auto *state = new (mem) HttpConnectionState(arena);
conn.user_data = state; conn.user_data = state;
} }
void HttpHandler::on_connection_closed(Connection &conn) { void HttpHandler::on_connection_closed(Connection &conn) {
// Arena cleanup happens automatically when connection is destroyed
auto *state = static_cast<HttpConnectionState *>(conn.user_data); auto *state = static_cast<HttpConnectionState *>(conn.user_data);
if (state) { delete state;
// Arena::Ptr automatically calls destructors
state->~HttpConnectionState();
}
conn.user_data = nullptr; conn.user_data = nullptr;
} }
void HttpHandler::on_write_buffer_drained(Ref<Connection> &conn_ptr) { void HttpHandler::on_write_buffer_drained(Connection &conn) {
// Reset arena after all messages have been written for the next request // Reset state after all messages have been written for the next request
auto *state = static_cast<HttpConnectionState *>(conn_ptr->user_data); auto *state = static_cast<HttpConnectionState *>(conn.user_data);
if (state) { if (state) {
TRACE_EVENT("http", "reply", TRACE_EVENT("http", "reply",
perfetto::Flow::Global(state->http_request_id)); perfetto::Flow::Global(state->http_request_id));
} }
on_connection_closed(*conn_ptr); // TODO we don't need this anymore. Look at removing it.
conn_ptr->reset(); on_connection_closed(conn);
on_connection_established(*conn_ptr); // Note: Connection reset happens at server level, not connection level
on_connection_established(conn);
} }
void HttpHandler::on_batch_complete(std::span<Ref<Connection>> batch) { void HttpHandler::on_batch_complete(std::span<Connection *const> batch) {
// Collect commit, status, and health check requests for pipeline processing // Collect commit, status, and health check requests for pipeline processing
int pipeline_count = 0; int pipeline_count = 0;
// Count commit, status, and health check requests // Count commit, status, and health check requests
for (auto &conn : batch) { for (auto conn : batch) {
if (conn && conn->user_data) {
auto *state = static_cast<HttpConnectionState *>(conn->user_data); auto *state = static_cast<HttpConnectionState *>(conn->user_data);
// Count commit requests that passed basic validation // Count commit requests that passed basic validation
@@ -115,40 +108,42 @@ void HttpHandler::on_batch_complete(std::span<Ref<Connection>> batch) {
pipeline_count++; pipeline_count++;
} }
} }
}
// Send requests to 4-stage pipeline in batch // Send requests to 4-stage pipeline in batch
if (pipeline_count > 0) { if (pipeline_count > 0) {
auto guard = commitPipeline.push(pipeline_count, true); auto guard = commitPipeline.push(pipeline_count, true);
auto out_iter = guard.batch.begin(); auto out_iter = guard.batch.begin();
for (auto &conn : batch) { for (auto conn : batch) {
if (conn && conn->user_data) {
auto *state = static_cast<HttpConnectionState *>(conn->user_data); auto *state = static_cast<HttpConnectionState *>(conn->user_data);
// Create CommitEntry for commit requests // Create CommitEntry for commit requests
if (state->route == HttpRoute::PostCommit && state->commit_request && if (state->route == HttpRoute::PostCommit && state->commit_request &&
state->parsing_commit && state->basic_validation_passed) { state->parsing_commit && state->basic_validation_passed) {
*out_iter++ = CommitEntry{std::move(conn)}; *out_iter++ =
CommitEntry{conn->get_weak_ref(), state->http_request_id,
state->connection_close, state->commit_request.get()};
} }
// Create StatusEntry for status requests // Create StatusEntry for status requests
else if (state->route == HttpRoute::GetStatus) { else if (state->route == HttpRoute::GetStatus) {
*out_iter++ = StatusEntry{std::move(conn)}; *out_iter++ =
StatusEntry{conn->get_weak_ref(), state->http_request_id,
state->connection_close, state->status_request_id};
} }
// Create HealthCheckEntry for health check requests // Create HealthCheckEntry for health check requests
else if (state->route == HttpRoute::GetOk) { else if (state->route == HttpRoute::GetOk) {
*out_iter++ = HealthCheckEntry{std::move(conn)}; *out_iter++ =
} HealthCheckEntry{conn->get_weak_ref(), state->http_request_id,
state->connection_close};
} }
} }
} }
} }
void HttpHandler::on_data_arrived(std::string_view data, void HttpHandler::on_data_arrived(std::string_view data, Connection &conn) {
Ref<Connection> &conn_ptr) { auto *state = static_cast<HttpConnectionState *>(conn.user_data);
auto *state = static_cast<HttpConnectionState *>(conn_ptr->user_data);
if (!state) { if (!state) {
send_error_response(*conn_ptr, 500, "Internal server error", true); send_error_response(conn, 500, "Internal server error", Arena{}, 0, true);
return; return;
} }
@@ -162,15 +157,18 @@ void HttpHandler::on_data_arrived(std::string_view data,
llhttp_execute(&state->parser, data.data(), data.size()); llhttp_execute(&state->parser, data.data(), data.size());
if (err != HPE_OK) { if (err != HPE_OK) {
send_error_response(*conn_ptr, 400, "Bad request", true); send_error_response(conn, 400, "Bad request", Arena{}, 0, true);
return; return;
} }
// If message is complete, route and handle the request // If message is complete, route and handle the request
if (state->message_complete) { if (state->message_complete) {
// Copy URL to arena for in-place decoding // Create request-scoped arena for URL parsing
Arena &arena = conn_ptr->get_arena(); // FIX: request_arena lifetime ends too soon. Should move arena into
char *url_buffer = arena.allocate<char>(state->url.size()); // individual handlers and propagate it all the way through to
// append_message
Arena request_arena;
char *url_buffer = request_arena.allocate<char>(state->url.size());
std::memcpy(url_buffer, state->url.data(), state->url.size()); std::memcpy(url_buffer, state->url.data(), state->url.size());
RouteMatch route_match; RouteMatch route_match;
@@ -180,7 +178,8 @@ void HttpHandler::on_data_arrived(std::string_view data,
if (parse_result != ParseResult::Success) { if (parse_result != ParseResult::Success) {
// Handle malformed URL encoding // Handle malformed URL encoding
send_error_response(*conn_ptr, 400, "Malformed URL encoding", true); send_error_response(conn, 400, "Malformed URL encoding", Arena{}, 0,
true);
return; return;
} }
@@ -189,35 +188,36 @@ void HttpHandler::on_data_arrived(std::string_view data,
// Route to appropriate handler // Route to appropriate handler
switch (state->route) { switch (state->route) {
case HttpRoute::GetVersion: case HttpRoute::GetVersion:
handle_get_version(*conn_ptr, *state); handle_get_version(conn, *state, std::move(request_arena));
break; break;
case HttpRoute::PostCommit: case HttpRoute::PostCommit:
handle_post_commit(*conn_ptr, *state); handle_post_commit(conn, *state, std::move(request_arena));
break; break;
case HttpRoute::GetSubscribe: case HttpRoute::GetSubscribe:
handle_get_subscribe(*conn_ptr, *state); handle_get_subscribe(conn, *state, std::move(request_arena));
break; break;
case HttpRoute::GetStatus: case HttpRoute::GetStatus:
handle_get_status(*conn_ptr, *state, route_match); handle_get_status(conn, *state, route_match, std::move(request_arena));
break; break;
case HttpRoute::PutRetention: case HttpRoute::PutRetention:
handle_put_retention(*conn_ptr, *state, route_match); handle_put_retention(conn, *state, route_match, std::move(request_arena));
break; break;
case HttpRoute::GetRetention: case HttpRoute::GetRetention:
handle_get_retention(*conn_ptr, *state, route_match); handle_get_retention(conn, *state, route_match, std::move(request_arena));
break; break;
case HttpRoute::DeleteRetention: case HttpRoute::DeleteRetention:
handle_delete_retention(*conn_ptr, *state, route_match); handle_delete_retention(conn, *state, route_match,
std::move(request_arena));
break; break;
case HttpRoute::GetMetrics: case HttpRoute::GetMetrics:
handle_get_metrics(*conn_ptr, *state); handle_get_metrics(conn, *state, std::move(request_arena));
break; break;
case HttpRoute::GetOk: case HttpRoute::GetOk:
handle_get_ok(*conn_ptr, *state); handle_get_ok(conn, *state, std::move(request_arena));
break; break;
case HttpRoute::NotFound: case HttpRoute::NotFound:
default: default:
handle_not_found(*conn_ptr, *state); handle_not_found(conn, *state, std::move(request_arena));
break; break;
} }
} }
@@ -225,27 +225,29 @@ void HttpHandler::on_data_arrived(std::string_view data,
// Route handlers (basic implementations) // Route handlers (basic implementations)
void HttpHandler::handle_get_version(Connection &conn, void HttpHandler::handle_get_version(Connection &conn,
const HttpConnectionState &state) { const HttpConnectionState &state,
Arena request_arena) {
version_counter.inc(); version_counter.inc();
send_json_response( send_json_response(
conn, 200, conn, 200,
format(conn.get_arena(), R"({"version":%ld,"leader":""})", format(request_arena, R"({"version":%ld,"leader":""})",
this->committed_version.load(std::memory_order_seq_cst)), this->committed_version.load(std::memory_order_seq_cst)),
state.connection_close); std::move(request_arena), state.http_request_id, state.connection_close);
} }
void HttpHandler::handle_post_commit(Connection &conn, void HttpHandler::handle_post_commit(Connection &conn,
const HttpConnectionState &state) { const HttpConnectionState &state,
Arena request_arena) {
commit_counter.inc(); commit_counter.inc();
// Check if streaming parse was successful // Check if streaming parse was successful
if (!state.commit_request || !state.parsing_commit) { if (!state.commit_request || !state.parsing_commit) {
const char *error = state.commit_parser const char *error = state.commit_parser
? state.commit_parser->get_parse_error() ? state.commit_parser->get_parse_error()
: "No parser initialized"; : "No parser initialized";
Arena &arena = conn.get_arena(); std::string_view error_msg = format(request_arena, "Parse failed: %s",
std::string_view error_msg = error ? error : "Unknown error");
format(arena, "Parse failed: %s", error ? error : "Unknown error"); send_error_response(conn, 400, error_msg, std::move(request_arena),
send_error_response(conn, 400, error_msg, state.connection_close); state.http_request_id, state.connection_close);
return; return;
} }
@@ -285,7 +287,8 @@ void HttpHandler::handle_post_commit(Connection &conn,
} }
if (!valid) { if (!valid) {
send_error_response(conn, 400, error_msg, state.connection_close); send_error_response(conn, 400, error_msg, std::move(request_arena),
state.http_request_id, state.connection_close);
return; return;
} }
@@ -295,17 +298,18 @@ void HttpHandler::handle_post_commit(Connection &conn,
} }
void HttpHandler::handle_get_subscribe(Connection &conn, void HttpHandler::handle_get_subscribe(Connection &conn,
const HttpConnectionState &state) { const HttpConnectionState &state,
Arena) {
// TODO: Implement subscription streaming // TODO: Implement subscription streaming
send_json_response( send_json_response(
conn, 200, conn, 200,
R"({"message":"Subscription endpoint - streaming not yet implemented"})", R"({"message":"Subscription endpoint - streaming not yet implemented"})",
state.connection_close); Arena{}, state.http_request_id, state.connection_close);
} }
void HttpHandler::handle_get_status(Connection &conn, void HttpHandler::handle_get_status(Connection &conn,
HttpConnectionState &state, HttpConnectionState &state,
const RouteMatch &route_match) { const RouteMatch &route_match, Arena) {
status_counter.inc(); status_counter.inc();
// Status requests are processed through the pipeline // Status requests are processed through the pipeline
// Response will be generated in the sequence stage // Response will be generated in the sequence stage
@@ -316,14 +320,14 @@ void HttpHandler::handle_get_status(Connection &conn,
route_match.params[static_cast<int>(ApiParameterKey::RequestId)]; route_match.params[static_cast<int>(ApiParameterKey::RequestId)];
if (!request_id) { if (!request_id) {
send_error_response(conn, 400, send_error_response(conn, 400,
"Missing required query parameter: request_id", "Missing required query parameter: request_id", Arena{},
state.connection_close); state.http_request_id, state.connection_close);
return; return;
} }
if (request_id->empty()) { if (request_id->empty()) {
send_error_response(conn, 400, "Empty request_id parameter", send_error_response(conn, 400, "Empty request_id parameter", Arena{},
state.connection_close); state.http_request_id, state.connection_close);
return; return;
} }
@@ -335,32 +339,33 @@ void HttpHandler::handle_get_status(Connection &conn,
void HttpHandler::handle_put_retention(Connection &conn, void HttpHandler::handle_put_retention(Connection &conn,
const HttpConnectionState &state, const HttpConnectionState &state,
const RouteMatch &route_match) { const RouteMatch &, Arena) {
// TODO: Parse retention policy from body and store // TODO: Parse retention policy from body and store
send_json_response(conn, 200, R"({"policy_id":"example","status":"created"})", send_json_response(conn, 200, R"({"policy_id":"example","status":"created"})",
state.connection_close); Arena{}, state.http_request_id, state.connection_close);
} }
void HttpHandler::handle_get_retention(Connection &conn, void HttpHandler::handle_get_retention(Connection &conn,
const HttpConnectionState &state, const HttpConnectionState &state,
const RouteMatch &route_match) { const RouteMatch &, Arena) {
// TODO: Extract policy_id from URL or return all policies // TODO: Extract policy_id from URL or return all policies
send_json_response(conn, 200, R"({"policies":[]})", state.connection_close); send_json_response(conn, 200, R"({"policies":[]})", Arena{},
state.http_request_id, state.connection_close);
} }
void HttpHandler::handle_delete_retention(Connection &conn, void HttpHandler::handle_delete_retention(Connection &conn,
const HttpConnectionState &state, const HttpConnectionState &state,
const RouteMatch &route_match) { const RouteMatch &, Arena) {
// TODO: Extract policy_id from URL and delete // TODO: Extract policy_id from URL and delete
send_json_response(conn, 200, R"({"policy_id":"example","status":"deleted"})", send_json_response(conn, 200, R"({"policy_id":"example","status":"deleted"})",
state.connection_close); Arena{}, state.http_request_id, state.connection_close);
} }
void HttpHandler::handle_get_metrics(Connection &conn, void HttpHandler::handle_get_metrics(Connection &conn,
const HttpConnectionState &state) { const HttpConnectionState &state,
Arena request_arena) {
metrics_counter.inc(); metrics_counter.inc();
Arena &arena = conn.get_arena(); auto metrics_span = metric::render(request_arena);
auto metrics_span = metric::render(arena);
// Calculate total size for the response body // Calculate total size for the response body
size_t total_size = 0; size_t total_size = 0;
@@ -374,32 +379,33 @@ void HttpHandler::handle_get_metrics(Connection &conn,
std::string_view headers; std::string_view headers;
if (state.connection_close) { if (state.connection_close) {
headers = static_format( headers = static_format(
arena, "HTTP/1.1 200 OK\r\n", request_arena, "HTTP/1.1 200 OK\r\n",
"Content-Type: text/plain; version=0.0.4\r\n", "Content-Type: text/plain; version=0.0.4\r\n",
"Content-Length: ", static_cast<uint64_t>(total_size), "\r\n", "Content-Length: ", static_cast<uint64_t>(total_size), "\r\n",
"X-Response-ID: ", static_cast<int64_t>(http_state->http_request_id), "X-Response-ID: ", static_cast<int64_t>(http_state->http_request_id),
"\r\n", "Connection: close\r\n", "\r\n"); "\r\n", "Connection: close\r\n", "\r\n");
conn.close_after_send();
} else { } else {
headers = static_format( headers = static_format(
arena, "HTTP/1.1 200 OK\r\n", request_arena, "HTTP/1.1 200 OK\r\n",
"Content-Type: text/plain; version=0.0.4\r\n", "Content-Type: text/plain; version=0.0.4\r\n",
"Content-Length: ", static_cast<uint64_t>(total_size), "\r\n", "Content-Length: ", static_cast<uint64_t>(total_size), "\r\n",
"X-Response-ID: ", static_cast<int64_t>(http_state->http_request_id), "X-Response-ID: ", static_cast<int64_t>(http_state->http_request_id),
"\r\n", "Connection: keep-alive\r\n", "\r\n"); "\r\n", "Connection: keep-alive\r\n", "\r\n");
} }
// Send headers auto result = std::span<std::string_view>{
conn.append_message(headers, false); request_arena.allocate<std::string_view>(metrics_span.size() + 1),
metrics_span.size() + 1};
// Send body in chunks auto out = result.begin();
for (const auto &sv : metrics_span) { *out++ = headers;
conn.append_message(sv, false); for (auto sv : metrics_span) {
*out++ = sv;
} }
conn.append_message(result, std::move(request_arena));
} }
void HttpHandler::handle_get_ok(Connection &conn, void HttpHandler::handle_get_ok(Connection &, const HttpConnectionState &state,
const HttpConnectionState &state) { Arena) {
ok_counter.inc(); ok_counter.inc();
TRACE_EVENT("http", "GET /ok", perfetto::Flow::Global(state.http_request_id)); TRACE_EVENT("http", "GET /ok", perfetto::Flow::Global(state.http_request_id));
@@ -408,16 +414,17 @@ void HttpHandler::handle_get_ok(Connection &conn,
} }
void HttpHandler::handle_not_found(Connection &conn, void HttpHandler::handle_not_found(Connection &conn,
const HttpConnectionState &state) { const HttpConnectionState &state, Arena) {
send_error_response(conn, 404, "Not found", state.connection_close); send_error_response(conn, 404, "Not found", Arena{}, state.http_request_id,
state.connection_close);
} }
// HTTP utility methods // HTTP utility methods
void HttpHandler::send_response(Connection &conn, int status_code, void HttpHandler::send_response(MessageSender &conn, int status_code,
std::string_view content_type, std::string_view content_type,
std::string_view body, bool close_connection) { std::string_view body, Arena response_arena,
Arena &arena = conn.get_arena(); int64_t http_request_id,
auto *state = static_cast<HttpConnectionState *>(conn.user_data); bool close_connection) {
// Status text // Status text
std::string_view status_text; std::string_view status_text;
@@ -441,8 +448,10 @@ void HttpHandler::send_response(Connection &conn, int status_code,
const char *connection_header = close_connection ? "close" : "keep-alive"; const char *connection_header = close_connection ? "close" : "keep-alive";
std::string_view response = auto response = std::span{response_arena.allocate<std::string_view>(1), 1};
format(arena,
response[0] =
format(response_arena,
"HTTP/1.1 %d %.*s\r\n" "HTTP/1.1 %d %.*s\r\n"
"Content-Type: %.*s\r\n" "Content-Type: %.*s\r\n"
"Content-Length: %zu\r\n" "Content-Length: %zu\r\n"
@@ -451,32 +460,32 @@ void HttpHandler::send_response(Connection &conn, int status_code,
"\r\n%.*s", "\r\n%.*s",
status_code, static_cast<int>(status_text.size()), status_code, static_cast<int>(status_text.size()),
status_text.data(), static_cast<int>(content_type.size()), status_text.data(), static_cast<int>(content_type.size()),
content_type.data(), body.size(), state->http_request_id, content_type.data(), body.size(), http_request_id,
connection_header, static_cast<int>(body.size()), body.data()); connection_header, static_cast<int>(body.size()), body.data());
if (close_connection) { conn.append_message(response, std::move(response_arena), close_connection);
conn.close_after_send();
} }
conn.append_message(response); void HttpHandler::send_json_response(MessageSender &conn, int status_code,
}
void HttpHandler::send_json_response(Connection &conn, int status_code,
std::string_view json, std::string_view json,
Arena response_arena,
int64_t http_request_id,
bool close_connection) { bool close_connection) {
send_response(conn, status_code, "application/json", json, close_connection); send_response(conn, status_code, "application/json", json,
std::move(response_arena), http_request_id, close_connection);
} }
void HttpHandler::send_error_response(Connection &conn, int status_code, void HttpHandler::send_error_response(MessageSender &conn, int status_code,
std::string_view message, std::string_view message,
Arena response_arena,
int64_t http_request_id,
bool close_connection) { bool close_connection) {
Arena &arena = conn.get_arena();
std::string_view json = std::string_view json =
format(arena, R"({"error":"%.*s"})", static_cast<int>(message.size()), format(response_arena, R"({"error":"%.*s"})",
message.data()); static_cast<int>(message.size()), message.data());
send_json_response(conn, status_code, json, close_connection); send_json_response(conn, status_code, json, std::move(response_arena),
http_request_id, close_connection);
} }
// llhttp callbacks // llhttp callbacks
@@ -623,29 +632,35 @@ bool HttpHandler::process_sequence_batch(BatchType &batch) {
} else if constexpr (std::is_same_v<T, CommitEntry>) { } else if constexpr (std::is_same_v<T, CommitEntry>) {
// Process commit entry: check banned list, assign version // Process commit entry: check banned list, assign version
auto &commit_entry = e; auto &commit_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = commit_entry.connection.lock();
commit_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
// metric
// TODO: Add dropped_pipeline_entries metric
return false; // Skip this entry and continue processing
}
if (!state || !state->commit_request) { if (!commit_entry.commit_request) {
// Should not happen - basic validation was done on I/O thread // Should not happen - basic validation was done on I/O thread
send_error_response(*commit_entry.connection, 500, send_error_response(*conn_ref, 500, "Internal server error",
"Internal server error", true); Arena{}, commit_entry.http_request_id, true);
return false; return false;
} }
// Check if request_id is banned (for status queries) // Check if request_id is banned (for status queries)
// Only check CommitRequest request_id, not HTTP header // Only check CommitRequest request_id, not HTTP header
if (state->commit_request && if (commit_entry.commit_request &&
state->commit_request->request_id().has_value()) { commit_entry.commit_request->request_id().has_value()) {
auto commit_request_id = auto commit_request_id =
state->commit_request->request_id().value(); commit_entry.commit_request->request_id().value();
if (banned_request_ids.find(commit_request_id) != if (banned_request_ids.find(commit_request_id) !=
banned_request_ids.end()) { banned_request_ids.end()) {
// Request ID is banned, this commit should fail // Request ID is banned, this commit should fail
send_json_response( send_json_response(
*commit_entry.connection, 409, *conn_ref, 409,
R"({"status": "not_committed", "error": "request_id_banned"})", R"({"status": "not_committed", "error": "request_id_banned"})",
state->connection_close); Arena{}, commit_entry.http_request_id,
commit_entry.connection_close);
return false; return false;
} }
} }
@@ -654,25 +669,30 @@ bool HttpHandler::process_sequence_batch(BatchType &batch) {
commit_entry.assigned_version = next_version++; commit_entry.assigned_version = next_version++;
TRACE_EVENT("http", "sequence_commit", TRACE_EVENT("http", "sequence_commit",
perfetto::Flow::Global(state->http_request_id)); perfetto::Flow::Global(commit_entry.http_request_id));
return false; // Continue processing return false; // Continue processing
} else if constexpr (std::is_same_v<T, StatusEntry>) { } else if constexpr (std::is_same_v<T, StatusEntry>) {
// Process status entry: add request_id to banned list, get version // Process status entry: add request_id to banned list, get version
// upper bound // upper bound
auto &status_entry = e; auto &status_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = status_entry.connection.lock();
status_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
// metric
// TODO: Add dropped_pipeline_entries metric
return false; // Skip this entry and continue processing
}
if (state && !state->status_request_id.empty()) { if (!status_entry.status_request_id.empty()) {
// Add request_id to banned list - store the string in arena and // Add request_id to banned list - store the string in arena and
// use string_view // use string_view
char *arena_chars = banned_request_arena.allocate<char>( char *arena_chars = banned_request_arena.allocate<char>(
state->status_request_id.size()); status_entry.status_request_id.size());
std::memcpy(arena_chars, state->status_request_id.data(), std::memcpy(arena_chars, status_entry.status_request_id.data(),
state->status_request_id.size()); status_entry.status_request_id.size());
std::string_view request_id_view(arena_chars, std::string_view request_id_view(
state->status_request_id.size()); arena_chars, status_entry.status_request_id.size());
banned_request_ids.insert(request_id_view); banned_request_ids.insert(request_id_view);
// Update memory usage metric // Update memory usage metric
@@ -684,26 +704,30 @@ bool HttpHandler::process_sequence_batch(BatchType &batch) {
} }
TRACE_EVENT("http", "sequence_status", TRACE_EVENT("http", "sequence_status",
perfetto::Flow::Global(state->http_request_id)); perfetto::Flow::Global(status_entry.http_request_id));
// TODO: Transfer to status threadpool - for now just respond // TODO: Transfer to status threadpool - for now just respond
// not_committed // not_committed
send_json_response(*status_entry.connection, 200, send_json_response(*conn_ref, 200, R"({"status": "not_committed"})",
R"({"status": "not_committed"})", Arena{}, status_entry.http_request_id,
state->connection_close); status_entry.connection_close);
return false; // Continue processing return false; // Continue processing
} else if constexpr (std::is_same_v<T, HealthCheckEntry>) { } else if constexpr (std::is_same_v<T, HealthCheckEntry>) {
// Process health check entry: noop in sequence stage // Process health check entry: noop in sequence stage
auto &health_check_entry = e; auto &health_check_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = health_check_entry.connection.lock();
health_check_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
if (state) { // metric
TRACE_EVENT("http", "sequence_health_check", // TODO: Add dropped_pipeline_entries metric
perfetto::Flow::Global(state->http_request_id)); return false; // Skip this entry and continue processing
} }
TRACE_EVENT(
"http", "sequence_health_check",
perfetto::Flow::Global(health_check_entry.http_request_id));
return false; // Continue processing return false; // Continue processing
} }
@@ -736,10 +760,15 @@ bool HttpHandler::process_resolve_batch(BatchType &batch) {
// Process commit entry: accept all commits (simplified // Process commit entry: accept all commits (simplified
// implementation) // implementation)
auto &commit_entry = e; auto &commit_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = commit_entry.connection.lock();
commit_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
// metric
// TODO: Add dropped_pipeline_entries metric
return false; // Skip this entry and continue processing
}
if (!state || !state->commit_request) { if (!commit_entry.commit_request) {
// Skip processing for failed sequence stage // Skip processing for failed sequence stage
return false; return false;
} }
@@ -748,7 +777,7 @@ bool HttpHandler::process_resolve_batch(BatchType &batch) {
commit_entry.resolve_success = true; commit_entry.resolve_success = true;
TRACE_EVENT("http", "resolve_commit", TRACE_EVENT("http", "resolve_commit",
perfetto::Flow::Global(state->http_request_id)); perfetto::Flow::Global(commit_entry.http_request_id));
return false; // Continue processing return false; // Continue processing
} else if constexpr (std::is_same_v<T, StatusEntry>) { } else if constexpr (std::is_same_v<T, StatusEntry>) {
@@ -758,14 +787,18 @@ bool HttpHandler::process_resolve_batch(BatchType &batch) {
} else if constexpr (std::is_same_v<T, HealthCheckEntry>) { } else if constexpr (std::is_same_v<T, HealthCheckEntry>) {
// Process health check entry: perform configurable CPU work // Process health check entry: perform configurable CPU work
auto &health_check_entry = e; auto &health_check_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = health_check_entry.connection.lock();
health_check_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
if (state) { // metric
TRACE_EVENT("http", "resolve_health_check", // TODO: Add dropped_pipeline_entries metric
perfetto::Flow::Global(state->http_request_id)); return false; // Skip this entry and continue processing
} }
TRACE_EVENT(
"http", "resolve_health_check",
perfetto::Flow::Global(health_check_entry.http_request_id));
// Perform configurable CPU-intensive work for benchmarking // Perform configurable CPU-intensive work for benchmarking
spend_cpu_cycles(config_.benchmark.ok_resolve_iterations); spend_cpu_cycles(config_.benchmark.ok_resolve_iterations);
@@ -799,12 +832,17 @@ bool HttpHandler::process_persist_batch(BatchType &batch) {
} else if constexpr (std::is_same_v<T, CommitEntry>) { } else if constexpr (std::is_same_v<T, CommitEntry>) {
// Process commit entry: mark as durable, generate response // Process commit entry: mark as durable, generate response
auto &commit_entry = e; auto &commit_entry = e;
auto *state = static_cast<HttpConnectionState *>( // Check if connection is still alive first
commit_entry.connection->user_data); auto conn_ref = commit_entry.connection.lock();
if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
// metric
// TODO: Add dropped_pipeline_entries metric
return false; // Skip this entry and continue processing
}
// Skip if resolve failed or connection is in error state // Skip if resolve failed or connection is in error state
if (!state || !state->commit_request || if (!commit_entry.commit_request || !commit_entry.resolve_success) {
!commit_entry.resolve_success) {
return false; return false;
} }
@@ -814,29 +852,31 @@ bool HttpHandler::process_persist_batch(BatchType &batch) {
std::memory_order_seq_cst); std::memory_order_seq_cst);
TRACE_EVENT("http", "persist_commit", TRACE_EVENT("http", "persist_commit",
perfetto::Flow::Global(state->http_request_id)); perfetto::Flow::Global(commit_entry.http_request_id));
const CommitRequest &commit_request = *state->commit_request; const CommitRequest &commit_request = *commit_entry.commit_request;
Arena &arena = commit_entry.connection->get_arena();
Arena response_arena;
std::string_view response; std::string_view response;
// Generate success response with actual assigned version // Generate success response with actual assigned version
if (commit_request.request_id().has_value()) { if (commit_request.request_id().has_value()) {
response = format( response = format(
arena, response_arena,
R"({"request_id":"%.*s","status":"committed","version":%ld,"leader_id":"leader123"})", R"({"request_id":"%.*s","status":"committed","version":%ld,"leader_id":"leader123"})",
static_cast<int>(commit_request.request_id().value().size()), static_cast<int>(commit_request.request_id().value().size()),
commit_request.request_id().value().data(), commit_request.request_id().value().data(),
commit_entry.assigned_version); commit_entry.assigned_version);
} else { } else {
response = format( response = format(
arena, response_arena,
R"({"status":"committed","version":%ld,"leader_id":"leader123"})", R"({"status":"committed","version":%ld,"leader_id":"leader123"})",
commit_entry.assigned_version); commit_entry.assigned_version);
} }
send_json_response(*commit_entry.connection, 200, response, send_json_response(
state->connection_close); *conn_ref, 200, response, std::move(response_arena),
commit_entry.http_request_id, commit_entry.connection_close);
return false; // Continue processing return false; // Continue processing
} else if constexpr (std::is_same_v<T, StatusEntry>) { } else if constexpr (std::is_same_v<T, StatusEntry>) {
@@ -846,17 +886,22 @@ bool HttpHandler::process_persist_batch(BatchType &batch) {
} else if constexpr (std::is_same_v<T, HealthCheckEntry>) { } else if constexpr (std::is_same_v<T, HealthCheckEntry>) {
// Process health check entry: generate OK response // Process health check entry: generate OK response
auto &health_check_entry = e; auto &health_check_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = health_check_entry.connection.lock();
health_check_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
// metric
// TODO: Add dropped_pipeline_entries metric
return false; // Skip this entry and continue processing
}
if (state) { TRACE_EVENT(
TRACE_EVENT("http", "persist_health_check", "http", "persist_health_check",
perfetto::Flow::Global(state->http_request_id)); perfetto::Flow::Global(health_check_entry.http_request_id));
// Generate OK response // Generate OK response
send_response(*health_check_entry.connection, 200, "text/plain", send_response(*conn_ref, 200, "text/plain", "OK", Arena{},
"OK", state->connection_close); health_check_entry.http_request_id,
} health_check_entry.connection_close);
return false; // Continue processing return false; // Continue processing
} }
@@ -888,38 +933,48 @@ bool HttpHandler::process_release_batch(BatchType &batch) {
} else if constexpr (std::is_same_v<T, CommitEntry>) { } else if constexpr (std::is_same_v<T, CommitEntry>) {
// Process commit entry: return connection to server // Process commit entry: return connection to server
auto &commit_entry = e; auto &commit_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = commit_entry.connection.lock();
commit_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
if (state) { // metric
TRACE_EVENT("http", "release_commit", // TODO: Add dropped_pipeline_entries metric
perfetto::Flow::Global(state->http_request_id)); return false; // Skip this entry and continue processing
} }
TRACE_EVENT("http", "release_commit",
perfetto::Flow::Global(commit_entry.http_request_id));
return false; // Continue processing return false; // Continue processing
} else if constexpr (std::is_same_v<T, StatusEntry>) { } else if constexpr (std::is_same_v<T, StatusEntry>) {
// Process status entry: return connection to server // Process status entry: return connection to server
auto &status_entry = e; auto &status_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = status_entry.connection.lock();
status_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
if (state) { // metric
TRACE_EVENT("http", "release_status", // TODO: Add dropped_pipeline_entries metric
perfetto::Flow::Global(state->http_request_id)); return false; // Skip this entry and continue processing
} }
TRACE_EVENT("http", "release_status",
perfetto::Flow::Global(status_entry.http_request_id));
return false; // Continue processing return false; // Continue processing
} else if constexpr (std::is_same_v<T, HealthCheckEntry>) { } else if constexpr (std::is_same_v<T, HealthCheckEntry>) {
// Process health check entry: return connection to server // Process health check entry: return connection to server
auto &health_check_entry = e; auto &health_check_entry = e;
auto *state = static_cast<HttpConnectionState *>( auto conn_ref = health_check_entry.connection.lock();
health_check_entry.connection->user_data); if (!conn_ref) {
// Connection is gone, drop the entry silently and increment
if (state) { // metric
TRACE_EVENT("http", "release_health_check", // TODO: Add dropped_pipeline_entries metric
perfetto::Flow::Global(state->http_request_id)); return false; // Skip this entry and continue processing
} }
TRACE_EVENT(
"http", "release_health_check",
perfetto::Flow::Global(health_check_entry.http_request_id));
return false; // Continue processing return false; // Continue processing
} }

View File

@@ -1,7 +1,6 @@
#pragma once #pragma once
#include <atomic> #include <atomic>
#include <memory>
#include <string_view> #include <string_view>
#include <thread> #include <thread>
#include <unordered_set> #include <unordered_set>
@@ -13,9 +12,7 @@
#include "config.hpp" #include "config.hpp"
#include "connection.hpp" #include "connection.hpp"
#include "connection_handler.hpp" #include "connection_handler.hpp"
#include "perfetto_categories.hpp"
#include "pipeline_entry.hpp" #include "pipeline_entry.hpp"
#include "server.hpp"
#include "thread_pipeline.hpp" #include "thread_pipeline.hpp"
// Forward declarations // Forward declarations
@@ -28,7 +25,7 @@ struct RouteMatch;
* Manages llhttp parser state and request data. * Manages llhttp parser state and request data.
*/ */
struct HttpConnectionState { struct HttpConnectionState {
Arena &arena; Arena arena; // Request-scoped arena for parsing state
llhttp_t parser; llhttp_t parser;
llhttp_settings_t settings; llhttp_settings_t settings;
@@ -62,7 +59,7 @@ struct HttpConnectionState {
bool basic_validation_passed = bool basic_validation_passed =
false; // Set to true if basic validation passes false; // Set to true if basic validation passes
explicit HttpConnectionState(Arena &arena); HttpConnectionState();
}; };
/** /**
@@ -134,10 +131,9 @@ struct HttpHandler : ConnectionHandler {
void on_connection_established(Connection &conn) override; void on_connection_established(Connection &conn) override;
void on_connection_closed(Connection &conn) override; void on_connection_closed(Connection &conn) override;
void on_data_arrived(std::string_view data, void on_data_arrived(std::string_view data, Connection &conn) override;
Ref<Connection> &conn_ptr) override; void on_batch_complete(std::span<Connection *const> batch) override;
void on_write_buffer_drained(Ref<Connection> &conn_ptr) override; void on_write_buffer_drained(Connection &conn_ptr) override;
void on_batch_complete(std::span<Ref<Connection>> /*batch*/) override;
// llhttp callbacks (public for HttpConnectionState access) // llhttp callbacks (public for HttpConnectionState access)
static int onUrl(llhttp_t *parser, const char *at, size_t length); static int onUrl(llhttp_t *parser, const char *at, size_t length);
@@ -193,31 +189,40 @@ private:
bool process_release_batch(BatchType &batch); bool process_release_batch(BatchType &batch);
// Route handlers // Route handlers
void handle_get_version(Connection &conn, const HttpConnectionState &state); void handle_get_version(Connection &conn, const HttpConnectionState &state,
void handle_post_commit(Connection &conn, const HttpConnectionState &state); Arena request_arena);
void handle_get_subscribe(Connection &conn, const HttpConnectionState &state); void handle_post_commit(Connection &conn, const HttpConnectionState &state,
Arena request_arena);
void handle_get_subscribe(Connection &conn, const HttpConnectionState &state,
Arena request_arena);
void handle_get_status(Connection &conn, HttpConnectionState &state, void handle_get_status(Connection &conn, HttpConnectionState &state,
const RouteMatch &route_match); const RouteMatch &route_match, Arena request_arena);
void handle_put_retention(Connection &conn, const HttpConnectionState &state, void handle_put_retention(Connection &conn, const HttpConnectionState &state,
const RouteMatch &route_match); const RouteMatch &route_match, Arena request_arena);
void handle_get_retention(Connection &conn, const HttpConnectionState &state, void handle_get_retention(Connection &conn, const HttpConnectionState &state,
const RouteMatch &route_match); const RouteMatch &route_match, Arena request_arena);
void handle_delete_retention(Connection &conn, void handle_delete_retention(Connection &conn,
const HttpConnectionState &state, const HttpConnectionState &state,
const RouteMatch &route_match); const RouteMatch &route_match,
void handle_get_metrics(Connection &conn, const HttpConnectionState &state); Arena request_arena);
void handle_get_ok(Connection &conn, const HttpConnectionState &state); void handle_get_metrics(Connection &conn, const HttpConnectionState &state,
void handle_not_found(Connection &conn, const HttpConnectionState &state); Arena request_arena);
void handle_get_ok(Connection &conn, const HttpConnectionState &state,
Arena request_arena);
void handle_not_found(Connection &conn, const HttpConnectionState &state,
Arena request_arena);
// HTTP utilities // HTTP utilities
static void send_response(Connection &conn, int status_code, static void send_response(MessageSender &conn, int status_code,
std::string_view content_type, std::string_view content_type,
std::string_view body, std::string_view body, Arena response_arena,
bool close_connection = false); int64_t http_request_id, bool close_connection);
static void send_json_response(Connection &conn, int status_code, static void send_json_response(MessageSender &conn, int status_code,
std::string_view json, std::string_view json, Arena response_arena,
bool close_connection = false); int64_t http_request_id,
static void send_error_response(Connection &conn, int status_code, bool close_connection);
static void send_error_response(MessageSender &conn, int status_code,
std::string_view message, std::string_view message,
bool close_connection = false); Arena response_arena, int64_t http_request_id,
bool close_connection);
}; };

View File

@@ -454,7 +454,7 @@ struct Metric {
Arena arena; Arena arena;
ThreadInit() { ThreadInit() {
// Register this thread's arena for memory tracking // Register this thread's arena for memory tracking
std::unique_lock<std::mutex> _{mutex}; std::unique_lock _{mutex};
get_thread_arenas()[std::this_thread::get_id()] = &arena; get_thread_arenas()[std::this_thread::get_id()] = &arena;
} }
~ThreadInit() { ~ThreadInit() {
@@ -462,7 +462,7 @@ struct Metric {
// THREAD SAFETY: All operations below are protected by the global mutex, // THREAD SAFETY: All operations below are protected by the global mutex,
// including writes to global accumulated state, preventing races with // including writes to global accumulated state, preventing races with
// render thread // render thread
std::unique_lock<std::mutex> _{mutex}; std::unique_lock _{mutex};
// NOTE: registration_version increment is REQUIRED here because: // NOTE: registration_version increment is REQUIRED here because:
// - Cached render plan has pre-resolved pointers to thread-local state // - Cached render plan has pre-resolved pointers to thread-local state
// - When threads disappear, these pointers become invalid // - When threads disappear, these pointers become invalid
@@ -501,7 +501,7 @@ struct Metric {
if (thread_it != family->per_thread_state.end()) { if (thread_it != family->per_thread_state.end()) {
for (auto &[labels_key, instance] : thread_it->second.instances) { for (auto &[labels_key, instance] : thread_it->second.instances) {
// Acquire lock to get consistent snapshot // Acquire lock to get consistent snapshot
std::lock_guard<std::mutex> lock(instance->mutex); std::lock_guard lock(instance->mutex);
// Global accumulator should have been created when we made the // Global accumulator should have been created when we made the
// histogram // histogram
@@ -592,7 +592,7 @@ struct Metric {
// Force thread_local initialization // Force thread_local initialization
(void)thread_init; (void)thread_init;
std::unique_lock<std::mutex> _{mutex}; std::unique_lock _{mutex};
++Metric::registration_version; ++Metric::registration_version;
const LabelsKey &key = intern_labels(labels); const LabelsKey &key = intern_labels(labels);
@@ -633,7 +633,7 @@ struct Metric {
static Gauge create_gauge_instance( static Gauge create_gauge_instance(
Family<Gauge> *family, Family<Gauge> *family,
std::span<const std::pair<std::string_view, std::string_view>> labels) { std::span<const std::pair<std::string_view, std::string_view>> labels) {
std::unique_lock<std::mutex> _{mutex}; std::unique_lock _{mutex};
++Metric::registration_version; ++Metric::registration_version;
const LabelsKey &key = intern_labels(labels); const LabelsKey &key = intern_labels(labels);
@@ -659,7 +659,7 @@ struct Metric {
// Force thread_local initialization // Force thread_local initialization
(void)thread_init; (void)thread_init;
std::unique_lock<std::mutex> _{mutex}; std::unique_lock _{mutex};
++Metric::registration_version; ++Metric::registration_version;
const LabelsKey &key = intern_labels(labels); const LabelsKey &key = intern_labels(labels);
@@ -1137,7 +1137,7 @@ struct Metric {
uint64_t observations_snapshot; uint64_t observations_snapshot;
{ {
std::lock_guard<std::mutex> lock(instance->mutex); std::lock_guard lock(instance->mutex);
for (size_t i = 0; i < instance->counts.size(); ++i) { for (size_t i = 0; i < instance->counts.size(); ++i) {
counts_snapshot[i] = instance->counts[i]; counts_snapshot[i] = instance->counts[i];
} }
@@ -1423,7 +1423,7 @@ update_histogram_buckets_simd(std::span<const double> thresholds,
void Histogram::observe(double x) { void Histogram::observe(double x) {
assert(p->thresholds.size() == p->counts.size()); assert(p->thresholds.size() == p->counts.size());
std::lock_guard<std::mutex> lock(p->mutex); std::lock_guard lock(p->mutex);
// Update bucket counts using SIMD // Update bucket counts using SIMD
update_histogram_buckets_simd(p->thresholds, p->counts, x, 0); update_histogram_buckets_simd(p->thresholds, p->counts, x, 0);
@@ -1458,7 +1458,7 @@ Histogram Family<Histogram>::create(
Family<Counter> create_counter(std::string_view name, std::string_view help) { Family<Counter> create_counter(std::string_view name, std::string_view help) {
validate_or_abort(is_valid_metric_name(name), "invalid counter name", name); validate_or_abort(is_valid_metric_name(name), "invalid counter name", name);
std::unique_lock<std::mutex> _{Metric::mutex}; std::unique_lock _{Metric::mutex};
++Metric::registration_version; ++Metric::registration_version;
auto &global_arena = Metric::get_global_arena(); auto &global_arena = Metric::get_global_arena();
auto name_view = arena_copy_string(name, global_arena); auto name_view = arena_copy_string(name, global_arena);
@@ -1480,7 +1480,7 @@ Family<Counter> create_counter(std::string_view name, std::string_view help) {
Family<Gauge> create_gauge(std::string_view name, std::string_view help) { Family<Gauge> create_gauge(std::string_view name, std::string_view help) {
validate_or_abort(is_valid_metric_name(name), "invalid gauge name", name); validate_or_abort(is_valid_metric_name(name), "invalid gauge name", name);
std::unique_lock<std::mutex> _{Metric::mutex}; std::unique_lock _{Metric::mutex};
++Metric::registration_version; ++Metric::registration_version;
auto &global_arena = Metric::get_global_arena(); auto &global_arena = Metric::get_global_arena();
auto name_view = arena_copy_string(name, global_arena); auto name_view = arena_copy_string(name, global_arena);
@@ -1504,7 +1504,7 @@ Family<Histogram> create_histogram(std::string_view name, std::string_view help,
std::span<const double> buckets) { std::span<const double> buckets) {
validate_or_abort(is_valid_metric_name(name), "invalid histogram name", name); validate_or_abort(is_valid_metric_name(name), "invalid histogram name", name);
std::unique_lock<std::mutex> _{Metric::mutex}; std::unique_lock _{Metric::mutex};
++Metric::registration_version; ++Metric::registration_version;
auto &global_arena = Metric::get_global_arena(); auto &global_arena = Metric::get_global_arena();
auto name_view = arena_copy_string(name, global_arena); auto name_view = arena_copy_string(name, global_arena);
@@ -1693,7 +1693,7 @@ std::span<std::string_view> render(Arena &arena) {
// Hold lock throughout all phases to prevent registry changes // Hold lock throughout all phases to prevent registry changes
// THREAD SAFETY: Global mutex protects cached_plan initialization and access, // THREAD SAFETY: Global mutex protects cached_plan initialization and access,
// prevents races during static member initialization at program startup // prevents races during static member initialization at program startup
std::unique_lock<std::mutex> _{Metric::mutex}; std::unique_lock _{Metric::mutex};
// Call all registered collectors to update their metrics // Call all registered collectors to update their metrics
for (const auto &collector : Metric::get_collectors()) { for (const auto &collector : Metric::get_collectors()) {
@@ -1723,7 +1723,7 @@ template <>
void Family<Counter>::register_callback( void Family<Counter>::register_callback(
std::span<const std::pair<std::string_view, std::string_view>> labels, std::span<const std::pair<std::string_view, std::string_view>> labels,
MetricCallback<Counter> callback) { MetricCallback<Counter> callback) {
std::unique_lock<std::mutex> _{Metric::mutex}; std::unique_lock _{Metric::mutex};
++Metric::registration_version; ++Metric::registration_version;
const LabelsKey &key = Metric::intern_labels(labels); const LabelsKey &key = Metric::intern_labels(labels);
@@ -1748,7 +1748,7 @@ template <>
void Family<Gauge>::register_callback( void Family<Gauge>::register_callback(
std::span<const std::pair<std::string_view, std::string_view>> labels, std::span<const std::pair<std::string_view, std::string_view>> labels,
MetricCallback<Gauge> callback) { MetricCallback<Gauge> callback) {
std::unique_lock<std::mutex> _{Metric::mutex}; std::unique_lock _{Metric::mutex};
++Metric::registration_version; ++Metric::registration_version;
const LabelsKey &key = Metric::intern_labels(labels); const LabelsKey &key = Metric::intern_labels(labels);
@@ -1804,7 +1804,7 @@ void reset_metrics_for_testing() {
} }
void register_collector(Ref<Collector> collector) { void register_collector(Ref<Collector> collector) {
std::unique_lock<std::mutex> _{Metric::mutex}; std::unique_lock _{Metric::mutex};
++Metric::registration_version; ++Metric::registration_version;
Metric::get_collectors().push_back(std::move(collector)); Metric::get_collectors().push_back(std::move(collector));
} }

View File

@@ -1,11 +1,15 @@
#pragma once #pragma once
#define ENABLE_PERFETTO 1 #ifndef ENABLE_PERFETTO
#define ENABLE_PERFETTO 0
#endif
#if ENABLE_PERFETTO #if ENABLE_PERFETTO
#include <perfetto.h> #include <perfetto.h>
#else #else
#define PERFETTO_DEFINE_CATEGORIES(...) #define PERFETTO_DEFINE_CATEGORIES(...)
#define PERFETTO_TRACK_EVENT_STATIC_STORAGE \
void perfetto_track_event_static_storage
#define TRACE_EVENT(...) #define TRACE_EVENT(...)
#endif #endif

View File

@@ -1,21 +1,32 @@
#pragma once #pragma once
#include "connection.hpp" #include "connection.hpp"
#include <memory>
#include <variant> #include <variant>
// Forward declarations
struct CommitRequest;
/** /**
* Pipeline entry for commit requests that need full 4-stage processing. * Pipeline entry for commit requests that need full 4-stage processing.
* Contains connection with parsed CommitRequest. * Contains connection with parsed CommitRequest.
*/ */
struct CommitEntry { struct CommitEntry {
Ref<Connection> connection; WeakRef<MessageSender> connection;
int64_t assigned_version = 0; // Set by sequence stage int64_t assigned_version = 0; // Set by sequence stage
bool resolve_success = false; // Set by resolve stage bool resolve_success = false; // Set by resolve stage
bool persist_success = false; // Set by persist stage bool persist_success = false; // Set by persist stage
// Copied HTTP state (pipeline threads cannot access connection user_data)
int64_t http_request_id = 0;
bool connection_close = false;
const CommitRequest *commit_request =
nullptr; // Points to connection's arena data
CommitEntry() = default; // Default constructor for variant CommitEntry() = default; // Default constructor for variant
explicit CommitEntry(Ref<Connection> conn) : connection(std::move(conn)) {} explicit CommitEntry(WeakRef<MessageSender> conn, int64_t req_id,
bool close_conn, const CommitRequest *req)
: connection(std::move(conn)), http_request_id(req_id),
connection_close(close_conn), commit_request(req) {}
}; };
/** /**
@@ -23,11 +34,19 @@ struct CommitEntry {
* then transfer to status threadpool. * then transfer to status threadpool.
*/ */
struct StatusEntry { struct StatusEntry {
Ref<Connection> connection; WeakRef<MessageSender> connection;
int64_t version_upper_bound = 0; // Set by sequence stage int64_t version_upper_bound = 0; // Set by sequence stage
// Copied HTTP state
int64_t http_request_id = 0;
bool connection_close = false;
std::string_view status_request_id; // Points to connection's arena data
StatusEntry() = default; // Default constructor for variant StatusEntry() = default; // Default constructor for variant
explicit StatusEntry(Ref<Connection> conn) : connection(std::move(conn)) {} explicit StatusEntry(WeakRef<MessageSender> conn, int64_t req_id,
bool close_conn, std::string_view request_id)
: connection(std::move(conn)), http_request_id(req_id),
connection_close(close_conn), status_request_id(request_id) {}
}; };
/** /**
@@ -36,11 +55,17 @@ struct StatusEntry {
* Resolve stage can perform configurable CPU work for benchmarking. * Resolve stage can perform configurable CPU work for benchmarking.
*/ */
struct HealthCheckEntry { struct HealthCheckEntry {
Ref<Connection> connection; WeakRef<MessageSender> connection;
// Copied HTTP state
int64_t http_request_id = 0;
bool connection_close = false;
HealthCheckEntry() = default; // Default constructor for variant HealthCheckEntry() = default; // Default constructor for variant
explicit HealthCheckEntry(Ref<Connection> conn) explicit HealthCheckEntry(WeakRef<MessageSender> conn, int64_t req_id,
: connection(std::move(conn)) {} bool close_conn)
: connection(std::move(conn)), http_request_id(req_id),
connection_close(close_conn) {}
}; };
/** /**

View File

@@ -173,6 +173,8 @@ int Server::create_local_connection() {
auto connection = make_ref<Connection>( auto connection = make_ref<Connection>(
addr, server_fd, connection_id_.fetch_add(1, std::memory_order_relaxed), addr, server_fd, connection_id_.fetch_add(1, std::memory_order_relaxed),
epoll_index, &handler_, self_.copy()); epoll_index, &handler_, self_.copy());
connection->self_ref_ = connection.as_weak();
connection->tsan_release();
// Store in registry // Store in registry
connection_registry_.store(server_fd, std::move(connection)); connection_registry_.store(server_fd, std::move(connection));
@@ -305,10 +307,11 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
// Handle existing connection events // Handle existing connection events
int fd = events[i].data.fd; int fd = events[i].data.fd;
Ref<Connection> conn = connection_registry_.remove(fd); Ref<Connection> conn = connection_registry_.remove(fd);
conn->tsan_acquire();
assert(conn); assert(conn);
if (events[i].events & (EPOLLERR | EPOLLHUP)) { if (events[i].events & (EPOLLERR | EPOLLHUP)) {
// Connection will be destroyed on scope exit close_connection(conn);
continue; continue;
} }
@@ -361,9 +364,9 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
perror("setsockopt SO_KEEPALIVE"); perror("setsockopt SO_KEEPALIVE");
} }
// Add to epoll with no interests // Add to epoll
struct epoll_event event{}; struct epoll_event event{};
event.events = 0; event.events = EPOLLIN;
event.data.fd = fd; event.data.fd = fd;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event) == -1) { if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event) == -1) {
perror("epoll_ctl ADD"); perror("epoll_ctl ADD");
@@ -376,6 +379,8 @@ void Server::start_io_threads(std::vector<std::thread> &threads) {
addr, fd, addr, fd,
connection_id_.fetch_add(1, std::memory_order_relaxed), connection_id_.fetch_add(1, std::memory_order_relaxed),
epoll_index, &handler_, self_.copy()); epoll_index, &handler_, self_.copy());
batch[batch_count]->self_ref_ = batch[batch_count].as_weak();
batch[batch_count]->tsan_release();
batch_events[batch_count] = batch_events[batch_count] =
EPOLLIN; // New connections always start with read EPOLLIN; // New connections always start with read
batch_count++; batch_count++;
@@ -413,7 +418,7 @@ void Server::process_connection_reads(Ref<Connection> &conn, int events) {
if (r < 0) { if (r < 0) {
// Error or EOF - connection should be closed // Error or EOF - connection should be closed
conn.reset(); close_connection(conn);
return; return;
} }
@@ -429,34 +434,37 @@ void Server::process_connection_reads(Ref<Connection> &conn, int events) {
void Server::process_connection_writes(Ref<Connection> &conn, int /*events*/) { void Server::process_connection_writes(Ref<Connection> &conn, int /*events*/) {
assert(conn); assert(conn);
// For simplicity, we always attempt to write when an event fires. We could be auto result = conn->write_bytes();
// more precise and skip the write if we detect that we've already seen EAGAIN
// on this connection and we don't have EPOLLOUT.
if (conn->has_messages()) {
bool had_messages = conn->has_messages();
bool error = conn->writeBytes();
if (error) { if (result & Connection::WriteBytesResult::Error) {
conn.reset(); // Connection should be closed close_connection(conn);
return; return;
} }
if (result & Connection::WriteBytesResult::Progress) {
// Call handler with connection reference - server retains ownership // Call handler with connection reference - server retains ownership
handler_.on_write_progress(*conn); handler_.on_write_progress(*conn);
}
// Check if buffer became empty (transition from non-empty -> empty) if (result & Connection::WriteBytesResult::Drained) {
if (had_messages && !conn->has_messages()) { // Call handler with connection reference - server retains ownership
handler_.on_write_buffer_drained(*conn); handler_.on_write_buffer_drained(*conn);
} }
// Check if we should close the connection according to application // Check if we should close the connection according to application
if (!conn->has_messages() && conn->should_close()) { if (result & Connection::WriteBytesResult::Close) {
conn.reset(); // Connection should be closed close_connection(conn);
return; return;
} }
} }
void Server::close_connection(Ref<Connection> &conn) {
conn->close();
conn.reset();
} }
static thread_local std::vector<Connection *> conn_ptrs;
void Server::process_connection_batch(int epollfd, void Server::process_connection_batch(int epollfd,
std::span<Ref<Connection>> batch, std::span<Ref<Connection>> batch,
std::span<const int> events) { std::span<const int> events) {
@@ -476,8 +484,7 @@ void Server::process_connection_batch(int epollfd,
} }
// Call batch complete handler with connection pointers // Call batch complete handler with connection pointers
std::vector<Connection *> conn_ptrs; conn_ptrs.clear();
conn_ptrs.reserve(batch.size());
for (auto &conn_ref : batch) { for (auto &conn_ref : batch) {
if (conn_ref) { if (conn_ref) {
conn_ptrs.push_back(conn_ref.get()); conn_ptrs.push_back(conn_ref.get());
@@ -485,26 +492,11 @@ void Server::process_connection_batch(int epollfd,
} }
handler_.on_batch_complete(conn_ptrs); handler_.on_batch_complete(conn_ptrs);
// Transfer all remaining connections back to epoll // Transfer all remaining connections back to registry
for (auto &conn_ptr : batch) { for (auto &conn_ptr : batch) {
if (conn_ptr) { if (conn_ptr) {
int fd = conn_ptr->getFd(); int fd = conn_ptr->getFd();
struct epoll_event event{};
if (!conn_ptr->has_messages()) {
event.events = EPOLLIN;
} else {
event.events = EPOLLIN | EPOLLOUT;
}
event.data.fd = fd;
// Put connection back in registry since handler didn't take ownership
// Must happen before epoll_ctl
connection_registry_.store(fd, std::move(conn_ptr)); connection_registry_.store(fd, std::move(conn_ptr));
if (epoll_ctl(epollfd, EPOLL_CTL_MOD, fd, &event) == -1) {
perror("epoll_ctl MOD");
(void)connection_registry_.remove(fd);
}
} }
} }
} }

View File

@@ -148,6 +148,8 @@ private:
void process_connection_reads(Ref<Connection> &conn_ptr, int events); void process_connection_reads(Ref<Connection> &conn_ptr, int events);
void process_connection_writes(Ref<Connection> &conn_ptr, int events); void process_connection_writes(Ref<Connection> &conn_ptr, int events);
void close_connection(Ref<Connection> &conn);
// Helper for processing a batch of connections with their events // Helper for processing a batch of connections with their events
void process_connection_batch(int epollfd, std::span<Ref<Connection>> batch, void process_connection_batch(int epollfd, std::span<Ref<Connection>> batch,
std::span<const int> events); std::span<const int> events);

View File

@@ -4,7 +4,7 @@
#include "server.hpp" #include "server.hpp"
#include <doctest/doctest.h> #include <doctest/doctest.h>
#include <future> #include <latch>
#include <string_view> #include <string_view>
#include <thread> #include <thread>
@@ -19,19 +19,16 @@ static std::string_view arena_copy_string(std::string_view str, Arena &arena) {
} }
struct EchoHandler : ConnectionHandler { struct EchoHandler : ConnectionHandler {
std::future<void> f;
void on_data_arrived(std::string_view data, Connection &conn) override {
Arena arena; Arena arena;
auto reply = std::span{arena.allocate<std::string_view>(1), 1}; std::span<std::string_view> reply;
WeakRef<MessageSender> wconn;
std::latch done{1};
void on_data_arrived(std::string_view data, Connection &conn) override {
reply = std::span{arena.allocate<std::string_view>(1), 1};
reply[0] = arena_copy_string(data, arena); reply[0] = arena_copy_string(data, arena);
f = std::async(std::launch::async, [wconn = conn.get_weak_ref(), reply, wconn = conn.get_weak_ref();
arena = std::move(arena)]() mutable { CHECK(wconn.lock());
if (auto conn = wconn.lock()) { done.count_down();
conn->append_message(reply, std::move(arena));
} else {
REQUIRE(false);
}
});
} }
}; };
@@ -44,15 +41,22 @@ TEST_CASE("Echo test") {
auto runThread = std::thread{[&]() { server->run(); }}; auto runThread = std::thread{[&]() { server->run(); }};
SUBCASE("writes hello back") {
int w = write(fd, "hello", 5); int w = write(fd, "hello", 5);
REQUIRE(w == 5); REQUIRE(w == 5);
handler.done.wait();
if (auto conn = handler.wconn.lock()) {
conn->append_message(std::exchange(handler.reply, {}),
std::move(handler.arena));
} else {
REQUIRE(false);
}
char buf[6]; char buf[6];
buf[5] = 0; buf[5] = 0;
int r = read(fd, buf, 5); int r = read(fd, buf, 5);
REQUIRE(r == 5); REQUIRE(r == 5);
CHECK(std::string(buf) == "hello"); CHECK(std::string(buf) == "hello");
}
close(fd); close(fd);

View File

@@ -297,7 +297,7 @@ struct Connection {
} }
} }
bool writeBytes() { bool write_bytes() {
for (;;) { for (;;) {
assert(!request.empty()); assert(!request.empty());
int w = send(fd, request.data(), request.size(), MSG_NOSIGNAL); int w = send(fd, request.data(), request.size(), MSG_NOSIGNAL);
@@ -672,7 +672,7 @@ int main(int argc, char *argv[]) {
continue; // Let unique_ptr destructor clean up continue; // Let unique_ptr destructor clean up
} }
if (events[i].events & EPOLLOUT) { if (events[i].events & EPOLLOUT) {
bool finished = conn->writeBytes(); bool finished = conn->write_bytes();
if (conn->error) { if (conn->error) {
continue; continue;
} }
@@ -748,14 +748,14 @@ int main(int argc, char *argv[]) {
// Try to write once in the connect thread before handing off to network // Try to write once in the connect thread before handing off to network
// threads // threads
assert(conn->has_messages()); assert(conn->has_messages());
bool writeFinished = conn->writeBytes(); bool write_finished = conn->write_bytes();
if (conn->error) { if (conn->error) {
continue; // Connection failed, destructor will clean up continue; // Connection failed, destructor will clean up
} }
// Determine the appropriate epoll events based on write result // Determine the appropriate epoll events based on write result
struct epoll_event event{}; struct epoll_event event{};
if (writeFinished) { if (write_finished) {
// All data was written, wait for response // All data was written, wait for response
int shutdown_result = shutdown(conn->fd, SHUT_WR); int shutdown_result = shutdown(conn->fd, SHUT_WR);
if (shutdown_result == -1) { if (shutdown_result == -1) {