Accumulate headers properly

This commit is contained in:
2025-08-20 14:10:39 -04:00
parent 8ccb02f450
commit abb47ee0c3
3 changed files with 100 additions and 36 deletions

View File

@@ -6,14 +6,17 @@
#include <strings.h> #include <strings.h>
// HttpConnectionState implementation // HttpConnectionState implementation
HttpConnectionState::HttpConnectionState( HttpConnectionState::HttpConnectionState(ArenaAllocator &arena)
[[maybe_unused]] ArenaAllocator &arena) { : current_header_field_buf(ArenaStlAllocator<char>(&arena)),
current_header_value_buf(ArenaStlAllocator<char>(&arena)) {
llhttp_settings_init(&settings); llhttp_settings_init(&settings);
// Set up llhttp callbacks // Set up llhttp callbacks
settings.on_url = HttpHandler::onUrl; settings.on_url = HttpHandler::onUrl;
settings.on_header_field = HttpHandler::onHeaderField; settings.on_header_field = HttpHandler::onHeaderField;
settings.on_header_field_complete = HttpHandler::onHeaderFieldComplete;
settings.on_header_value = HttpHandler::onHeaderValue; settings.on_header_value = HttpHandler::onHeaderValue;
settings.on_header_value_complete = HttpHandler::onHeaderValueComplete;
settings.on_headers_complete = HttpHandler::onHeadersComplete; settings.on_headers_complete = HttpHandler::onHeadersComplete;
settings.on_body = HttpHandler::onBody; settings.on_body = HttpHandler::onBody;
settings.on_message_complete = HttpHandler::onMessageComplete; settings.on_message_complete = HttpHandler::onMessageComplete;
@@ -26,18 +29,23 @@ HttpConnectionState::HttpConnectionState(
void HttpHandler::on_connection_established(Connection &conn) { void HttpHandler::on_connection_established(Connection &conn) {
// Allocate HTTP state in connection's arena // Allocate HTTP state in connection's arena
ArenaAllocator &arena = conn.getArena(); ArenaAllocator &arena = conn.getArena();
auto *state = arena.construct<HttpConnectionState>(arena); void *mem = arena.allocate_raw(sizeof(HttpConnectionState),
alignof(HttpConnectionState));
auto *state = new (mem) HttpConnectionState(arena);
conn.user_data = state; conn.user_data = state;
} }
void HttpHandler::on_connection_closed(Connection &conn) { void HttpHandler::on_connection_closed(Connection &conn) {
// Arena cleanup happens automatically when connection is destroyed // Arena cleanup happens automatically when connection is destroyed
auto *state = static_cast<HttpConnectionState *>(conn.user_data);
state->~HttpConnectionState();
conn.user_data = nullptr; conn.user_data = nullptr;
} }
void HttpHandler::on_write_progress(std::unique_ptr<Connection> &conn_ptr) { void HttpHandler::on_write_progress(std::unique_ptr<Connection> &conn_ptr) {
// Reset arena after all messages have been written for the next request // Reset arena after all messages have been written for the next request
if (conn_ptr->outgoingBytesQueued() == 0) { if (conn_ptr->outgoingBytesQueued() == 0) {
on_connection_closed(*conn_ptr);
conn_ptr->reset(); conn_ptr->reset();
on_connection_established(*conn_ptr); on_connection_established(*conn_ptr);
} }
@@ -312,31 +320,49 @@ int HttpHandler::onUrl(llhttp_t *parser, const char *at, size_t length) {
int HttpHandler::onHeaderField(llhttp_t *parser, const char *at, int HttpHandler::onHeaderField(llhttp_t *parser, const char *at,
size_t length) { size_t length) {
auto *state = static_cast<HttpConnectionState *>(parser->data); auto *state = static_cast<HttpConnectionState *>(parser->data);
// Store current header field name for processing in onHeaderValue // Accumulate header field data
state->current_header_field = std::string_view(at, length); state->current_header_field_buf.append(at, length);
return 0;
}
int HttpHandler::onHeaderFieldComplete(llhttp_t *parser) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
state->header_field_complete = true;
return 0; return 0;
} }
int HttpHandler::onHeaderValue(llhttp_t *parser, const char *at, int HttpHandler::onHeaderValue(llhttp_t *parser, const char *at,
size_t length) { size_t length) {
auto *state = static_cast<HttpConnectionState *>(parser->data); auto *state = static_cast<HttpConnectionState *>(parser->data);
std::string_view value(at, length); // Accumulate header value data
state->current_header_value_buf.append(at, length);
return 0;
}
int HttpHandler::onHeaderValueComplete(llhttp_t *parser) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
if (!state->header_field_complete) {
// Field is not complete yet, wait
return 0;
}
// Now we have complete header field and value
const auto &field = state->current_header_field_buf;
const auto &value = state->current_header_value_buf;
// Check for Connection header // Check for Connection header
if (state->current_header_field.size() == 10 && if (field.size() == 10 && strncasecmp(field.data(), "connection", 10) == 0) {
strncasecmp(state->current_header_field.data(), "connection", 10) == 0) {
if (value.size() == 5 && strncasecmp(value.data(), "close", 5) == 0) { if (value.size() == 5 && strncasecmp(value.data(), "close", 5) == 0) {
state->connection_close = true; state->connection_close = true;
} }
} }
// Check for X-Request-Id header // Check for X-Request-Id header
if (state->current_header_field.size() == 12 && if (field.size() == 12 &&
strncasecmp(state->current_header_field.data(), "x-request-id", 12) == strncasecmp(field.data(), "x-request-id", 12) == 0) {
0) {
uint64_t id = 0; uint64_t id = 0;
for (int i = 0; i < int(length); ++i) { for (char c : value) {
auto c = at[i];
if (c >= '0' && c <= '9') { if (c >= '0' && c <= '9') {
id = id * 10 + (c - '0'); id = id * 10 + (c - '0');
} }
@@ -344,6 +370,11 @@ int HttpHandler::onHeaderValue(llhttp_t *parser, const char *at,
state->request_id = id; state->request_id = id;
} }
// Clear buffers for next header
state->current_header_field_buf.clear();
state->current_header_value_buf.clear();
state->header_field_complete = false;
return 0; return 0;
} }

View File

@@ -41,8 +41,14 @@ struct HttpConnectionState {
bool message_complete = false; bool message_complete = false;
bool connection_close = false; // Client requested connection close bool connection_close = false; // Client requested connection close
HttpRoute route = HttpRoute::NotFound; HttpRoute route = HttpRoute::NotFound;
std::string_view current_header_field; // Current header being parsed
uint64_t request_id = 0; // X-Request-Id header value // Header accumulation buffers (arena-allocated)
using ArenaString =
std::basic_string<char, std::char_traits<char>, ArenaStlAllocator<char>>;
ArenaString current_header_field_buf;
ArenaString current_header_value_buf;
bool header_field_complete = false;
uint64_t request_id = 0; // X-Request-Id header value
explicit HttpConnectionState(ArenaAllocator &arena); explicit HttpConnectionState(ArenaAllocator &arena);
}; };
@@ -67,7 +73,9 @@ public:
// llhttp callbacks (public for HttpConnectionState access) // llhttp callbacks (public for HttpConnectionState access)
static int onUrl(llhttp_t *parser, const char *at, size_t length); static int onUrl(llhttp_t *parser, const char *at, size_t length);
static int onHeaderField(llhttp_t *parser, const char *at, size_t length); static int onHeaderField(llhttp_t *parser, const char *at, size_t length);
static int onHeaderFieldComplete(llhttp_t *parser);
static int onHeaderValue(llhttp_t *parser, const char *at, size_t length); static int onHeaderValue(llhttp_t *parser, const char *at, size_t length);
static int onHeaderValueComplete(llhttp_t *parser);
static int onHeadersComplete(llhttp_t *parser); static int onHeadersComplete(llhttp_t *parser);
static int onBody(llhttp_t *parser, const char *at, size_t length); static int onBody(llhttp_t *parser, const char *at, size_t length);
static int onMessageComplete(llhttp_t *parser); static int onMessageComplete(llhttp_t *parser);

View File

@@ -159,17 +159,17 @@ void signal_handler(int sig) {
struct Connection { struct Connection {
static const llhttp_settings_t settings; static const llhttp_settings_t settings;
static std::atomic<uint64_t> requestId; static std::atomic<uint64_t> nextRequestId;
char buf[1024]; // Increased size for dynamic request format char buf[1024];
std::string_view request; std::string_view request;
uint64_t currentRequestId; uint64_t requestId = 0;
void initRequest() { void initRequest() {
currentRequestId = requestId.fetch_add(1, std::memory_order_relaxed); requestId = nextRequestId.fetch_add(1, std::memory_order_relaxed);
int len = snprintf(buf, sizeof(buf), int len = snprintf(buf, sizeof(buf),
"GET /ok HTTP/1.1\r\nX-Request-Id: %" PRIu64 "\r\n\r\n", "GET /ok HTTP/1.1\r\nX-Request-Id: %" PRIu64 "\r\n\r\n",
currentRequestId); requestId);
if (len == -1 || len > int(sizeof(buf))) { if (len == -1 || len > int(sizeof(buf))) {
abort(); abort();
} }
@@ -203,7 +203,7 @@ struct Connection {
bool readBytes() { bool readBytes() {
for (;;) { for (;;) {
char buf[1024]; // Use a reasonable default, configurable via g_config char buf[64 * (1 << 10)]; // 64 KiB
int buf_size = std::min(int(sizeof(buf)), g_config.connection_buf_size); int buf_size = std::min(int(sizeof(buf)), g_config.connection_buf_size);
int r = read(fd, buf, buf_size); int r = read(fd, buf, buf_size);
if (r == -1) { if (r == -1) {
@@ -214,6 +214,7 @@ struct Connection {
return false; return false;
} }
} }
// printf("read: %.*s\n", r, buf);
if (r == 0) { if (r == 0) {
llhttp_finish(&parser); llhttp_finish(&parser);
return true; return true;
@@ -233,8 +234,8 @@ struct Connection {
bool writeBytes() { bool writeBytes() {
for (;;) { for (;;) {
int w; assert(!request.empty());
w = write(fd, request.data(), request.size()); int w = write(fd, request.data(), request.size());
if (w == -1) { if (w == -1) {
if (errno == EINTR) { if (errno == EINTR) {
continue; continue;
@@ -247,14 +248,12 @@ struct Connection {
return true; return true;
} }
assert(w != 0); assert(w != 0);
// printf("write: %.*s\n", w, request.data());
request = request.substr(w, request.size() - w); request = request.substr(w, request.size() - w);
if (request.empty()) { if (request.empty()) {
++requestsSent; ++requestsSent;
TRACE_EVENT("http", "Send request", TRACE_EVENT("http", "Send request", perfetto::Flow::Global(requestId));
perfetto::Flow::Global(currentRequestId)); return requestsSent == g_config.requests_per_connection;
if (requestsSent == g_config.requests_per_connection) {
return true;
}
} }
} }
} }
@@ -285,14 +284,40 @@ private:
} }
uint64_t responseId = 0; uint64_t responseId = 0;
std::string headerValueBuffer;
bool expectingResponseId = false;
int on_header_field(const char *data, size_t s) {
std::string_view field(data, s);
expectingResponseId = (field == "X-Response-ID");
if (expectingResponseId) {
headerValueBuffer.clear();
}
return 0;
}
int on_header_value(const char *data, size_t s) { int on_header_value(const char *data, size_t s) {
for (int i = 0; i < int(s); ++i) { if (expectingResponseId) {
responseId = responseId * 10 + data[i] - '0'; headerValueBuffer.append(data, s);
}
return 0;
}
int on_header_value_complete() {
if (expectingResponseId) {
responseId = 0;
for (char c : headerValueBuffer) {
if (c >= '0' && c <= '9') {
responseId = responseId * 10 + (c - '0');
}
}
expectingResponseId = false;
} }
return 0; return 0;
} }
int on_message_complete() { int on_message_complete() {
assert(responseId == requestId);
TRACE_EVENT("http", "Receive response", perfetto::Flow::Global(responseId)); TRACE_EVENT("http", "Receive response", perfetto::Flow::Global(responseId));
responseId = 0; responseId = 0;
++responsesReceived; ++responsesReceived;
@@ -303,13 +328,16 @@ private:
llhttp_t parser; llhttp_t parser;
}; };
std::atomic<uint64_t> Connection::requestId = {}; std::atomic<uint64_t> Connection::nextRequestId = {};
const llhttp_settings_t Connection::settings = []() { const llhttp_settings_t Connection::settings = []() {
llhttp_settings_t settings; llhttp_settings_t settings;
llhttp_settings_init(&settings); llhttp_settings_init(&settings);
settings.on_message_complete = callback<&Connection::on_message_complete>; settings.on_message_complete = callback<&Connection::on_message_complete>;
settings.on_header_field = callback<&Connection::on_header_field>;
settings.on_header_value = callback<&Connection::on_header_value>; settings.on_header_value = callback<&Connection::on_header_value>;
settings.on_header_value_complete =
callback<&Connection::on_header_value_complete>;
return settings; return settings;
}(); }();
@@ -514,7 +542,7 @@ int main(int argc, char *argv[]) {
pthread_setname_np(pthread_self(), pthread_setname_np(pthread_self(),
("network-" + std::to_string(i)).c_str()); ("network-" + std::to_string(i)).c_str());
while (!g_shutdown.load(std::memory_order_relaxed)) { while (!g_shutdown.load(std::memory_order_relaxed)) {
struct epoll_event events[64]; // Use a reasonable max size struct epoll_event events[256]; // Use a reasonable max size
int batch_size = std::min(int(sizeof(events) / sizeof(events[0])), int batch_size = std::min(int(sizeof(events) / sizeof(events[0])),
g_config.event_batch_size); g_config.event_batch_size);
int eventCount; int eventCount;
@@ -528,10 +556,6 @@ int main(int argc, char *argv[]) {
perror("epoll_wait"); perror("epoll_wait");
abort(); // Keep abort for critical errors like server does abort(); // Keep abort for critical errors like server does
} }
if (eventCount == 0) {
// Timeout - check shutdown flag
break;
}
break; break;
} }
@@ -624,6 +648,7 @@ int main(int argc, char *argv[]) {
// Add to epoll with proper events matching server pattern // Add to epoll with proper events matching server pattern
struct epoll_event event{}; struct epoll_event event{};
assert(conn->hasMessages());
event.events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP; event.events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP;
conn->tsan_release(); conn->tsan_release();
Connection *raw_conn = conn.release(); Connection *raw_conn = conn.release();