diff --git a/tools/load_tester.cpp b/tools/load_tester.cpp index a1e71ac..024e686 100644 --- a/tools/load_tester.cpp +++ b/tools/load_tester.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -116,19 +117,34 @@ int getConnectFdUnix(const char *socket_name) { return sfd; } -constexpr int kConcurrency = 1000; -constexpr int kRequestsPerConnection = 1; -constexpr std::string_view kRequestFmt = - "GET /ok HTTP/1.1\r\nX-Request-Id: %" PRIu64 "\r\n\r\n"; +struct Config { + int concurrency = 1000; + int requests_per_connection = 1; + int connect_threads = 0; // 0 means auto-calculate + int network_threads = 0; // 0 means auto-calculate + int event_batch_size = 32; + int connection_buf_size = 1024; + std::string host = ""; + std::string port = ""; + std::string unix_socket = "weaseldb.sock"; + int stats_interval = 1; + int duration = 0; // 0 means run forever + bool use_tcp = false; +}; -constexpr int kConnectThreads = std::min(2, kConcurrency); -constexpr int kNetworkThreads = std::min(8, kConcurrency); - -constexpr int kEventBatchSize = 32; -constexpr int kConnectionBufSize = 1024; +Config g_config; sem_t connectionLimit; +// Shutdown mechanism +std::atomic g_shutdown{false}; + +void signal_handler(int sig) { + if (sig == SIGTERM || sig == SIGINT) { + g_shutdown.store(true, std::memory_order_relaxed); + } +} + } // namespace // Connection lifecycle. Only one of these is the case at a time @@ -144,13 +160,15 @@ struct Connection { static std::atomic requestId; - char buf[kRequestFmt.size() + 64]; + char buf[1024]; // Increased size for dynamic request format std::string_view request; uint64_t currentRequestId; void initRequest() { currentRequestId = requestId.fetch_add(1, std::memory_order_relaxed); - int len = snprintf(buf, sizeof(buf), kRequestFmt.data(), currentRequestId); + int len = snprintf(buf, sizeof(buf), + "GET /ok HTTP/1.1\r\nX-Request-Id: %" PRIu64 "\r\n\r\n", + currentRequestId); if (len == -1 || len > int(sizeof(buf))) { abort(); } @@ -184,8 +202,9 @@ struct Connection { bool readBytes() { for (;;) { - char buf[kConnectionBufSize]; - int r = read(fd, buf, sizeof(buf)); + char buf[1024]; // Use a reasonable default, configurable via g_config + int buf_size = std::min(int(sizeof(buf)), g_config.connection_buf_size); + int r = read(fd, buf, buf_size); if (r == -1) { if (errno == EINTR) { continue; @@ -205,7 +224,7 @@ struct Connection { error = true; return true; } - if (responsesReceived == kRequestsPerConnection) { + if (responsesReceived == g_config.requests_per_connection) { return true; } } @@ -230,7 +249,7 @@ struct Connection { request = request.substr(w, request.size() - w); if (request.empty()) { ++requestsSent; - if (requestsSent == kRequestsPerConnection) { + if (requestsSent == g_config.requests_per_connection) { return true; } } @@ -290,8 +309,157 @@ const llhttp_settings_t Connection::settings = []() { return settings; }(); -int main() { +void print_usage(const char *program_name) { + printf("Usage: %s [OPTIONS]\n", program_name); + printf("\nConnection options:\n"); + printf(" --host HOST TCP server hostname (default: none, uses " + "unix socket)\n"); + printf(" --port PORT TCP server port (default: none)\n"); + printf(" --unix-socket PATH Unix socket path (default: " + "weaseldb.sock)\n"); + printf("\nLoad options:\n"); + printf(" --concurrency N Number of concurrent connections " + "(default: 1000)\n"); + printf(" --requests-per-conn N Number of requests per connection " + "(default: 1)\n"); + printf("\nThread options:\n"); + printf(" --connect-threads N Number of connect threads (default: " + "auto)\n"); + printf(" --network-threads N Number of network threads (default: " + "auto)\n"); + printf("\nPerformance options:\n"); + printf(" --event-batch-size N Epoll event batch size (default: 32)\n"); + printf( + " --connection-buf-size N Connection buffer size (default: 1024)\n"); + printf("\nStatistics options:\n"); + printf(" --stats-interval N Statistics display interval in seconds " + "(default: 1)\n"); + printf("\nTiming options:\n"); + printf(" --duration N Run for N seconds, 0 means run forever " + "(default: 0)\n"); + printf("\nOther options:\n"); + printf(" --help Show this help message\n"); +} + +void parse_args(int argc, char *argv[]) { + static struct option long_options[] = { + {"host", required_argument, 0, 'h'}, + {"port", required_argument, 0, 'p'}, + {"unix-socket", required_argument, 0, 'u'}, + {"concurrency", required_argument, 0, 'c'}, + {"requests-per-conn", required_argument, 0, 'r'}, + {"connect-threads", required_argument, 0, 'C'}, + {"network-threads", required_argument, 0, 'N'}, + {"event-batch-size", required_argument, 0, 'E'}, + {"connection-buf-size", required_argument, 0, 'B'}, + {"stats-interval", required_argument, 0, 'S'}, + {"duration", required_argument, 0, 'D'}, + {"help", no_argument, 0, '?'}, + {0, 0, 0, 0}}; + + int option_index = 0; + int c; + + while ((c = getopt_long(argc, argv, "h:p:u:c:r:C:N:E:B:S:D:?", long_options, + &option_index)) != -1) { + switch (c) { + case 'h': + g_config.host = optarg; + g_config.use_tcp = true; + break; + case 'p': + g_config.port = optarg; + g_config.use_tcp = true; + break; + case 'u': + g_config.unix_socket = optarg; + g_config.use_tcp = false; + break; + case 'c': + g_config.concurrency = atoi(optarg); + if (g_config.concurrency <= 0) { + fprintf(stderr, "Error: concurrency must be positive\n"); + exit(1); + } + break; + case 'r': + g_config.requests_per_connection = atoi(optarg); + if (g_config.requests_per_connection <= 0) { + fprintf(stderr, "Error: requests-per-conn must be positive\n"); + exit(1); + } + break; + case 'C': + g_config.connect_threads = atoi(optarg); + if (g_config.connect_threads < 0) { + fprintf(stderr, "Error: connect-threads must be non-negative\n"); + exit(1); + } + break; + case 'N': + g_config.network_threads = atoi(optarg); + if (g_config.network_threads < 0) { + fprintf(stderr, "Error: network-threads must be non-negative\n"); + exit(1); + } + break; + case 'E': + g_config.event_batch_size = atoi(optarg); + if (g_config.event_batch_size <= 0) { + fprintf(stderr, "Error: event-batch-size must be positive\n"); + exit(1); + } + break; + case 'B': + g_config.connection_buf_size = atoi(optarg); + if (g_config.connection_buf_size <= 0) { + fprintf(stderr, "Error: connection-buf-size must be positive\n"); + exit(1); + } + break; + case 'S': + g_config.stats_interval = atoi(optarg); + if (g_config.stats_interval <= 0) { + fprintf(stderr, "Error: stats-interval must be positive\n"); + exit(1); + } + break; + case 'D': + g_config.duration = atoi(optarg); + if (g_config.duration < 0) { + fprintf(stderr, "Error: duration must be non-negative\n"); + exit(1); + } + break; + case '?': + default: + print_usage(argv[0]); + exit(c == '?' ? 0 : 1); + } + } + + // Validation + if (g_config.use_tcp && (g_config.host.empty() || g_config.port.empty())) { + fprintf(stderr, "Error: Both --host and --port must be specified for TCP " + "connections\n"); + exit(1); + } + + // Auto-calculate thread counts if not specified + if (g_config.connect_threads == 0) { + g_config.connect_threads = std::min(2, g_config.concurrency); + } + if (g_config.network_threads == 0) { + g_config.network_threads = std::min(8, g_config.concurrency); + } +} + +int main(int argc, char *argv[]) { + parse_args(argc, argv); + signal(SIGPIPE, SIG_IGN); + signal(SIGTERM, signal_handler); + signal(SIGINT, signal_handler); int epollfd = epoll_create(/*ignored*/ 1); if (epollfd == -1) { @@ -299,7 +467,7 @@ int main() { abort(); } - int e = sem_init(&connectionLimit, 0, kConcurrency); + int e = sem_init(&connectionLimit, 0, g_config.concurrency); if (e == -1) { perror("sem_init"); abort(); @@ -308,16 +476,18 @@ int main() { std::vector threads; - for (int i = 0; i < kNetworkThreads; ++i) { + for (int i = 0; i < g_config.network_threads; ++i) { threads.emplace_back([epollfd, i]() { pthread_setname_np(pthread_self(), ("network-" + std::to_string(i)).c_str()); - for (;;) { - struct epoll_event events[kEventBatchSize]; + while (!g_shutdown.load(std::memory_order_relaxed)) { + struct epoll_event events[64]; // Use a reasonable max size + int batch_size = std::min(int(sizeof(events) / sizeof(events[0])), + g_config.event_batch_size); int eventCount; for (;;) { - eventCount = - epoll_wait(epollfd, events, kEventBatchSize, /*no timeout*/ -1); + eventCount = epoll_wait(epollfd, events, batch_size, + 100); // 100ms timeout to check shutdown if (eventCount == -1) { if (errno == EINTR) { continue; @@ -325,9 +495,18 @@ int main() { perror("epoll_wait"); abort(); // Keep abort for critical errors like server does } + if (eventCount == 0) { + // Timeout - check shutdown flag + break; + } break; } + if (eventCount == 0) { + // Timeout occurred, continue to check shutdown flag + continue; + } + for (int i = 0; i < eventCount; ++i) { std::unique_ptr conn{ static_cast(events[i].data.ptr)}; @@ -384,11 +563,11 @@ int main() { }); } - for (int i = 0; i < kConnectThreads; ++i) { + for (int i = 0; i < g_config.connect_threads; ++i) { threads.emplace_back([epollfd, i, &connectionId]() { pthread_setname_np(pthread_self(), ("connect-" + std::to_string(i)).c_str()); - for (;;) { + while (!g_shutdown.load(std::memory_order_relaxed)) { int e; { e = sem_wait(&connectionLimit); @@ -397,8 +576,12 @@ int main() { abort(); } } - // int fd = getConnectFd("127.0.0.1", "4569"); - int fd = getConnectFdUnix("weaseldb.sock"); + int fd; + if (g_config.use_tcp) { + fd = getConnectFd(g_config.host.c_str(), g_config.port.c_str()); + } else { + fd = getConnectFdUnix(g_config.unix_socket.c_str()); + } if (fd == -1) { continue; // Connection failed, try again } @@ -423,14 +606,27 @@ int main() { }); } - for (double prevTime = now(), + double startTime = now(); + for (double prevTime = startTime, prevConnections = connectionId.load(std::memory_order_relaxed); - ;) { - sleep(1); + !g_shutdown.load(std::memory_order_relaxed);) { + sleep(g_config.stats_interval); + double currTime = now(); double currConnections = connectionId.load(std::memory_order_relaxed); printf("req/s: %f\n", (currConnections - prevConnections) / - (currTime - prevTime) * kRequestsPerConnection); + (currTime - prevTime) * + g_config.requests_per_connection); + + // Check if we should exit based on duration + if (g_config.duration > 0 && (currTime - startTime) >= g_config.duration) { + printf("Duration of %d seconds reached, exiting...\n", g_config.duration); + g_shutdown.store(true, std::memory_order_relaxed); + break; + } + + prevTime = currTime; + prevConnections = currConnections; } for (auto &thread : threads) {