From 4044f0a8713572b69a4c1a74db19d5dcbad7b5c6 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Tue, 19 Aug 2025 17:57:07 -0400 Subject: [PATCH] Add unix socket listening mode --- CMakeLists.txt | 4 ++ src/config.cpp | 24 +++++++-- src/config.hpp | 2 + src/http_handler.cpp | 123 ++++++++++++++++++++++++++++--------------- src/http_handler.hpp | 14 +++-- src/main.cpp | 11 ++-- src/server.cpp | 75 +++++++++++++++++++++++--- 7 files changed, 192 insertions(+), 61 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b41d002..8218bf6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -200,6 +200,10 @@ add_dependencies(debug_arena generate_json_tokens) target_link_libraries(debug_arena weaseljson simdutf::simdutf) target_include_directories(debug_arena PRIVATE src) +# Load tester +add_executable(load_tester tools/load_tester.cpp) +target_link_libraries(load_tester Threads::Threads) + add_test(NAME arena_allocator_tests COMMAND test_arena_allocator) add_test(NAME commit_request_tests COMMAND test_commit_request) add_test(NAME http_handler_tests COMMAND test_http_handler) diff --git a/src/config.cpp b/src/config.cpp index cae6893..339e195 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -81,6 +81,7 @@ void ConfigParser::parse_server_config(const auto &toml_data, parse_section(toml_data, "server", [&](const auto &srv) { parse_field(srv, "bind_address", config.bind_address); parse_field(srv, "port", config.port); + parse_field(srv, "unix_socket_path", config.unix_socket_path); parse_field(srv, "max_request_size_bytes", config.max_request_size_bytes); parse_field(srv, "accept_threads", config.accept_threads); parse_field(srv, "network_threads", config.network_threads); @@ -114,11 +115,24 @@ bool ConfigParser::validate_config(const Config &config) { bool valid = true; // Validate server configuration - if (config.server.port <= 0 || config.server.port > 65535) { - std::cerr - << "Configuration error: server.port must be between 1 and 65535, got " - << config.server.port << std::endl; - valid = false; + if (config.server.unix_socket_path.empty()) { + // TCP mode validation + if (config.server.port <= 0 || config.server.port > 65535) { + std::cerr << "Configuration error: server.port must be between 1 and " + "65535, got " + << config.server.port << std::endl; + valid = false; + } + } else { + // Unix socket mode validation + if (config.server.unix_socket_path.length() > + 107) { // UNIX_PATH_MAX is typically 108 + std::cerr << "Configuration error: unix_socket_path too long (max 107 " + "chars), got " + << config.server.unix_socket_path.length() << " chars" + << std::endl; + valid = false; + } } if (config.server.max_request_size_bytes == 0) { diff --git a/src/config.hpp b/src/config.hpp index 8e3f43f..59b9c42 100644 --- a/src/config.hpp +++ b/src/config.hpp @@ -14,6 +14,8 @@ struct ServerConfig { std::string bind_address = "127.0.0.1"; /// TCP port number for the server to listen on int port = 8080; + /// Unix socket path (if specified, takes precedence over TCP) + std::string unix_socket_path; /// Maximum size in bytes for incoming HTTP requests (default: 1MB) size_t max_request_size_bytes = 1024 * 1024; /// Number of accept threads for handling incoming connections diff --git a/src/http_handler.cpp b/src/http_handler.cpp index daef138..db5ecdb 100644 --- a/src/http_handler.cpp +++ b/src/http_handler.cpp @@ -2,6 +2,7 @@ #include "arena_allocator.hpp" #include #include +#include // HttpConnectionState implementation HttpConnectionState::HttpConnectionState( @@ -45,7 +46,7 @@ void HttpHandler::on_data_arrived(std::string_view data, std::unique_ptr &conn_ptr) { auto *state = static_cast(conn_ptr->user_data); if (!state) { - sendErrorResponse(*conn_ptr, 500, "Internal server error"); + sendErrorResponse(*conn_ptr, 500, "Internal server error", true); return; } @@ -59,7 +60,7 @@ void HttpHandler::on_data_arrived(std::string_view data, llhttp_execute(&state->parser, data.data(), data.size()); if (err != HPE_OK) { - sendErrorResponse(*conn_ptr, 400, "Bad request"); + sendErrorResponse(*conn_ptr, 400, "Bad request", true); return; } @@ -94,6 +95,9 @@ void HttpHandler::on_data_arrived(std::string_view data, case HttpRoute::GET_metrics: handleGetMetrics(*conn_ptr, *state); break; + case HttpRoute::GET_ok: + handleGetOk(*conn_ptr, *state); + break; case HttpRoute::NotFound: default: handleNotFound(*conn_ptr, *state); @@ -127,6 +131,8 @@ HttpRoute HttpHandler::parseRoute(std::string_view method, } if (url == "/metrics") return HttpRoute::GET_metrics; + if (url == "/ok") + return HttpRoute::GET_ok; } else if (method == "POST") { if (url == "/v1/commit") return HttpRoute::POST_commit; @@ -142,71 +148,83 @@ HttpRoute HttpHandler::parseRoute(std::string_view method, } // Route handlers (basic implementations) -void HttpHandler::handleGetVersion( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handleGetVersion(Connection &conn, + const HttpConnectionState &state) { sendJsonResponse( conn, 200, - R"({"version":"0.0.1","leader":"node-1","committed_version":42})"); + R"({"version":"0.0.1","leader":"node-1","committed_version":42})", + state.connection_close); } -void HttpHandler::handlePostCommit( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handlePostCommit(Connection &conn, + const HttpConnectionState &state) { // TODO: Parse commit request from state.body and process sendJsonResponse( conn, 200, - R"({"request_id":"example","status":"committed","version":43})"); + R"({"request_id":"example","status":"committed","version":43})", + state.connection_close); } -void HttpHandler::handleGetSubscribe( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handleGetSubscribe(Connection &conn, + const HttpConnectionState &state) { // TODO: Implement subscription streaming sendJsonResponse( conn, 200, - R"({"message":"Subscription endpoint - streaming not yet implemented"})"); + R"({"message":"Subscription endpoint - streaming not yet implemented"})", + state.connection_close); } -void HttpHandler::handleGetStatus( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handleGetStatus(Connection &conn, + const HttpConnectionState &state) { // TODO: Extract request_id from URL and check status sendJsonResponse( conn, 200, - R"({"request_id":"example","status":"committed","version":43})"); + R"({"request_id":"example","status":"committed","version":43})", + state.connection_close); } -void HttpHandler::handlePutRetention( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handlePutRetention(Connection &conn, + const HttpConnectionState &state) { // TODO: Parse retention policy from body and store - sendJsonResponse(conn, 200, R"({"policy_id":"example","status":"created"})"); + sendJsonResponse(conn, 200, R"({"policy_id":"example","status":"created"})", + state.connection_close); } -void HttpHandler::handleGetRetention( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handleGetRetention(Connection &conn, + const HttpConnectionState &state) { // TODO: Extract policy_id from URL or return all policies - sendJsonResponse(conn, 200, R"({"policies":[]})"); + sendJsonResponse(conn, 200, R"({"policies":[]})", state.connection_close); } -void HttpHandler::handleDeleteRetention( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handleDeleteRetention(Connection &conn, + const HttpConnectionState &state) { // TODO: Extract policy_id from URL and delete - sendJsonResponse(conn, 200, R"({"policy_id":"example","status":"deleted"})"); + sendJsonResponse(conn, 200, R"({"policy_id":"example","status":"deleted"})", + state.connection_close); } -void HttpHandler::handleGetMetrics( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { +void HttpHandler::handleGetMetrics(Connection &conn, + const HttpConnectionState &state) { // TODO: Implement metrics collection and formatting sendResponse(conn, 200, "text/plain", - "# WeaselDB metrics\nweaseldb_requests_total 0\n"); + "# WeaselDB metrics\nweaseldb_requests_total 0\n", + state.connection_close); } -void HttpHandler::handleNotFound( - Connection &conn, [[maybe_unused]] const HttpConnectionState &state) { - sendErrorResponse(conn, 404, "Not found"); +void HttpHandler::handleGetOk(Connection &conn, + const HttpConnectionState &state) { + sendResponse(conn, 200, "text/plain", "OK", state.connection_close); +} + +void HttpHandler::handleNotFound(Connection &conn, + const HttpConnectionState &state) { + sendErrorResponse(conn, 404, "Not found", state.connection_close); } // HTTP utility methods void HttpHandler::sendResponse(Connection &conn, int status_code, std::string_view content_type, - std::string_view body) { + std::string_view body, bool close_connection) { [[maybe_unused]] ArenaAllocator &arena = conn.getArena(); // Build HTTP response using arena @@ -243,7 +261,14 @@ void HttpHandler::sendResponse(Connection &conn, int status_code, response += "Content-Length: "; response += std::to_string(body.size()); response += "\r\n"; - response += "Connection: keep-alive\r\n"; + + if (close_connection) { + response += "Connection: close\r\n"; + conn.closeAfterSend(); // Signal connection should be closed after sending + } else { + response += "Connection: keep-alive\r\n"; + } + response += "\r\n"; response += body; @@ -251,19 +276,21 @@ void HttpHandler::sendResponse(Connection &conn, int status_code, } void HttpHandler::sendJsonResponse(Connection &conn, int status_code, - std::string_view json) { - sendResponse(conn, status_code, "application/json", json); + std::string_view json, + bool close_connection) { + sendResponse(conn, status_code, "application/json", json, close_connection); } void HttpHandler::sendErrorResponse(Connection &conn, int status_code, - std::string_view message) { + std::string_view message, + bool close_connection) { [[maybe_unused]] ArenaAllocator &arena = conn.getArena(); std::string json = R"({"error":")"; json += message; json += R"("})"; - sendJsonResponse(conn, status_code, json); + sendJsonResponse(conn, status_code, json, close_connection); } // llhttp callbacks @@ -274,17 +301,27 @@ int HttpHandler::onUrl(llhttp_t *parser, const char *at, size_t length) { return 0; } -int HttpHandler::onHeaderField([[maybe_unused]] llhttp_t *parser, - [[maybe_unused]] const char *at, - [[maybe_unused]] size_t length) { - // TODO: Store headers if needed +int HttpHandler::onHeaderField(llhttp_t *parser, const char *at, + size_t length) { + auto *state = static_cast(parser->data); + // Store current header field name for processing in onHeaderValue + state->current_header_field = std::string_view(at, length); return 0; } -int HttpHandler::onHeaderValue([[maybe_unused]] llhttp_t *parser, - [[maybe_unused]] const char *at, - [[maybe_unused]] size_t length) { - // TODO: Store header values if needed +int HttpHandler::onHeaderValue(llhttp_t *parser, const char *at, + size_t length) { + auto *state = static_cast(parser->data); + std::string_view value(at, length); + + // Check for Connection header + if (state->current_header_field.size() == 10 && + strncasecmp(state->current_header_field.data(), "connection", 10) == 0) { + if (value.size() == 5 && strncasecmp(value.data(), "close", 5) == 0) { + state->connection_close = true; + } + } + return 0; } diff --git a/src/http_handler.hpp b/src/http_handler.hpp index 50b5b93..2e87b5c 100644 --- a/src/http_handler.hpp +++ b/src/http_handler.hpp @@ -19,6 +19,7 @@ enum class HttpRoute { GET_retention, DELETE_retention, GET_metrics, + GET_ok, NotFound }; @@ -38,7 +39,9 @@ struct HttpConnectionState { // Parse state bool headers_complete = false; bool message_complete = false; + bool connection_close = false; // Client requested connection close HttpRoute route = HttpRoute::NotFound; + std::string_view current_header_field; // Current header being parsed explicit HttpConnectionState(ArenaAllocator &arena); }; @@ -79,14 +82,17 @@ private: void handleDeleteRetention(Connection &conn, const HttpConnectionState &state); void handleGetMetrics(Connection &conn, const HttpConnectionState &state); + void handleGetOk(Connection &conn, const HttpConnectionState &state); void handleNotFound(Connection &conn, const HttpConnectionState &state); // HTTP utilities static void sendResponse(Connection &conn, int status_code, - std::string_view content_type, - std::string_view body); + std::string_view content_type, std::string_view body, + bool close_connection = false); static void sendJsonResponse(Connection &conn, int status_code, - std::string_view json); + std::string_view json, + bool close_connection = false); static void sendErrorResponse(Connection &conn, int status_code, - std::string_view message); + std::string_view message, + bool close_connection = false); }; \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 1203a62..5fa19c5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -80,9 +80,14 @@ int main(int argc, char *argv[]) { } std::cout << "Configuration loaded successfully:" << std::endl; - std::cout << "Server bind address: " << config->server.bind_address - << std::endl; - std::cout << "Server port: " << config->server.port << std::endl; + if (!config->server.unix_socket_path.empty()) { + std::cout << "Unix socket path: " << config->server.unix_socket_path + << std::endl; + } else { + std::cout << "Server bind address: " << config->server.bind_address + << std::endl; + std::cout << "Server port: " << config->server.port << std::endl; + } std::cout << "Max request size: " << config->server.max_request_size_bytes << " bytes" << std::endl; std::cout << "Accept threads: " << config->server.accept_threads << std::endl; diff --git a/src/server.cpp b/src/server.cpp index fc29bba..711b217 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include extern std::atomic activeConnections; @@ -145,9 +146,64 @@ void Server::setup_shutdown_pipe() { } int Server::create_listen_socket() { + int sfd; + + // Check if unix socket path is specified + if (!config_.server.unix_socket_path.empty()) { + // Create unix socket + sfd = socket(AF_UNIX, SOCK_STREAM, 0); + if (sfd == -1) { + perror("socket"); + throw std::runtime_error("Failed to create unix socket"); + } + + // Remove existing socket file if it exists + unlink(config_.server.unix_socket_path.c_str()); + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + + if (config_.server.unix_socket_path.length() >= sizeof(addr.sun_path)) { + close(sfd); + throw std::runtime_error("Unix socket path too long"); + } + + strncpy(addr.sun_path, config_.server.unix_socket_path.c_str(), + sizeof(addr.sun_path) - 1); + + // Set socket to non-blocking for graceful shutdown + int flags = fcntl(sfd, F_GETFL, 0); + if (flags == -1) { + perror("fcntl F_GETFL"); + close(sfd); + throw std::runtime_error("Failed to get socket flags"); + } + if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) { + perror("fcntl F_SETFL"); + close(sfd); + throw std::runtime_error("Failed to set socket non-blocking"); + } + + if (bind(sfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) { + perror("bind"); + close(sfd); + throw std::runtime_error("Failed to bind unix socket"); + } + + if (listen(sfd, SOMAXCONN) == -1) { + perror("listen"); + close(sfd); + throw std::runtime_error("Failed to listen on unix socket"); + } + + return sfd; + } + + // TCP socket creation (original code) struct addrinfo hints; struct addrinfo *result, *rp; - int sfd, s; + int s; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */ @@ -178,11 +234,13 @@ int Server::create_listen_socket() { continue; } - // Enable TCP_NODELAY for low latency - if (setsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) == -1) { - perror("setsockopt TCP_NODELAY"); - close(sfd); - continue; + // Enable TCP_NODELAY for low latency (only for TCP sockets) + if (rp->ai_family == AF_INET || rp->ai_family == AF_INET6) { + if (setsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) == -1) { + perror("setsockopt TCP_NODELAY"); + close(sfd); + continue; + } } // Set socket to non-blocking for graceful shutdown @@ -438,4 +496,9 @@ void Server::cleanup_resources() { close(listen_sockfd_); listen_sockfd_ = -1; } + + // Clean up unix socket file if it exists + if (!config_.server.unix_socket_path.empty()) { + unlink(config_.server.unix_socket_path.c_str()); + } }