Files
weaseldb/tools/load_tester.cpp
Andrew Noyes 8ccb02f450 We don't want to close the connection on EPOLLRDHUP
We'll rely on the errors from reads and writes to close the connections
2025-08-20 14:09:39 -04:00

668 lines
19 KiB
C++

#include <atomic>
#include <cassert>
#include <cerrno>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <fcntl.h>
#include <getopt.h>
#include <inttypes.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <semaphore.h>
#include <signal.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <thread>
#include <time.h>
#include <unistd.h>
#include <vector>
#include <llhttp.h>
// Use shared perfetto categories
#include "../src/perfetto_categories.hpp"
PERFETTO_TRACK_EVENT_STATIC_STORAGE();
namespace {
double now() {
struct timespec t;
int e = clock_gettime(CLOCK_MONOTONIC_RAW, &t);
if (e == -1) {
perror("clock_gettime");
abort();
}
return double(t.tv_sec) + (1e-9 * double(t.tv_nsec));
}
void fd_set_nb(int fd) {
errno = 0;
int flags = fcntl(fd, F_GETFL, 0);
if (errno) {
perror("fcntl");
abort();
}
flags |= O_NONBLOCK;
errno = 0;
(void)fcntl(fd, F_SETFL, flags);
if (errno) {
perror("fcntl");
abort();
}
}
int getConnectFd(const char *node, const char *service) {
struct addrinfo hints;
struct addrinfo *result, *rp;
int sfd, s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
s = getaddrinfo(node, service, &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
abort();
}
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
if (connect(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
}
freeaddrinfo(result); /* No longer needed */
if (rp == nullptr) { /* No address succeeded */
return -1;
}
fd_set_nb(sfd);
return sfd;
}
int getConnectFdUnix(const char *socket_name) {
int sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
perror("socket");
abort();
}
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, socket_name, sizeof(addr.sun_path) - 1);
int e = connect(sfd, (struct sockaddr *)&addr, sizeof(addr));
if (e == -1) {
perror("connect");
abort();
}
fd_set_nb(sfd);
return sfd;
}
struct Config {
int concurrency = 1000;
int requests_per_connection = 1;
int connect_threads = 0; // 0 means auto-calculate
int network_threads = 0; // 0 means auto-calculate
int event_batch_size = 32;
int connection_buf_size = 1024;
std::string host = "";
std::string port = "";
std::string unix_socket = "weaseldb.sock";
int stats_interval = 1;
int duration = 0; // 0 means run forever
bool use_tcp = false;
};
Config g_config;
sem_t connectionLimit;
// Shutdown mechanism
std::atomic<bool> g_shutdown{false};
void signal_handler(int sig) {
if (sig == SIGTERM || sig == SIGINT) {
g_shutdown.store(true, std::memory_order_relaxed);
}
}
} // namespace
// Connection lifecycle. Only one of these is the case at a time
// - Created on a connect thread from a call to connect
// - Waiting on connection fd to be readable/writable
// - Owned by a network thread, which drains all readable and writable bytes
// - Closed by a network thread according to http protocol
//
// Since only one thread owns a connection at a time, no synchronization is
// necessary
struct Connection {
static const llhttp_settings_t settings;
static std::atomic<uint64_t> requestId;
char buf[1024]; // Increased size for dynamic request format
std::string_view request;
uint64_t currentRequestId;
void initRequest() {
currentRequestId = requestId.fetch_add(1, std::memory_order_relaxed);
int len = snprintf(buf, sizeof(buf),
"GET /ok HTTP/1.1\r\nX-Request-Id: %" PRIu64 "\r\n\r\n",
currentRequestId);
if (len == -1 || len > int(sizeof(buf))) {
abort();
}
request = std::string_view{buf, size_t(len)};
}
Connection(int fd, int64_t id) : fd(fd), id(id) {
llhttp_init(&parser, HTTP_RESPONSE, &settings);
parser.data = this;
initRequest();
}
// Match server's connection state management
bool hasMessages() const { return !request.empty(); }
bool error = false;
~Connection() {
int e = close(fd);
if (e == -1) {
perror("close");
abort();
}
{
e = sem_post(&connectionLimit);
if (e == -1) {
perror("sem_post");
abort();
}
}
}
bool readBytes() {
for (;;) {
char buf[1024]; // Use a reasonable default, configurable via g_config
int buf_size = std::min(int(sizeof(buf)), g_config.connection_buf_size);
int r = read(fd, buf, buf_size);
if (r == -1) {
if (errno == EINTR) {
continue;
}
if (errno == EAGAIN) {
return false;
}
}
if (r == 0) {
llhttp_finish(&parser);
return true;
}
auto e = llhttp_execute(&parser, buf, r);
if (e != HPE_OK) {
fprintf(stderr, "Parse error: %s %s\n", llhttp_errno_name(e),
llhttp_get_error_reason(&parser));
error = true;
return true;
}
if (responsesReceived == g_config.requests_per_connection) {
return true;
}
}
}
bool writeBytes() {
for (;;) {
int w;
w = write(fd, request.data(), request.size());
if (w == -1) {
if (errno == EINTR) {
continue;
}
if (errno == EAGAIN) {
return false;
}
perror("write");
error = true;
return true;
}
assert(w != 0);
request = request.substr(w, request.size() - w);
if (request.empty()) {
++requestsSent;
TRACE_EVENT("http", "Send request",
perfetto::Flow::Global(currentRequestId));
if (requestsSent == g_config.requests_per_connection) {
return true;
}
}
}
}
const int fd;
const int64_t id;
#if __has_feature(thread_sanitizer)
void tsan_acquire() { tsan_sync.load(std::memory_order_acquire); }
void tsan_release() { tsan_sync.store(0, std::memory_order_release); }
std::atomic<int> tsan_sync;
#else
void tsan_acquire() {}
void tsan_release() {}
#endif
private:
int requestsSent = 0;
int responsesReceived = 0;
template <int (Connection::*Method)()> static int callback(llhttp_t *parser) {
auto &self = *static_cast<Connection *>(parser->data);
return (self.*Method)();
}
template <int (Connection::*Method)(const char *, size_t)>
static int callback(llhttp_t *parser, const char *at, size_t length) {
auto &self = *static_cast<Connection *>(parser->data);
return (self.*Method)(at, length);
}
uint64_t responseId = 0;
int on_header_value(const char *data, size_t s) {
for (int i = 0; i < int(s); ++i) {
responseId = responseId * 10 + data[i] - '0';
}
return 0;
}
int on_message_complete() {
TRACE_EVENT("http", "Receive response", perfetto::Flow::Global(responseId));
responseId = 0;
++responsesReceived;
initRequest();
return 0;
}
llhttp_t parser;
};
std::atomic<uint64_t> Connection::requestId = {};
const llhttp_settings_t Connection::settings = []() {
llhttp_settings_t settings;
llhttp_settings_init(&settings);
settings.on_message_complete = callback<&Connection::on_message_complete>;
settings.on_header_value = callback<&Connection::on_header_value>;
return settings;
}();
void print_usage(const char *program_name) {
printf("Usage: %s [OPTIONS]\n", program_name);
printf("\nConnection options:\n");
printf(" --host HOST TCP server hostname (default: none, uses "
"unix socket)\n");
printf(" --port PORT TCP server port (default: none)\n");
printf(" --unix-socket PATH Unix socket path (default: "
"weaseldb.sock)\n");
printf("\nLoad options:\n");
printf(" --concurrency N Number of concurrent connections "
"(default: 1000)\n");
printf(" --requests-per-conn N Number of requests per connection "
"(default: 1)\n");
printf("\nThread options:\n");
printf(" --connect-threads N Number of connect threads (default: "
"auto)\n");
printf(" --network-threads N Number of network threads (default: "
"auto)\n");
printf("\nPerformance options:\n");
printf(" --event-batch-size N Epoll event batch size (default: 32)\n");
printf(
" --connection-buf-size N Connection buffer size (default: 1024)\n");
printf("\nStatistics options:\n");
printf(" --stats-interval N Statistics display interval in seconds "
"(default: 1)\n");
printf("\nTiming options:\n");
printf(" --duration N Run for N seconds, 0 means run forever "
"(default: 0)\n");
printf("\nOther options:\n");
printf(" --help Show this help message\n");
}
void parse_args(int argc, char *argv[]) {
static struct option long_options[] = {
{"host", required_argument, 0, 'h'},
{"port", required_argument, 0, 'p'},
{"unix-socket", required_argument, 0, 'u'},
{"concurrency", required_argument, 0, 'c'},
{"requests-per-conn", required_argument, 0, 'r'},
{"connect-threads", required_argument, 0, 'C'},
{"network-threads", required_argument, 0, 'N'},
{"event-batch-size", required_argument, 0, 'E'},
{"connection-buf-size", required_argument, 0, 'B'},
{"stats-interval", required_argument, 0, 'S'},
{"duration", required_argument, 0, 'D'},
{"help", no_argument, 0, '?'},
{0, 0, 0, 0}};
int option_index = 0;
int c;
while ((c = getopt_long(argc, argv, "h:p:u:c:r:C:N:E:B:S:D:?", long_options,
&option_index)) != -1) {
switch (c) {
case 'h':
g_config.host = optarg;
g_config.use_tcp = true;
break;
case 'p':
g_config.port = optarg;
g_config.use_tcp = true;
break;
case 'u':
g_config.unix_socket = optarg;
g_config.use_tcp = false;
break;
case 'c':
g_config.concurrency = atoi(optarg);
if (g_config.concurrency <= 0) {
fprintf(stderr, "Error: concurrency must be positive\n");
exit(1);
}
break;
case 'r':
g_config.requests_per_connection = atoi(optarg);
if (g_config.requests_per_connection <= 0) {
fprintf(stderr, "Error: requests-per-conn must be positive\n");
exit(1);
}
break;
case 'C':
g_config.connect_threads = atoi(optarg);
if (g_config.connect_threads < 0) {
fprintf(stderr, "Error: connect-threads must be non-negative\n");
exit(1);
}
break;
case 'N':
g_config.network_threads = atoi(optarg);
if (g_config.network_threads < 0) {
fprintf(stderr, "Error: network-threads must be non-negative\n");
exit(1);
}
break;
case 'E':
g_config.event_batch_size = atoi(optarg);
if (g_config.event_batch_size <= 0) {
fprintf(stderr, "Error: event-batch-size must be positive\n");
exit(1);
}
break;
case 'B':
g_config.connection_buf_size = atoi(optarg);
if (g_config.connection_buf_size <= 0) {
fprintf(stderr, "Error: connection-buf-size must be positive\n");
exit(1);
}
break;
case 'S':
g_config.stats_interval = atoi(optarg);
if (g_config.stats_interval <= 0) {
fprintf(stderr, "Error: stats-interval must be positive\n");
exit(1);
}
break;
case 'D':
g_config.duration = atoi(optarg);
if (g_config.duration < 0) {
fprintf(stderr, "Error: duration must be non-negative\n");
exit(1);
}
break;
case '?':
default:
print_usage(argv[0]);
exit(c == '?' ? 0 : 1);
}
}
// Validation
if (g_config.use_tcp && (g_config.host.empty() || g_config.port.empty())) {
fprintf(stderr, "Error: Both --host and --port must be specified for TCP "
"connections\n");
exit(1);
}
// Auto-calculate thread counts if not specified
if (g_config.connect_threads == 0) {
g_config.connect_threads = std::min(2, g_config.concurrency);
}
if (g_config.network_threads == 0) {
g_config.network_threads = std::min(8, g_config.concurrency);
}
}
int main(int argc, char *argv[]) {
#if ENABLE_PERFETTO
perfetto::TracingInitArgs args;
args.backends |= perfetto::kSystemBackend;
perfetto::Tracing::Initialize(args);
perfetto::TrackEvent::Register();
#endif
parse_args(argc, argv);
// Print configuration
printf("Load Tester Configuration:\n");
if (g_config.use_tcp) {
printf(" Connection: TCP %s:%s\n", g_config.host.c_str(),
g_config.port.c_str());
} else {
printf(" Connection: Unix socket %s\n", g_config.unix_socket.c_str());
}
printf(" Concurrency: %d connections\n", g_config.concurrency);
printf(" Requests per connection: %d\n", g_config.requests_per_connection);
printf(" Connect threads: %d\n", g_config.connect_threads);
printf(" Network threads: %d\n", g_config.network_threads);
printf(" Event batch size: %d\n", g_config.event_batch_size);
printf(" Connection buffer size: %d bytes\n", g_config.connection_buf_size);
printf(" Stats interval: %d seconds\n", g_config.stats_interval);
if (g_config.duration > 0) {
printf(" Duration: %d seconds\n", g_config.duration);
} else {
printf(" Duration: unlimited\n");
}
printf("\n");
signal(SIGPIPE, SIG_IGN);
signal(SIGTERM, signal_handler);
signal(SIGINT, signal_handler);
int epollfd = epoll_create(/*ignored*/ 1);
if (epollfd == -1) {
perror("epoll_create");
abort();
}
int e = sem_init(&connectionLimit, 0, g_config.concurrency);
if (e == -1) {
perror("sem_init");
abort();
}
std::atomic<int64_t> connectionId{0};
std::vector<std::thread> threads;
for (int i = 0; i < g_config.network_threads; ++i) {
threads.emplace_back([epollfd, i]() {
pthread_setname_np(pthread_self(),
("network-" + std::to_string(i)).c_str());
while (!g_shutdown.load(std::memory_order_relaxed)) {
struct epoll_event events[64]; // Use a reasonable max size
int batch_size = std::min(int(sizeof(events) / sizeof(events[0])),
g_config.event_batch_size);
int eventCount;
for (;;) {
eventCount = epoll_wait(epollfd, events, batch_size,
100); // 100ms timeout to check shutdown
if (eventCount == -1) {
if (errno == EINTR) {
continue;
}
perror("epoll_wait");
abort(); // Keep abort for critical errors like server does
}
if (eventCount == 0) {
// Timeout - check shutdown flag
break;
}
break;
}
if (eventCount == 0) {
// Timeout occurred, continue to check shutdown flag
continue;
}
for (int i = 0; i < eventCount; ++i) {
std::unique_ptr<Connection> conn{
static_cast<Connection *>(events[i].data.ptr)};
conn->tsan_acquire();
events[i].data.ptr = nullptr;
const int fd = conn->fd;
if (events[i].events & EPOLLERR) {
continue; // Let unique_ptr destructor clean up
}
if (events[i].events & EPOLLOUT) {
bool finished = conn->writeBytes();
if (conn->error) {
continue;
}
if (finished) {
int e = shutdown(conn->fd, SHUT_WR);
if (e == -1) {
perror("shutdown");
conn->error = true;
continue;
}
}
}
if (events[i].events & EPOLLIN) {
bool finished = conn->readBytes();
if (conn->error) {
continue;
}
if (finished) {
continue;
}
}
// Transfer back to epoll instance. This thread or another thread
// will wake when fd is ready
if (conn->hasMessages()) {
events[i].events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP;
} else {
events[i].events = EPOLLIN | EPOLLONESHOT | EPOLLRDHUP;
}
conn->tsan_release();
Connection *raw_conn = conn.release();
events[i].data.ptr = raw_conn;
int e = epoll_ctl(epollfd, EPOLL_CTL_MOD, fd, &events[i]);
if (e == -1) {
perror("epoll_ctl MOD");
delete raw_conn; // Clean up on failure like server
continue;
}
}
}
});
}
for (int i = 0; i < g_config.connect_threads; ++i) {
threads.emplace_back([epollfd, i, &connectionId]() {
pthread_setname_np(pthread_self(),
("connect-" + std::to_string(i)).c_str());
while (!g_shutdown.load(std::memory_order_relaxed)) {
int e;
{
e = sem_wait(&connectionLimit);
if (e == -1) {
perror("sem_wait");
abort();
}
}
int fd;
if (g_config.use_tcp) {
fd = getConnectFd(g_config.host.c_str(), g_config.port.c_str());
} else {
fd = getConnectFdUnix(g_config.unix_socket.c_str());
}
if (fd == -1) {
continue; // Connection failed, try again
}
// Create connection with proper ownership like server
auto conn = std::make_unique<Connection>(
fd, connectionId.fetch_add(1, std::memory_order_relaxed));
// Add to epoll with proper events matching server pattern
struct epoll_event event{};
event.events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP;
conn->tsan_release();
Connection *raw_conn = conn.release();
event.data.ptr = raw_conn;
e = epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event);
if (e == -1) {
perror("epoll_ctl ADD");
delete raw_conn; // Clean up on failure like server
continue;
}
}
});
}
double startTime = now();
for (double prevTime = startTime,
prevConnections = connectionId.load(std::memory_order_relaxed);
!g_shutdown.load(std::memory_order_relaxed);) {
sleep(g_config.stats_interval);
double currTime = now();
double currConnections = connectionId.load(std::memory_order_relaxed);
printf("req/s: %f\n", (currConnections - prevConnections) /
(currTime - prevTime) *
g_config.requests_per_connection);
// Check if we should exit based on duration
if (g_config.duration > 0 && (currTime - startTime) >= g_config.duration) {
printf("Duration of %d seconds reached, exiting...\n", g_config.duration);
g_shutdown.store(true, std::memory_order_relaxed);
break;
}
prevTime = currTime;
prevConnections = currConnections;
}
for (auto &thread : threads) {
thread.join();
}
}