Add safety comments

This commit is contained in:
2025-08-18 15:50:50 -04:00
parent 368ec721d5
commit 5bc78577c6
2 changed files with 621 additions and 66 deletions

551
src/Server.cpp Normal file
View File

@@ -0,0 +1,551 @@
#include <atomic>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <deque>
#include <inttypes.h>
#include <memory>
#include <netdb.h>
#include <netinet/tcp.h>
#include <signal.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <thread>
#include <unistd.h>
#include <vector>
#include <llhttp.h>
#define ENABLE_PERFETTO 1
#if ENABLE_PERFETTO
#include <perfetto.h>
#else
#define PERFETTO_DEFINE_CATEGORIES(...)
#define PERFETTO_TRACK_EVENT_STATIC_STORAGE(...)
#define TRACE_EVENT(...)
#endif
#include "ThreadPipeline.h"
#ifndef __has_feature
#define __has_feature(x) 0
#endif
PERFETTO_DEFINE_CATEGORIES(perfetto::Category("network").SetDescription(
"Network upload and download statistics"));
PERFETTO_TRACK_EVENT_STATIC_STORAGE();
namespace {
constexpr int kConnectionQueueLgDepth = 13;
constexpr int kDefaultPipelineBatchSize = 16;
constexpr std::string_view kResponseFmt =
"HTTP/1.1 204 No Content\r\nX-Response-Id: %" PRIu64 "\r\n\r\n";
constexpr int kAcceptThreads = 2;
constexpr int kNetworkThreads = 8;
constexpr int kEventBatchSize = 32;
constexpr int kConnectionBufSize = 1024;
constexpr uint32_t kMandatoryEpollFlags = EPOLLONESHOT;
// Adapted from getaddrinfo man page
int getListenFd(const char *node, const char *service) {
struct addrinfo hints;
struct addrinfo *result, *rp;
int sfd, s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */
hints.ai_canonname = nullptr;
hints.ai_addr = nullptr;
hints.ai_next = nullptr;
s = getaddrinfo(node, service, &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
abort();
}
/* getaddrinfo() returns a list of address structures.
Try each address until we successfully bind(2).
If socket(2) (or bind(2)) fails, we (close the socket
and) try the next address. */
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
int val = 1;
setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
}
freeaddrinfo(result); /* No longer needed */
if (rp == nullptr) { /* No address succeeded */
fprintf(stderr, "Could not bind\n");
abort();
}
int rv = listen(sfd, SOMAXCONN);
if (rv) {
perror("listen");
abort();
}
return sfd;
}
int getUnixSocket(const char *socket_name) {
int sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
perror("socket");
abort();
}
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, socket_name, sizeof(addr.sun_path) - 1);
int e = bind(sfd, (struct sockaddr *)&addr, sizeof(addr));
if (e == -1) {
perror("bind");
abort();
}
e = listen(sfd, SOMAXCONN);
if (e == -1) {
perror("listen");
abort();
}
return sfd;
}
int getAcceptFd(int listenFd, struct sockaddr *addr) {
socklen_t addrlen = sizeof(sockaddr);
int fd = accept4(listenFd, addr, &addrlen, SOCK_NONBLOCK);
return fd;
}
double now() {
struct timespec t;
int e = clock_gettime(CLOCK_MONOTONIC_RAW, &t);
if (e == -1) {
perror("clock_gettime");
abort();
}
return double(t.tv_sec) + (1e-9 * double(t.tv_nsec));
}
} // namespace
struct HttpRequest {
bool closeConnection = false;
int64_t id = 0;
};
struct HttpResponse {
HttpResponse(int64_t id, bool closeConnection)
: id(id), closeConnection(closeConnection) {
int len = snprintf(buf, sizeof(buf), kResponseFmt.data(), id);
if (len == -1 || len > int(sizeof(buf))) {
abort();
}
response = std::string_view{buf, size_t(len)};
}
int64_t const id;
bool closeConnection = false;
std::string_view response;
private:
char buf[kResponseFmt.size() + 64];
};
// Connection lifecycle. Only one of these is the case at a time
// - Created on an accept thread from a call to accept
// - Waiting on connection fd to be readable/writable
// - Owned by a network thread, which drains all readable and writable bytes
// - Owned by a thread in the request processing pipeline
// - Closed by a network thread according to http protocol
//
// Since only one thread owns a connection at a time, no synchronization is
// necessary
struct Connection {
static const llhttp_settings_t settings;
std::deque<HttpRequest> requests;
std::deque<HttpResponse> responses;
Connection(struct sockaddr addr, int fd, int64_t id)
: fd(fd), id(id), addr(addr) {
llhttp_init(&parser, HTTP_REQUEST, &settings);
parser.data = this;
}
~Connection() {
int e = close(fd);
if (e == -1) {
perror("close");
abort();
}
}
void readBytes() {
TRACE_EVENT("network", "read", "connectionId", id);
for (;;) {
char buf[kConnectionBufSize];
int r = read(fd, buf, sizeof(buf));
if (r == -1) {
if (errno == EINTR) {
continue;
}
if (errno == EAGAIN) {
return;
}
perror("read");
goto close_connection;
}
if (r == 0) {
llhttp_finish(&parser);
goto close_connection;
}
auto e = llhttp_execute(&parser, buf, r);
if (e != HPE_OK) {
fprintf(stderr, "Parse error: %s %s\n", llhttp_errno_name(e),
llhttp_get_error_reason(&parser));
goto close_connection;
}
}
close_connection:
requests.emplace_back();
requests.back().closeConnection = true;
}
bool writeBytes() {
TRACE_EVENT("network", "write", "connectionId", id);
while (!responses.empty()) {
auto &front = responses.front();
if (front.closeConnection) {
return true;
}
int w;
for (;;) {
w = write(fd, front.response.data(), front.response.size());
if (w == -1) {
if (errno == EINTR) {
continue;
}
if (errno == EAGAIN) {
return false;
}
perror("write");
return true;
}
break;
}
assert(w != 0);
front.response = front.response.substr(w, front.response.size() - w);
if (front.response.empty()) {
TRACE_EVENT("network", "write response",
perfetto::Flow::Global(front.id));
responses.pop_front();
}
}
return false;
}
const int fd;
const int64_t id;
#if __has_feature(thread_sanitizer)
void tsan_acquire() { tsan_sync.load(std::memory_order_acquire); }
void tsan_release() { tsan_sync.store(0, std::memory_order_release); }
std::atomic<int> tsan_sync;
#else
void tsan_acquire() {}
void tsan_release() {}
#endif
private:
template <int (Connection::*Method)()> static int callback(llhttp_t *parser) {
auto &self = *static_cast<Connection *>(parser->data);
return (self.*Method)();
}
template <int (Connection::*Method)(const char *, size_t)>
static int callback(llhttp_t *parser, const char *at, size_t length) {
auto &self = *static_cast<Connection *>(parser->data);
return (self.*Method)(at, length);
}
uint64_t requestId = 0;
int on_header_value(const char *data, size_t s) {
for (int i = 0; i < int(s); ++i) {
requestId = requestId * 10 + data[i] - '0';
}
return 0;
}
int on_message_complete() {
requests.emplace_back();
requests.back().id = requestId;
TRACE_EVENT("network", "read request", perfetto::Flow::Global(requestId));
requestId = 0;
return 0;
}
int messages = 0;
llhttp_t parser;
struct sockaddr addr;
};
const llhttp_settings_t Connection::settings = []() {
llhttp_settings_t settings;
llhttp_settings_init(&settings);
settings.on_message_complete = callback<&Connection::on_message_complete>;
settings.on_header_value = callback<&Connection::on_header_value>;
return settings;
}();
int main() {
signal(SIGPIPE, SIG_IGN);
#if ENABLE_PERFETTO
perfetto::TracingInitArgs args;
args.backends |= perfetto::kSystemBackend;
perfetto::Tracing::Initialize(args);
perfetto::TrackEvent::Register();
#endif
// int sockfd = getListenFd("0.0.0.0", "4569");
int sockfd = getUnixSocket("multithread-epoll.socket");
std::vector<std::thread> threads;
int epollfd = epoll_create(/*ignored*/ 1);
if (epollfd == -1) {
perror("epoll_create");
abort();
}
ThreadPipeline<std::unique_ptr<Connection>> pipeline{kConnectionQueueLgDepth,
{1, 1, 1}};
// Request processing pipeline threads
threads.emplace_back([&pipeline]() {
pthread_setname_np(pthread_self(), "pipeline-0-0");
for (;;) {
auto guard = pipeline.acquire(0, 0, kDefaultPipelineBatchSize);
for (auto &conn : guard.batch) {
assert(!conn->requests.empty());
while (!conn->requests.empty()) {
auto &front = conn->requests.front();
TRACE_EVENT("network", "forward", perfetto::Flow::Global(front.id));
conn->responses.emplace_back(front.id, front.closeConnection);
conn->requests.pop_front();
}
}
}
});
// Request processing pipeline threads
threads.emplace_back([&pipeline]() {
pthread_setname_np(pthread_self(), "pipeline-1-0");
std::deque<
std::tuple<ThreadPipeline<std::unique_ptr<Connection>>::StageGuard,
double, int64_t>>
queue;
int64_t batchNum = 0;
double lastBatch = now();
for (;;) {
if (queue.size() < 10) {
while (now() - lastBatch < 5e-3) {
usleep(1000);
}
auto guard = pipeline.acquire(1, 0, /*maxBatch=*/0,
/*mayBlock*/ queue.empty());
lastBatch = now();
if (!guard.batch.empty()) {
queue.emplace_back(std::move(guard), now() + 50e-3, batchNum);
TRACE_EVENT("network", "startBatch",
perfetto::Flow::ProcessScoped(batchNum));
for ([[maybe_unused]] auto &conn : std::get<0>(queue.back()).batch) {
for (auto const &r : conn->responses) {
TRACE_EVENT("network", "start", perfetto::Flow::Global(r.id));
}
}
++batchNum;
}
}
if (queue.size() > 0 && std::get<1>(queue.front()) <= now()) {
TRACE_EVENT("network", "finishBatch",
perfetto::Flow::ProcessScoped(std::get<2>(queue.front())));
for ([[maybe_unused]] auto &conn : std::get<0>(queue.front()).batch) {
for (auto const &r : conn->responses) {
TRACE_EVENT("network", "finish", perfetto::Flow::Global(r.id));
}
}
queue.pop_front();
}
}
});
// Request processing pipeline threads
threads.emplace_back([epollfd, &pipeline]() {
pthread_setname_np(pthread_self(), "pipeline-2-0");
for (;;) {
auto guard = pipeline.acquire(2, 0, kDefaultPipelineBatchSize);
for (auto &conn : guard.batch) {
for (auto const &r : conn->responses) {
TRACE_EVENT("network", "forward", perfetto::Flow::Global(r.id));
}
struct epoll_event event{};
assert(conn->requests.empty());
assert(!conn->responses.empty());
event.events = EPOLLIN | EPOLLOUT | kMandatoryEpollFlags;
const int fd = conn->fd;
auto *c = conn.release();
c->tsan_release();
event.data.ptr = c;
int e = epoll_ctl(epollfd, EPOLL_CTL_MOD, fd, &event);
if (e == -1) {
perror("epoll_ctl");
abort();
}
}
}
});
// Network threads
for (int i = 0; i < kNetworkThreads; ++i) {
threads.emplace_back([epollfd, i, &pipeline]() {
pthread_setname_np(pthread_self(),
("network-" + std::to_string(i)).c_str());
std::vector<std::unique_ptr<Connection>> batch;
int64_t requestsDropped = 0;
for (;;) {
batch.clear();
struct epoll_event events[kEventBatchSize]{};
int eventCount;
for (;;) {
eventCount =
epoll_wait(epollfd, events, kEventBatchSize, /*no timeout*/ -1);
if (eventCount == -1) {
if (errno == EINTR) {
continue;
}
perror("epoll_wait");
abort();
}
break;
}
for (int i = 0; i < eventCount; ++i) {
std::unique_ptr<Connection> conn{
static_cast<Connection *>(events[i].data.ptr)};
conn->tsan_acquire();
events[i].data.ptr = nullptr;
const int fd = conn->fd;
if (events[i].events & EPOLLERR) {
// Done with connection
continue;
}
if (events[i].events & EPOLLOUT) {
// Write bytes, maybe close connection
bool finished = conn->writeBytes();
if (finished) {
// Done with connection
continue;
}
}
if (events[i].events & EPOLLIN) {
conn->readBytes();
}
if (conn->requests.empty()) {
// Transfer back to epoll instance. This thread or another thread
// will wake when fd is ready
events[i].events = EPOLLIN | kMandatoryEpollFlags;
if (!conn->responses.empty()) {
events[i].events |= EPOLLOUT;
}
conn->tsan_release();
events[i].data.ptr = conn.release();
int e = epoll_ctl(epollfd, EPOLL_CTL_MOD, fd, &events[i]);
if (e == -1) {
perror("epoll_ctl");
abort();
}
continue;
}
assert(!conn->requests.empty());
// Transfer to request processing pipeline
batch.push_back(std::move(conn));
}
for (int moved = 0; moved < int(batch.size());) {
auto guard = pipeline.push(batch.size() - moved, /*block=*/false);
if (guard.batch.empty()) {
requestsDropped += batch.size() - moved;
printf("Network thread %d: Queue full. Total requests dropped: "
"%" PRId64 "\n",
i, requestsDropped);
break;
}
std::move(batch.data() + moved,
batch.data() + moved + guard.batch.size(),
guard.batch.begin());
moved += guard.batch.size();
}
}
});
}
std::atomic<int64_t> connectionId{0};
for (int i = 0; i < kAcceptThreads; ++i) {
threads.emplace_back([epollfd, i, sockfd, &connectionId]() {
pthread_setname_np(pthread_self(),
("accept-" + std::to_string(i)).c_str());
// Call accept in a loop
for (;;) {
struct sockaddr addr;
int fd = getAcceptFd(sockfd, &addr);
if (fd == -1) {
perror("accept4");
continue;
}
auto conn = std::make_unique<Connection>(
addr, fd, connectionId.fetch_add(1, std::memory_order_relaxed));
// Post to epoll instance
struct epoll_event event{};
event.events = EPOLLIN | kMandatoryEpollFlags;
conn->tsan_release();
event.data.ptr = conn.release();
int e = epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event);
if (e == -1) {
perror("epoll_ctl");
abort();
}
}
});
}
for (auto &t : threads) {
t.join();
}
}

View File

@@ -31,6 +31,7 @@ void signal_handler(int sig) {
if (sig == SIGTERM || sig == SIGINT) {
if (shutdown_eventfd != -1) {
uint64_t val = 1;
// write() is async-signal-safe per POSIX - safe to use in signal handler
if (write(shutdown_eventfd, &val, sizeof(val)) == -1) {
abort(); // Critical failure - can't signal shutdown
}
@@ -281,7 +282,11 @@ int main(int argc, char *argv[]) {
struct epoll_event shutdown_event;
shutdown_event.events = EPOLLIN;
shutdown_event.data.fd = shutdown_eventfd;
epoll_ctl(network_epollfd, EPOLL_CTL_ADD, shutdown_eventfd, &shutdown_event);
if (epoll_ctl(network_epollfd, EPOLL_CTL_ADD, shutdown_eventfd,
&shutdown_event) == -1) {
perror("epoll_ctl add shutdown event");
abort();
}
std::atomic<int64_t> connectionId{0};
@@ -289,10 +294,11 @@ int main(int argc, char *argv[]) {
int networkThreads = config->server.network_threads;
for (int i = 0; i < networkThreads; ++i) {
threads.emplace_back(
[network_epollfd, i,
max_request_size = config->server.max_request_size_bytes,
event_batch_size = config->server.event_batch_size]() {
threads.emplace_back([network_epollfd, i,
max_request_size =
config->server.max_request_size_bytes,
event_batch_size =
config->server.event_batch_size]() {
pthread_setname_np(pthread_self(),
("network-" + std::to_string(i)).c_str());
std::vector<struct epoll_event> events(event_batch_size);
@@ -310,7 +316,7 @@ int main(int argc, char *argv[]) {
for (int i = 0; i < eventCount; ++i) {
// Check for shutdown event
if (events[i].data.fd == shutdown_eventfd) {
// Don't read - let other threads see it too
// Don't read eventfd - all threads need to see shutdown signal
return;
}
@@ -318,8 +324,7 @@ int main(int argc, char *argv[]) {
std::unique_ptr<Connection> conn{
static_cast<Connection *>(events[i].data.ptr)};
conn->tsan_acquire();
events[i].data.ptr =
nullptr; // Clear epoll pointer (we own it now)
events[i].data.ptr = nullptr; // Clear epoll pointer (we own it now)
const int fd = conn->fd;
if (events[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) {
@@ -346,12 +351,11 @@ int main(int argc, char *argv[]) {
}
// Transfer ownership back to epoll: unique_ptr -> raw pointer
conn->tsan_release();
events[i].data.ptr =
conn.release(); // epoll now owns the connection
events[i].data.ptr = conn.release(); // epoll now owns the connection
int e = epoll_ctl(network_epollfd, EPOLL_CTL_MOD, fd, &events[i]);
if (e == -1) {
perror("epoll_ctl");
abort(); // Process termination - OS cleans up leaked connection
abort(); // Connection leaked on abort() is acceptable - OS cleanup
}
}
}
@@ -403,7 +407,7 @@ int main(int argc, char *argv[]) {
for (int j = 0; j < ready; ++j) {
if (events[j].data.fd == shutdown_eventfd) {
// Don't read - let other threads see it too
// Don't read eventfd - all threads need to see shutdown signal
return;
}
@@ -425,7 +429,8 @@ int main(int argc, char *argv[]) {
}
// Check connection limit (0 means unlimited). Limiting
// connections is best effort but it's the best we can do.
// connections is best effort - race condition between check and
// increment is acceptable for this use case
if (max_connections > 0 &&
activeConnections.load(std::memory_order_relaxed) >=
max_connections) {
@@ -447,14 +452,13 @@ int main(int argc, char *argv[]) {
int e = epoll_ctl(network_epollfd, EPOLL_CTL_ADD, fd, &event);
if (e == -1) {
perror("epoll_ctl");
abort();
abort(); // Connection leaked on abort() is acceptable - OS
// cleanup
}
}
}
}
}
close(accept_epollfd);
});
}