Files
weaseldb/src/http_handler.cpp

692 lines
25 KiB
C++

#include "http_handler.hpp"
#include <cstring>
#include <string>
#include <strings.h>
#include "api_url_parser.hpp"
#include "arena.hpp"
#include "connection.hpp"
#include "format.hpp"
#include "json_commit_request_parser.hpp"
#include "metric.hpp"
#include "pipeline_entry.hpp"
auto requests_counter_family = metric::create_counter(
"weaseldb_http_requests_total", "Total http requests");
thread_local auto metrics_counter =
requests_counter_family.create({{"path", "/metrics"}});
// API endpoint request counters
thread_local auto commit_counter =
requests_counter_family.create({{"path", "/v1/commit"}});
thread_local auto status_counter =
requests_counter_family.create({{"path", "/v1/status"}});
thread_local auto version_counter =
requests_counter_family.create({{"path", "/v1/version"}});
thread_local auto ok_counter =
requests_counter_family.create({{"path", "/ok"}});
thread_local auto not_found_counter =
requests_counter_family.create({{"path", "not_found"}});
HttpConnectionState::HttpConnectionState() {
llhttp_settings_init(&settings);
// Set up llhttp callbacks
settings.on_url = HttpHandler::onUrl;
settings.on_header_field = HttpHandler::onHeaderField;
settings.on_header_field_complete = HttpHandler::onHeaderFieldComplete;
settings.on_header_value = HttpHandler::onHeaderValue;
settings.on_header_value_complete = HttpHandler::onHeaderValueComplete;
settings.on_headers_complete = HttpHandler::onHeadersComplete;
settings.on_body = HttpHandler::onBody;
settings.on_message_complete = HttpHandler::onMessageComplete;
llhttp_init(&parser, HTTP_REQUEST, &settings);
parser.data = &pending;
}
// HttpConnectionState implementation
HttpRequestState::HttpRequestState()
: url(ArenaStlAllocator<char>(&arena)),
current_header_field_buf(ArenaStlAllocator<char>(&arena)),
current_header_value_buf(ArenaStlAllocator<char>(&arena)) {}
// HttpHandler implementation
void HttpHandler::on_connection_established(Connection &conn) {
// Allocate HTTP state using server-provided arena for connection lifecycle
auto *state = new HttpConnectionState();
conn.user_data = state;
}
void HttpHandler::on_connection_closed(Connection &conn) {
auto *state = static_cast<HttpConnectionState *>(conn.user_data);
delete state;
conn.user_data = nullptr;
}
void HttpHandler::on_preprocess_writes(
Connection &conn, std::span<PendingResponse> pending_responses) {
auto *state = static_cast<HttpConnectionState *>(conn.user_data);
// Process incoming responses and add to reorder queue
{
for (auto &pending : pending_responses) {
auto *ctx = static_cast<HttpResponseContext *>(pending.protocol_context);
// Determine HTTP status code and content type from response content
int status_code = 200;
std::string_view content_type = "application/json";
// For health checks, detect plain text responses
if (pending.response_json == "OK") {
content_type = "text/plain";
}
// For metrics, detect Prometheus format (starts with # or contains metric
// names)
else if (pending.response_json.starts_with("#") ||
pending.response_json.find("_total") != std::string_view::npos ||
pending.response_json.find("_counter") !=
std::string_view::npos) {
content_type = "text/plain; version=0.0.4";
}
// Format HTTP response from JSON
auto http_response = format_response(
status_code, content_type, pending.response_json, pending.arena,
ctx->http_request_id, ctx->connection_close);
state->send_ordered_response(conn, ctx->sequence_id, http_response,
std::move(pending.arena),
ctx->connection_close);
}
}
}
static thread_local std::vector<PipelineEntry> g_batch_entries;
void HttpHandler::on_batch_complete(std::span<Connection *const> batch) {
// Count commit, status, and health check requests
for (auto conn : batch) {
auto *state = static_cast<HttpConnectionState *>(conn->user_data);
for (auto &req : state->queue) {
// Assign sequence ID for response ordering
int64_t sequence_id = state->get_next_sequence_id();
req.sequence_id = sequence_id;
// Create HttpResponseContext for this request
auto *ctx = req.arena.allocate<HttpResponseContext>(1);
ctx->sequence_id = sequence_id;
ctx->http_request_id = req.http_request_id;
ctx->connection_close = req.connection_close;
RouteMatch route_match;
auto parse_result =
ApiUrlParser::parse(req.method, const_cast<char *>(req.url.data()),
static_cast<int>(req.url.size()), route_match);
if (parse_result != ParseResult::Success) {
// Handle malformed URL encoding
auto json_response = R"({"error":"Malformed URL encoding"})";
auto http_response =
format_json_response(400, json_response, req.arena, 0, true);
state->send_ordered_response(*conn, ctx->sequence_id, http_response,
std::move(req.arena),
ctx->connection_close);
break;
}
req.route = route_match.route;
// Route to appropriate handler
switch (req.route) {
case HttpRoute::GetVersion:
handle_get_version(*conn, req);
break;
case HttpRoute::PostCommit:
handle_post_commit(*conn, req);
break;
case HttpRoute::GetSubscribe:
handle_get_subscribe(*conn, req);
break;
case HttpRoute::GetStatus:
handle_get_status(*conn, req, route_match);
break;
case HttpRoute::PutRetention:
handle_put_retention(*conn, req, route_match);
break;
case HttpRoute::GetRetention:
handle_get_retention(*conn, req, route_match);
break;
case HttpRoute::DeleteRetention:
handle_delete_retention(*conn, req, route_match);
break;
case HttpRoute::GetMetrics:
handle_get_metrics(*conn, req);
break;
case HttpRoute::GetOk:
handle_get_ok(*conn, req);
break;
case HttpRoute::NotFound:
default:
handle_not_found(*conn, req);
break;
}
// Create CommitEntry for commit requests
if (req.route == HttpRoute::PostCommit && req.commit_request &&
req.parsing_commit && req.basic_validation_passed) {
g_batch_entries.emplace_back(CommitEntry(conn->get_weak_ref(), ctx,
req.commit_request.get(),
std::move(req.arena)));
}
// Create StatusEntry for status requests
else if (req.route == HttpRoute::GetStatus) {
g_batch_entries.emplace_back(StatusEntry(conn->get_weak_ref(), ctx,
req.status_request_id,
std::move(req.arena)));
}
// Create HealthCheckEntry for health check requests
else if (req.route == HttpRoute::GetOk) {
g_batch_entries.emplace_back(
HealthCheckEntry(conn->get_weak_ref(), ctx, std::move(req.arena)));
}
// Create GetVersionEntry for version requests
else if (req.route == HttpRoute::GetVersion) {
g_batch_entries.emplace_back(
GetVersionEntry(conn->get_weak_ref(), ctx, std::move(req.arena),
commit_pipeline_.get_committed_version()));
}
}
state->queue.clear();
}
// Send requests to commit pipeline in batch. Batching here reduces
// contention on the way into the pipeline.
if (!g_batch_entries.empty()) {
commit_pipeline_.submit_batch(g_batch_entries);
}
g_batch_entries.clear();
}
void HttpHandler::on_data_arrived(std::string_view data, Connection &conn) {
auto *state = static_cast<HttpConnectionState *>(conn.user_data);
assert(state);
// TODO: Enforce the configured max_request_size_bytes limit here.
// Should track cumulative bytes received for the current HTTP request
// and send 413 Request Entity Too Large if limit is exceeded.
// This prevents DoS attacks via oversized HTTP requests.
// Parse HTTP data with llhttp
for (;;) {
enum llhttp_errno err =
llhttp_execute(&state->parser, data.data(), data.size());
if (err == HPE_PAUSED) {
assert(state->pending.message_complete);
state->queue.push_back(std::move(state->pending));
state->pending = {};
int consumed = llhttp_get_error_pos(&state->parser) - data.data();
data = data.substr(consumed, data.size() - consumed);
llhttp_resume(&state->parser);
continue;
}
if (err == HPE_OK) {
break;
}
// Parse error - send response directly since this is before sequence
// assignment
auto json_response = R"({"error":"Bad request"})";
auto http_response =
format_json_response(400, json_response, state->pending.arena, 0, true);
state->send_ordered_response(conn, state->get_next_sequence_id(),
http_response, std::move(state->pending.arena),
true);
return;
}
}
// Route handlers (basic implementations)
void HttpHandler::handle_get_version(Connection &, HttpRequestState &) {
version_counter.inc();
// Sent to commit pipeline
}
void HttpHandler::handle_post_commit(Connection &conn,
HttpRequestState &state) {
commit_counter.inc();
// Check if streaming parse was successful
if (!state.commit_request || !state.parsing_commit) {
auto json_response = R"({"error":"Parse failed"})";
auto http_response =
format_json_response(400, json_response, state.arena,
state.http_request_id, state.connection_close);
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
return;
}
const CommitRequest &commit_request = *state.commit_request;
// Perform basic validation that doesn't require serialization (done on I/O
// threads)
bool valid = true;
std::string_view error_msg;
// Check that we have at least one operation
if (commit_request.operations().empty()) {
valid = false;
error_msg = "Commit request must contain at least one operation";
}
// Check leader_id is not empty
if (valid && commit_request.leader_id().empty()) {
valid = false;
error_msg = "Commit request must specify a leader_id";
}
// Check operations are well-formed
if (valid) {
for (const auto &op : commit_request.operations()) {
if (op.param1.empty()) {
valid = false;
error_msg = "Operation key cannot be empty";
break;
}
if (op.type == Operation::Type::Write && op.param2.empty()) {
valid = false;
error_msg = "Write operation value cannot be empty";
break;
}
}
}
if (!valid) {
auto json_response =
format(state.arena, R"({"error":"%.*s"})",
static_cast<int>(error_msg.size()), error_msg.data());
auto http_response =
format_json_response(400, json_response, state.arena,
state.http_request_id, state.connection_close);
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
return;
}
// Basic validation passed - mark for 4-stage pipeline processing
state.basic_validation_passed = true;
// Response will be sent after 4-stage pipeline processing is complete
}
void HttpHandler::handle_get_subscribe(Connection &conn,
HttpRequestState &state) {
// TODO: Implement subscription streaming
auto json_response =
R"({"message":"Subscription endpoint - streaming not yet implemented"})";
auto http_response =
format_json_response(200, json_response, state.arena,
state.http_request_id, state.connection_close);
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
}
void HttpHandler::handle_get_status(Connection &conn, HttpRequestState &state,
const RouteMatch &route_match) {
status_counter.inc();
// Status requests are processed through the pipeline
// Response will be generated in the sequence stage
// This handler extracts request_id from query parameters and prepares for
// pipeline processing
const auto &request_id =
route_match.params[static_cast<int>(ApiParameterKey::RequestId)];
if (!request_id) {
auto json_response =
R"({"error":"Missing required query parameter: request_id"})";
auto http_response =
format_json_response(400, json_response, state.arena,
state.http_request_id, state.connection_close);
// Add directly to response queue with proper sequencing
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
return;
}
if (request_id->empty()) {
auto json_response = R"({"error":"Empty request_id parameter"})";
auto http_response =
format_json_response(400, json_response, state.arena,
state.http_request_id, state.connection_close);
// Add directly to response queue with proper sequencing
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
return;
}
// Store the request_id in the state for the pipeline
state.status_request_id = *request_id;
// Ready for pipeline processing
}
void HttpHandler::handle_put_retention(Connection &conn,
HttpRequestState &state,
const RouteMatch &) {
// TODO: Parse retention policy from body and store
auto json_response = R"({"policy_id":"example","status":"created"})";
auto http_response =
format_json_response(200, json_response, state.arena,
state.http_request_id, state.connection_close);
// Send through reorder queue and preprocessing to maintain proper ordering
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
}
void HttpHandler::handle_get_retention(Connection &conn,
HttpRequestState &state,
const RouteMatch &) {
// TODO: Extract policy_id from URL or return all policies
auto json_response = R"({"policies":[]})";
auto http_response =
format_json_response(200, json_response, state.arena,
state.http_request_id, state.connection_close);
// Send through reorder queue and preprocessing to maintain proper ordering
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
}
void HttpHandler::handle_delete_retention(Connection &conn,
HttpRequestState &state,
const RouteMatch &) {
// TODO: Extract policy_id from URL and delete
auto json_response = R"({"policy_id":"example","status":"deleted"})";
auto http_response =
format_json_response(200, json_response, state.arena,
state.http_request_id, state.connection_close);
// Send through reorder queue and preprocessing to maintain proper ordering
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
}
void HttpHandler::handle_get_metrics(Connection &conn,
HttpRequestState &state) {
metrics_counter.inc();
auto metrics_span = metric::render(state.arena);
// Calculate total size for the response body
size_t total_size = 0;
for (const auto &sv : metrics_span) {
total_size += sv.size();
}
// Build HTTP response with metrics data
auto result =
state.arena.allocate_span<std::string_view>(metrics_span.size() + 1);
auto out = result.begin();
// Build HTTP headers
std::string_view headers;
if (state.connection_close) {
headers = static_format(
state.arena, "HTTP/1.1 200 OK\r\n",
"Content-Type: text/plain; version=0.0.4\r\n",
"Content-Length: ", static_cast<uint64_t>(total_size), "\r\n",
"X-Response-ID: ", static_cast<int64_t>(state.http_request_id), "\r\n",
"Connection: close\r\n", "\r\n");
} else {
headers = static_format(
state.arena, "HTTP/1.1 200 OK\r\n",
"Content-Type: text/plain; version=0.0.4\r\n",
"Content-Length: ", static_cast<uint64_t>(total_size), "\r\n",
"X-Response-ID: ", static_cast<int64_t>(state.http_request_id), "\r\n",
"Connection: keep-alive\r\n", "\r\n");
}
*out++ = headers;
for (auto sv : metrics_span) {
*out++ = sv;
}
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, result,
std::move(state.arena),
state.connection_close);
}
void HttpHandler::handle_get_ok(Connection &, HttpRequestState &) {
ok_counter.inc();
// Health check requests are processed through the pipeline
// Response will be generated in the release stage after pipeline processing
}
void HttpHandler::handle_not_found(Connection &conn, HttpRequestState &state) {
not_found_counter.inc();
auto json_response = R"({"error":"Not found"})";
auto http_response =
format_json_response(404, json_response, state.arena,
state.http_request_id, state.connection_close);
auto *conn_state = static_cast<HttpConnectionState *>(conn.user_data);
conn_state->send_ordered_response(conn, state.sequence_id, http_response,
std::move(state.arena),
state.connection_close);
}
void HttpConnectionState::send_ordered_response(
Connection &conn, int64_t sequence_id,
std::span<std::string_view> http_response, Arena arena,
bool close_connection) {
// Add to reorder queue with proper sequencing
ready_responses[sequence_id] =
ResponseData{http_response, std::move(arena), close_connection};
// Process ready responses in order and send via append_bytes
auto iter = ready_responses.begin();
while (iter != ready_responses.end() &&
iter->first == next_sequence_to_send) {
auto &[sequence_id, response_data] = *iter;
// Send through append_bytes which handles write interest
conn.append_bytes(response_data.data, std::move(response_data.arena),
response_data.connection_close
? ConnectionShutdown::WriteOnly
: ConnectionShutdown::None);
next_sequence_to_send++;
iter = ready_responses.erase(iter);
}
}
std::span<std::string_view>
HttpHandler::format_response(int status_code, std::string_view content_type,
std::string_view body, Arena &response_arena,
int64_t http_request_id, bool close_connection) {
// Status text
std::string_view status_text;
switch (status_code) {
case 200:
status_text = "OK";
break;
case 400:
status_text = "Bad Request";
break;
case 404:
status_text = "Not Found";
break;
case 500:
status_text = "Internal Server Error";
break;
default:
status_text = "Unknown";
break;
}
const char *connection_header = close_connection ? "close" : "keep-alive";
auto response = response_arena.allocate_span<std::string_view>(1);
response[0] =
format(response_arena,
"HTTP/1.1 %d %.*s\r\n"
"Content-Type: %.*s\r\n"
"Content-Length: %zu\r\n"
"X-Response-ID: %ld\r\n"
"Connection: %s\r\n"
"\r\n%.*s",
status_code, static_cast<int>(status_text.size()),
status_text.data(), static_cast<int>(content_type.size()),
content_type.data(), body.size(), http_request_id,
connection_header, static_cast<int>(body.size()), body.data());
return response;
}
std::span<std::string_view> HttpHandler::format_json_response(
int status_code, std::string_view json, Arena &response_arena,
int64_t http_request_id, bool close_connection) {
return format_response(status_code, "application/json", json, response_arena,
http_request_id, close_connection);
}
// llhttp callbacks
int HttpHandler::onUrl(llhttp_t *parser, const char *at, size_t length) {
auto *state = static_cast<HttpRequestState *>(parser->data);
// Accumulate URL data
state->url.append(at, length);
return 0;
}
int HttpHandler::onHeaderField(llhttp_t *parser, const char *at,
size_t length) {
auto *state = static_cast<HttpRequestState *>(parser->data);
// Accumulate header field data
state->current_header_field_buf.append(at, length);
return 0;
}
int HttpHandler::onHeaderFieldComplete(llhttp_t *parser) {
auto *state = static_cast<HttpRequestState *>(parser->data);
state->header_field_complete = true;
return 0;
}
int HttpHandler::onHeaderValue(llhttp_t *parser, const char *at,
size_t length) {
auto *state = static_cast<HttpRequestState *>(parser->data);
// Accumulate header value data
state->current_header_value_buf.append(at, length);
return 0;
}
int HttpHandler::onHeaderValueComplete(llhttp_t *parser) {
auto *state = static_cast<HttpRequestState *>(parser->data);
if (!state->header_field_complete) {
// Field is not complete yet, wait
return 0;
}
// Now we have complete header field and value
const auto &field = state->current_header_field_buf;
const auto &value = state->current_header_value_buf;
// Check for Connection header
if (field.size() == 10 && strncasecmp(field.data(), "connection", 10) == 0) {
if (value.size() == 5 && strncasecmp(value.data(), "close", 5) == 0) {
state->connection_close = true;
}
}
// Check for X-Request-Id header
if (field.size() == 12 &&
strncasecmp(field.data(), "x-request-id", 12) == 0) {
int64_t id = 0;
for (char c : value) {
if (c >= '0' && c <= '9') {
id = id * 10 + (c - '0');
}
}
state->http_request_id = id;
}
// Clear buffers for next header
state->current_header_field_buf.clear();
state->current_header_value_buf.clear();
state->header_field_complete = false;
return 0;
}
int HttpHandler::onHeadersComplete(llhttp_t *parser) {
auto *state = static_cast<HttpRequestState *>(parser->data);
state->headers_complete = true;
// Get HTTP method
const char *method_str =
llhttp_method_name(static_cast<llhttp_method_t>(parser->method));
state->method = std::string_view(method_str);
// Check if this looks like a POST to /v1/commit to initialize streaming
// parser
if (state->method == "POST" && state->url.find("/v1/commit") == 0) {
// Initialize streaming commit request parsing
state->commit_parser = state->arena.construct<JsonCommitRequestParser>();
state->commit_request = state->arena.construct<CommitRequest>();
state->parsing_commit =
state->commit_parser->begin_streaming_parse(*state->commit_request);
if (!state->parsing_commit) {
return -1; // Signal parsing error to llhttp
}
}
return 0;
}
int HttpHandler::onBody(llhttp_t *parser, const char *at, size_t length) {
auto *state = static_cast<HttpRequestState *>(parser->data);
if (state->parsing_commit && state->commit_parser) {
// Stream data to commit request parser
auto status =
state->commit_parser->parse_chunk(const_cast<char *>(at), length);
if (status == CommitRequestParser::ParseStatus::Error) {
return -1; // Signal parsing error to llhttp
}
}
return 0;
}
int HttpHandler::onMessageComplete(llhttp_t *parser) {
auto *state = static_cast<HttpRequestState *>(parser->data);
state->message_complete = true;
return HPE_PAUSED;
}