Files
weaseldb/src/main.cpp

393 lines
12 KiB
C++

#include "config.hpp"
#include <atomic>
#include <cassert>
#include <csignal>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <deque>
#include <fcntl.h>
#include <inttypes.h>
#include <iostream>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <thread>
#include <unistd.h>
#include <vector>
std::atomic<bool> shutdown_requested{false};
#ifndef __has_feature
#define __has_feature(x) 0
#endif
void signal_handler(int sig) {
if (sig == SIGTERM || sig == SIGINT) {
shutdown_requested.store(true, std::memory_order_relaxed);
}
}
// Adapted from getaddrinfo man page
int getListenFd(const char *node, const char *service) {
struct addrinfo hints;
struct addrinfo *result, *rp;
int sfd, s;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* stream socket */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */
hints.ai_canonname = nullptr;
hints.ai_addr = nullptr;
hints.ai_next = nullptr;
s = getaddrinfo(node, service, &hints, &result);
if (s != 0) {
fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
abort();
}
/* getaddrinfo() returns a list of address structures.
Try each address until we successfully bind(2).
If socket(2) (or bind(2)) fails, we (close the socket
and) try the next address. */
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1) {
continue;
}
int val = 1;
setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
// Set socket to non-blocking for graceful shutdown
int flags = fcntl(sfd, F_GETFL, 0);
if (flags != -1) {
fcntl(sfd, F_SETFL, flags | O_NONBLOCK);
}
if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; /* Success */
}
close(sfd);
}
freeaddrinfo(result); /* No longer needed */
if (rp == nullptr) { /* No address succeeded */
fprintf(stderr, "Could not bind\n");
abort();
}
int rv = listen(sfd, SOMAXCONN);
if (rv) {
perror("listen");
abort();
}
return sfd;
}
int getAcceptFd(int listenFd, struct sockaddr_storage *addr) {
// Use sockaddr_storage (not sockaddr) to handle both IPv4 and IPv6
socklen_t addrlen = sizeof(sockaddr_storage);
int fd = accept4(listenFd, (struct sockaddr *)addr, &addrlen, SOCK_NONBLOCK);
return fd;
}
// Connection lifecycle. Only one of these is the case at a time
// - Created on an accept thread from a call to accept
// - Waiting on connection fd to be readable/writable
// - Owned by a network thread, which drains readable and writable bytes
// - Owned by a thread in the request processing pipeline
// - Closed by a network thread according to http protocol
//
// Since only one thread owns a connection at a time, no synchronization is
// necessary
// Connection ownership model:
// - Created by accept thread, transferred to epoll via raw pointer
// - Network threads claim ownership by wrapping raw pointer in unique_ptr
// - Network threads transfer back to epoll by releasing unique_ptr to raw
// pointer
// - RAII cleanup happens if network thread doesn't transfer back
struct Connection {
const int fd;
const int64_t id;
struct sockaddr_storage addr; // sockaddr_storage handles IPv4/IPv6
Connection(struct sockaddr_storage addr, int fd, int64_t id)
: fd(fd), id(id), addr(addr) {}
~Connection() {
int e = close(fd);
if (e == -1) {
perror("close");
abort();
}
}
struct Task {
std::string s;
bool closeConnection{false};
int written = 0;
};
std::deque<Task> tasks;
void readBytes(size_t max_request_size) {
// Use smaller buffer size but respect max request size
// TODO revisit
size_t buf_size = std::min(size_t(4096), max_request_size);
std::vector<char> buf(buf_size);
for (;;) {
int r = read(fd, buf.data(), buf.size());
if (r == -1) {
if (errno == EINTR) {
continue;
}
if (errno == EAGAIN) {
return;
}
perror("read");
goto close_connection;
}
if (r == 0) {
goto close_connection;
}
// "pump parser"
// TODO revisit
tasks.emplace_back(std::string{buf.data(), size_t(r)});
}
close_connection:
tasks.emplace_back(std::string{}, true);
}
bool writeBytes() {
while (!tasks.empty()) {
auto &front = tasks.front();
if (front.closeConnection) {
return true;
}
int w;
for (;;) {
w = write(fd, front.s.data() + front.written,
front.s.size() - front.written);
if (w == -1) {
if (errno == EINTR) {
continue; // Standard practice: retry on signal interruption
}
if (errno == EAGAIN) {
return false;
}
perror("write");
return true;
}
break;
}
assert(w != 0);
front.written += w;
if (front.written == int(front.s.size())) {
tasks.pop_front();
}
}
return false;
}
#if __has_feature(thread_sanitizer)
void tsan_acquire() { tsan_sync.load(std::memory_order_acquire); }
void tsan_release() { tsan_sync.store(0, std::memory_order_release); }
std::atomic<int> tsan_sync;
#else
void tsan_acquire() {}
void tsan_release() {}
#endif
};
int main(int argc, char *argv[]) {
std::string config_file = "config.toml";
if (argc > 1) {
config_file = argv[1];
}
auto config = weaseldb::ConfigParser::load_from_file(config_file);
if (!config) {
std::cerr << "Failed to load config from: " << config_file << std::endl;
std::cerr << "Using default configuration..." << std::endl;
config = weaseldb::Config{};
}
std::cout << "Configuration loaded successfully:" << std::endl;
std::cout << "Server bind address: " << config->server.bind_address
<< std::endl;
std::cout << "Server port: " << config->server.port << std::endl;
std::cout << "Max request size: " << config->server.max_request_size_bytes
<< " bytes" << std::endl;
std::cout << "Accept threads: " << config->server.accept_threads << std::endl;
std::cout << "Network threads: " << config->server.network_threads
<< " (0 = auto)" << std::endl;
std::cout << "Event batch size: " << config->server.event_batch_size
<< std::endl;
std::cout << "Min request ID length: " << config->commit.min_request_id_length
<< std::endl;
std::cout << "Request ID retention: "
<< config->commit.request_id_retention_hours.count() << " hours"
<< std::endl;
std::cout << "Subscription buffer size: "
<< config->subscription.max_buffer_size_bytes << " bytes"
<< std::endl;
std::cout << "Keepalive interval: "
<< config->subscription.keepalive_interval.count() << " seconds"
<< std::endl;
signal(SIGPIPE, SIG_IGN);
signal(SIGTERM, signal_handler);
signal(SIGINT, signal_handler);
int sockfd = getListenFd(config->server.bind_address.c_str(),
std::to_string(config->server.port).c_str());
std::vector<std::thread> threads;
int epollfd = epoll_create(/*ignored*/ 1);
if (epollfd == -1) {
perror("epoll_create");
abort();
}
// Network threads - use config value, fallback to hardware concurrency
int networkThreads = config->server.network_threads;
if (networkThreads == 0) {
// TODO revisit
networkThreads = std::thread::hardware_concurrency();
if (networkThreads == 0)
networkThreads = 1; // ultimate fallback
}
// Event batch size from configuration
for (int i = 0; i < networkThreads; ++i) {
threads.emplace_back(
[epollfd, i, max_request_size = config->server.max_request_size_bytes,
event_batch_size = config->server.event_batch_size]() {
pthread_setname_np(pthread_self(),
("network-" + std::to_string(i)).c_str());
while (!shutdown_requested.load(std::memory_order_relaxed)) {
std::vector<struct epoll_event> events(event_batch_size);
int eventCount;
for (;;) {
eventCount = epoll_wait(epollfd, events.data(), event_batch_size,
1000 /* 1 second timeout */);
if (eventCount == -1) {
if (errno == EINTR) {
continue;
}
perror("epoll_wait");
abort();
}
break;
}
if (eventCount == 0) {
// Timeout occurred, check shutdown flag again
continue;
}
for (int i = 0; i < eventCount; ++i) {
// Take ownership from epoll: raw pointer -> unique_ptr
std::unique_ptr<Connection> conn{
static_cast<Connection *>(events[i].data.ptr)};
conn->tsan_acquire();
events[i].data.ptr =
nullptr; // Clear epoll pointer (we own it now)
const int fd = conn->fd;
if (events[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) {
// Connection closed or error occurred - unique_ptr destructor
// cleans up
continue;
}
if (events[i].events & EPOLLIN) {
conn->readBytes(max_request_size);
}
if (events[i].events & EPOLLOUT) {
bool done = conn->writeBytes();
if (done) {
continue;
}
}
if (conn->tasks.empty()) {
// Transfer back to epoll instance. This thread or another
// thread will wake when fd is ready
events[i].events = EPOLLIN | EPOLLONESHOT | EPOLLRDHUP;
} else {
events[i].events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP;
}
// Transfer ownership back to epoll: unique_ptr -> raw pointer
conn->tsan_release();
events[i].data.ptr =
conn.release(); // epoll now owns the connection
int e = epoll_ctl(epollfd, EPOLL_CTL_MOD, fd, &events[i]);
if (e == -1) {
perror("epoll_ctl");
abort(); // Process termination - OS cleans up leaked connection
}
}
}
});
}
std::atomic<int64_t> connectionId{0};
// Accept threads from configuration
int acceptThreads = config->server.accept_threads;
for (int i = 0; i < acceptThreads; ++i) {
threads.emplace_back([epollfd, i, sockfd, &connectionId]() {
pthread_setname_np(pthread_self(),
("accept-" + std::to_string(i)).c_str());
// Call accept in a loop
while (!shutdown_requested.load(std::memory_order_relaxed)) {
struct sockaddr_storage addr;
int fd = getAcceptFd(sockfd, &addr);
if (fd == -1) {
if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) {
// TODO revisit
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
perror("accept4");
continue;
}
auto conn = std::make_unique<Connection>(
addr, fd, connectionId.fetch_add(1, std::memory_order_relaxed));
// Transfer new connection to epoll ownership
struct epoll_event event{};
event.events = EPOLLIN | EPOLLONESHOT |
EPOLLRDHUP; // Listen for reads and disconnects
conn->tsan_release();
event.data.ptr = conn.release(); // epoll now owns the connection
int e = epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event);
if (e == -1) {
perror("epoll_ctl");
abort(); // Process termination - OS cleans up leaked connection
}
}
});
}
for (auto &t : threads) {
t.join();
}
return 0;
}