Files
weaseldb/src/http_handler.cpp

419 lines
13 KiB
C++

#include "http_handler.hpp"
#include "arena_allocator.hpp"
#include "perfetto_categories.hpp"
#include <cstring>
#include <string>
#include <strings.h>
// HttpConnectionState implementation
HttpConnectionState::HttpConnectionState(ArenaAllocator &arena)
: current_header_field_buf(ArenaStlAllocator<char>(&arena)),
current_header_value_buf(ArenaStlAllocator<char>(&arena)) {
llhttp_settings_init(&settings);
// Set up llhttp callbacks
settings.on_url = HttpHandler::onUrl;
settings.on_header_field = HttpHandler::onHeaderField;
settings.on_header_field_complete = HttpHandler::onHeaderFieldComplete;
settings.on_header_value = HttpHandler::onHeaderValue;
settings.on_header_value_complete = HttpHandler::onHeaderValueComplete;
settings.on_headers_complete = HttpHandler::onHeadersComplete;
settings.on_body = HttpHandler::onBody;
settings.on_message_complete = HttpHandler::onMessageComplete;
llhttp_init(&parser, HTTP_REQUEST, &settings);
parser.data = this;
}
// HttpHandler implementation
void HttpHandler::on_connection_established(Connection &conn) {
// Allocate HTTP state in connection's arena
ArenaAllocator &arena = conn.getArena();
void *mem = arena.allocate_raw(sizeof(HttpConnectionState),
alignof(HttpConnectionState));
auto *state = new (mem) HttpConnectionState(arena);
conn.user_data = state;
}
void HttpHandler::on_connection_closed(Connection &conn) {
// Arena cleanup happens automatically when connection is destroyed
auto *state = static_cast<HttpConnectionState *>(conn.user_data);
state->~HttpConnectionState();
conn.user_data = nullptr;
}
void HttpHandler::on_write_progress(std::unique_ptr<Connection> &conn_ptr) {
// Reset arena after all messages have been written for the next request
if (conn_ptr->outgoingBytesQueued() == 0) {
on_connection_closed(*conn_ptr);
conn_ptr->reset();
on_connection_established(*conn_ptr);
}
}
void HttpHandler::on_post_batch(std::span<std::unique_ptr<Connection>> batch) {
int readyCount = 0;
for (int i = 0; i < int(batch.size()); ++i) {
readyCount += batch[i] && batch[i]->outgoingBytesQueued() > 0;
}
if (readyCount > 0) {
auto guard = pipeline.push(readyCount, /*block=*/true);
auto outIter = guard.batch.begin();
for (int i = 0; i < int(batch.size()); ++i) {
if (batch[i] && batch[i]->outgoingBytesQueued() > 0) {
*outIter++ = std::move(batch[i]);
}
}
}
}
void HttpHandler::on_data_arrived(std::string_view data,
std::unique_ptr<Connection> &conn_ptr) {
auto *state = static_cast<HttpConnectionState *>(conn_ptr->user_data);
if (!state) {
sendErrorResponse(*conn_ptr, 500, "Internal server error", true);
return;
}
// TODO: Enforce the configured max_request_size_bytes limit here.
// Should track cumulative bytes received for the current HTTP request
// and send 413 Request Entity Too Large if limit is exceeded.
// This prevents DoS attacks via oversized HTTP requests.
// Parse HTTP data with llhttp
enum llhttp_errno err =
llhttp_execute(&state->parser, data.data(), data.size());
if (err != HPE_OK) {
sendErrorResponse(*conn_ptr, 400, "Bad request", true);
return;
}
// If message is complete, route and handle the request
if (state->message_complete) {
// Parse route from method and URL
state->route = parseRoute(state->method, state->url);
// Route to appropriate handler
switch (state->route) {
case HttpRoute::GET_version:
handleGetVersion(*conn_ptr, *state);
break;
case HttpRoute::POST_commit:
handlePostCommit(*conn_ptr, *state);
break;
case HttpRoute::GET_subscribe:
handleGetSubscribe(*conn_ptr, *state);
break;
case HttpRoute::GET_status:
handleGetStatus(*conn_ptr, *state);
break;
case HttpRoute::PUT_retention:
handlePutRetention(*conn_ptr, *state);
break;
case HttpRoute::GET_retention:
handleGetRetention(*conn_ptr, *state);
break;
case HttpRoute::DELETE_retention:
handleDeleteRetention(*conn_ptr, *state);
break;
case HttpRoute::GET_metrics:
handleGetMetrics(*conn_ptr, *state);
break;
case HttpRoute::GET_ok:
handleGetOk(*conn_ptr, *state);
break;
case HttpRoute::NotFound:
default:
handleNotFound(*conn_ptr, *state);
break;
}
}
}
HttpRoute HttpHandler::parseRoute(std::string_view method,
std::string_view url) {
// Strip query parameters if present
size_t query_pos = url.find('?');
if (query_pos != std::string_view::npos) {
url = url.substr(0, query_pos);
}
// Route based on method and path
if (method == "GET") {
if (url == "/v1/version")
return HttpRoute::GET_version;
if (url == "/v1/subscribe")
return HttpRoute::GET_subscribe;
if (url.starts_with("/v1/status"))
return HttpRoute::GET_status;
if (url.starts_with("/v1/retention")) {
// Check if it's a specific retention policy or list all
return HttpRoute::GET_retention;
}
if (url == "/metrics")
return HttpRoute::GET_metrics;
if (url == "/ok")
return HttpRoute::GET_ok;
} else if (method == "POST") {
if (url == "/v1/commit")
return HttpRoute::POST_commit;
} else if (method == "PUT") {
if (url.starts_with("/v1/retention/"))
return HttpRoute::PUT_retention;
} else if (method == "DELETE") {
if (url.starts_with("/v1/retention/"))
return HttpRoute::DELETE_retention;
}
return HttpRoute::NotFound;
}
// Route handlers (basic implementations)
void HttpHandler::handleGetVersion(Connection &conn,
const HttpConnectionState &state) {
sendJsonResponse(
conn, 200,
R"({"version":"0.0.1","leader":"node-1","committed_version":42})",
state.connection_close);
}
void HttpHandler::handlePostCommit(Connection &conn,
const HttpConnectionState &state) {
// TODO: Parse commit request from state.body and process
sendJsonResponse(
conn, 200,
R"({"request_id":"example","status":"committed","version":43})",
state.connection_close);
}
void HttpHandler::handleGetSubscribe(Connection &conn,
const HttpConnectionState &state) {
// TODO: Implement subscription streaming
sendJsonResponse(
conn, 200,
R"({"message":"Subscription endpoint - streaming not yet implemented"})",
state.connection_close);
}
void HttpHandler::handleGetStatus(Connection &conn,
const HttpConnectionState &state) {
// TODO: Extract request_id from URL and check status
sendJsonResponse(
conn, 200,
R"({"request_id":"example","status":"committed","version":43})",
state.connection_close);
}
void HttpHandler::handlePutRetention(Connection &conn,
const HttpConnectionState &state) {
// TODO: Parse retention policy from body and store
sendJsonResponse(conn, 200, R"({"policy_id":"example","status":"created"})",
state.connection_close);
}
void HttpHandler::handleGetRetention(Connection &conn,
const HttpConnectionState &state) {
// TODO: Extract policy_id from URL or return all policies
sendJsonResponse(conn, 200, R"({"policies":[]})", state.connection_close);
}
void HttpHandler::handleDeleteRetention(Connection &conn,
const HttpConnectionState &state) {
// TODO: Extract policy_id from URL and delete
sendJsonResponse(conn, 200, R"({"policy_id":"example","status":"deleted"})",
state.connection_close);
}
void HttpHandler::handleGetMetrics(Connection &conn,
const HttpConnectionState &state) {
// TODO: Implement metrics collection and formatting
sendResponse(conn, 200, "text/plain",
"# WeaselDB metrics\nweaseldb_requests_total 0\n",
state.connection_close);
}
void HttpHandler::handleGetOk(Connection &conn,
const HttpConnectionState &state) {
TRACE_EVENT("http", "GET /ok", perfetto::Flow::Global(state.request_id));
sendResponse(conn, 200, "text/plain", "OK", state.connection_close);
}
void HttpHandler::handleNotFound(Connection &conn,
const HttpConnectionState &state) {
sendErrorResponse(conn, 404, "Not found", state.connection_close);
}
// HTTP utility methods
void HttpHandler::sendResponse(Connection &conn, int status_code,
std::string_view content_type,
std::string_view body, bool close_connection) {
[[maybe_unused]] ArenaAllocator &arena = conn.getArena();
// Build HTTP response using arena
std::string response;
response.reserve(256 + body.size());
response += "HTTP/1.1 ";
response += std::to_string(status_code);
response += " ";
// Status text
switch (status_code) {
case 200:
response += "OK";
break;
case 400:
response += "Bad Request";
break;
case 404:
response += "Not Found";
break;
case 500:
response += "Internal Server Error";
break;
default:
response += "Unknown";
break;
}
auto *state = static_cast<HttpConnectionState *>(conn.user_data);
response += "\r\n";
response += "Content-Type: ";
response += content_type;
response += "\r\n";
response += "Content-Length: ";
response += std::to_string(body.size());
response += "\r\n";
response += "X-Response-ID: ";
response += std::to_string(state->request_id);
response += "\r\n";
if (close_connection) {
response += "Connection: close\r\n";
conn.closeAfterSend(); // Signal connection should be closed after sending
} else {
response += "Connection: keep-alive\r\n";
}
response += "\r\n";
response += body;
conn.appendMessage(response);
}
void HttpHandler::sendJsonResponse(Connection &conn, int status_code,
std::string_view json,
bool close_connection) {
sendResponse(conn, status_code, "application/json", json, close_connection);
}
void HttpHandler::sendErrorResponse(Connection &conn, int status_code,
std::string_view message,
bool close_connection) {
[[maybe_unused]] ArenaAllocator &arena = conn.getArena();
std::string json = R"({"error":")";
json += message;
json += R"("})";
sendJsonResponse(conn, status_code, json, close_connection);
}
// llhttp callbacks
int HttpHandler::onUrl(llhttp_t *parser, const char *at, size_t length) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
// Store URL in arena (simplified - would need to accumulate for streaming)
state->url = std::string_view(at, length);
return 0;
}
int HttpHandler::onHeaderField(llhttp_t *parser, const char *at,
size_t length) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
// Accumulate header field data
state->current_header_field_buf.append(at, length);
return 0;
}
int HttpHandler::onHeaderFieldComplete(llhttp_t *parser) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
state->header_field_complete = true;
return 0;
}
int HttpHandler::onHeaderValue(llhttp_t *parser, const char *at,
size_t length) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
// Accumulate header value data
state->current_header_value_buf.append(at, length);
return 0;
}
int HttpHandler::onHeaderValueComplete(llhttp_t *parser) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
if (!state->header_field_complete) {
// Field is not complete yet, wait
return 0;
}
// Now we have complete header field and value
const auto &field = state->current_header_field_buf;
const auto &value = state->current_header_value_buf;
// Check for Connection header
if (field.size() == 10 && strncasecmp(field.data(), "connection", 10) == 0) {
if (value.size() == 5 && strncasecmp(value.data(), "close", 5) == 0) {
state->connection_close = true;
}
}
// Check for X-Request-Id header
if (field.size() == 12 &&
strncasecmp(field.data(), "x-request-id", 12) == 0) {
uint64_t id = 0;
for (char c : value) {
if (c >= '0' && c <= '9') {
id = id * 10 + (c - '0');
}
}
state->request_id = id;
}
// Clear buffers for next header
state->current_header_field_buf.clear();
state->current_header_value_buf.clear();
state->header_field_complete = false;
return 0;
}
int HttpHandler::onHeadersComplete(llhttp_t *parser) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
state->headers_complete = true;
// Get HTTP method
const char *method_str =
llhttp_method_name(static_cast<llhttp_method_t>(parser->method));
state->method = std::string_view(method_str);
return 0;
}
int HttpHandler::onBody(llhttp_t *parser, const char *at, size_t length) {
[[maybe_unused]] auto *state =
static_cast<HttpConnectionState *>(parser->data);
(void)at;
(void)length;
return 0;
}
int HttpHandler::onMessageComplete(llhttp_t *parser) {
auto *state = static_cast<HttpConnectionState *>(parser->data);
state->message_complete = true;
return 0;
}