Files
weaseldb/src/config.cpp

322 lines
12 KiB
C++

#include "config.hpp"
#include <iostream>
#include <toml.hpp>
namespace weaseldb {
std::optional<Config>
ConfigParser::load_from_file(const std::string &file_path) {
try {
const auto toml_data = toml::parse(file_path);
Config config;
parse_server_config(toml_data, config.server);
parse_commit_config(toml_data, config.commit);
parse_subscription_config(toml_data, config.subscription);
parse_benchmark_config(toml_data, config.benchmark);
if (!validate_config(config)) {
return std::nullopt;
}
return config;
} catch (const std::exception &e) {
std::cerr << "Error parsing config file '" << file_path << "': " << e.what()
<< std::endl;
return std::nullopt;
}
}
std::optional<Config>
ConfigParser::parse_toml_string(const std::string &toml_content) {
try {
const auto toml_data = toml::parse_str(toml_content);
Config config;
parse_server_config(toml_data, config.server);
parse_commit_config(toml_data, config.commit);
parse_subscription_config(toml_data, config.subscription);
parse_benchmark_config(toml_data, config.benchmark);
if (!validate_config(config)) {
return std::nullopt;
}
return config;
} catch (const std::exception &e) {
std::cerr << "Error parsing TOML content: " << e.what() << std::endl;
return std::nullopt;
}
}
// Generic configuration parsing utilities
template <typename T>
void ConfigParser::parse_field(const auto &section,
const std::string &field_name, T &target) {
if (section.contains(field_name)) {
target = toml::get<T>(section.at(field_name));
}
}
template <typename Rep, typename Period>
void ConfigParser::parse_duration_field(
const auto &section, const std::string &field_name,
std::chrono::duration<Rep, Period> &target) {
if (section.contains(field_name)) {
auto value = toml::get<int>(section.at(field_name));
target = std::chrono::duration<Rep, Period>{value};
}
}
void ConfigParser::parse_section(const auto &toml_data,
const std::string &section_name,
auto parse_func) {
if (toml_data.contains(section_name)) {
const auto &section = toml_data.at(section_name);
parse_func(section);
}
}
void ConfigParser::parse_server_config(const auto &toml_data,
ServerConfig &config) {
parse_section(toml_data, "server", [&](const auto &srv) {
// Parse interfaces array
if (srv.contains("interfaces")) {
auto interfaces = srv.at("interfaces");
if (interfaces.is_array()) {
for (const auto &iface : interfaces.as_array()) {
if (iface.contains("type")) {
std::string type = iface.at("type").as_string();
if (type == "tcp") {
std::string address = iface.at("address").as_string();
int port = iface.at("port").as_integer();
config.interfaces.push_back(ListenInterface::tcp(address, port));
} else if (type == "unix") {
std::string path = iface.at("path").as_string();
config.interfaces.push_back(ListenInterface::unix_socket(path));
}
}
}
}
}
// If no interfaces configured, use default TCP interface
if (config.interfaces.empty()) {
config.interfaces.push_back(ListenInterface::tcp("127.0.0.1", 8080));
}
parse_field(srv, "max_request_size_bytes", config.max_request_size_bytes);
parse_field(srv, "io_threads", config.io_threads);
// epoll_instances removed - now 1:1 with io_threads
parse_field(srv, "event_batch_size", config.event_batch_size);
parse_field(srv, "max_connections", config.max_connections);
parse_field(srv, "read_buffer_size", config.read_buffer_size);
// epoll_instances validation removed - now always equals io_threads
});
}
void ConfigParser::parse_commit_config(const auto &toml_data,
CommitConfig &config) {
parse_section(toml_data, "commit", [&](const auto &commit) {
parse_field(commit, "min_request_id_length", config.min_request_id_length);
parse_duration_field(commit, "request_id_retention_hours",
config.request_id_retention_hours);
parse_field(commit, "request_id_retention_versions",
config.request_id_retention_versions);
// Parse wait strategy
if (commit.contains("pipeline_wait_strategy")) {
std::string strategy_str =
toml::get<std::string>(commit.at("pipeline_wait_strategy"));
if (strategy_str == "WaitIfStageEmpty") {
config.pipeline_wait_strategy = WaitStrategy::WaitIfStageEmpty;
} else if (strategy_str == "WaitIfUpstreamIdle") {
config.pipeline_wait_strategy = WaitStrategy::WaitIfUpstreamIdle;
} else if (strategy_str == "Never") {
config.pipeline_wait_strategy = WaitStrategy::Never;
} else {
std::cerr << "Warning: Unknown pipeline_wait_strategy '" << strategy_str
<< "', using default (WaitIfUpstreamIdle)" << std::endl;
}
}
parse_field(commit, "pipeline_release_threads",
config.pipeline_release_threads);
});
}
void ConfigParser::parse_subscription_config(const auto &toml_data,
SubscriptionConfig &config) {
parse_section(toml_data, "subscription", [&](const auto &sub) {
parse_field(sub, "max_buffer_size_bytes", config.max_buffer_size_bytes);
parse_duration_field(sub, "keepalive_interval_seconds",
config.keepalive_interval);
});
}
void ConfigParser::parse_benchmark_config(const auto &toml_data,
BenchmarkConfig &config) {
parse_section(toml_data, "benchmark", [&](const auto &bench) {
parse_field(bench, "ok_resolve_iterations", config.ok_resolve_iterations);
});
}
bool ConfigParser::validate_config(const Config &config) {
bool valid = true;
// Validate server interfaces
if (config.server.interfaces.empty()) {
std::cerr << "Configuration error: no interfaces configured" << std::endl;
valid = false;
}
for (const auto &iface : config.server.interfaces) {
if (iface.type == ListenInterface::Type::TCP) {
if (iface.port <= 0 || iface.port > 65535) {
std::cerr << "Configuration error: TCP port must be between 1 and "
"65535, got "
<< iface.port << std::endl;
valid = false;
}
if (iface.address.empty()) {
std::cerr << "Configuration error: TCP address cannot be empty"
<< std::endl;
valid = false;
}
} else { // Unix socket
if (iface.path.empty()) {
std::cerr << "Configuration error: Unix socket path cannot be empty"
<< std::endl;
valid = false;
}
if (iface.path.length() > 107) { // UNIX_PATH_MAX is typically 108
std::cerr << "Configuration error: Unix socket path too long (max 107 "
"chars), got "
<< iface.path.length() << " chars" << std::endl;
valid = false;
}
}
}
if (config.server.max_request_size_bytes == 0) {
std::cerr << "Configuration error: server.max_request_size_bytes must be "
"greater than 0"
<< std::endl;
valid = false;
}
if (config.server.max_request_size_bytes > 100 * 1024 * 1024) { // 100MB limit
std::cerr << "Configuration error: server.max_request_size_bytes too large "
"(max 100MB), got "
<< config.server.max_request_size_bytes << " bytes" << std::endl;
valid = false;
}
if (config.server.io_threads < 1 || config.server.io_threads > 1000) {
std::cerr << "Configuration error: server.io_threads must be between 1 "
"and 1000, got "
<< config.server.io_threads << std::endl;
valid = false;
}
// epoll_instances validation removed - now always 1:1 with io_threads
if (config.server.event_batch_size < 1 ||
config.server.event_batch_size > 10000) {
std::cerr << "Configuration error: server.event_batch_size must be between "
"1 and 10000, got "
<< config.server.event_batch_size << std::endl;
valid = false;
}
if (config.server.max_connections < 0 ||
config.server.max_connections > 100000) {
std::cerr << "Configuration error: server.max_connections must be between "
"0 and 100000, got "
<< config.server.max_connections << std::endl;
valid = false;
}
if (config.server.read_buffer_size < 1024 ||
config.server.read_buffer_size > 1024 * 1024) { // 1KB to 1MB
std::cerr << "Configuration error: server.read_buffer_size must be between "
"1024 and 1048576 bytes, got "
<< config.server.read_buffer_size << std::endl;
valid = false;
}
// Validate commit configuration
if (config.commit.min_request_id_length < 8 ||
config.commit.min_request_id_length > 256) {
std::cerr << "Configuration error: commit.min_request_id_length must be "
"between 8 and 256, got "
<< config.commit.min_request_id_length << std::endl;
valid = false;
}
if (config.commit.request_id_retention_hours.count() < 1 ||
config.commit.request_id_retention_hours.count() > 8760) { // 1 year max
std::cerr << "Configuration error: commit.request_id_retention_hours must "
"be between 1 and 8760, got "
<< config.commit.request_id_retention_hours.count() << std::endl;
valid = false;
}
if (config.commit.request_id_retention_versions == 0) {
std::cerr << "Configuration error: commit.request_id_retention_versions "
"must be greater than 0"
<< std::endl;
valid = false;
}
if (config.commit.pipeline_release_threads < 1 ||
config.commit.pipeline_release_threads > 64) {
std::cerr << "Configuration error: commit.pipeline_release_threads must be "
"between 1 and 64, got "
<< config.commit.pipeline_release_threads << std::endl;
valid = false;
}
// Validate subscription configuration
if (config.subscription.max_buffer_size_bytes == 0) {
std::cerr << "Configuration error: subscription.max_buffer_size_bytes must "
"be greater than 0"
<< std::endl;
valid = false;
}
if (config.subscription.max_buffer_size_bytes >
1024 * 1024 * 1024) { // 1GB limit
std::cerr << "Configuration error: subscription.max_buffer_size_bytes too "
"large (max 1GB), got "
<< config.subscription.max_buffer_size_bytes << " bytes"
<< std::endl;
valid = false;
}
if (config.subscription.keepalive_interval.count() < 1 ||
config.subscription.keepalive_interval.count() > 3600) { // 1 hour max
std::cerr << "Configuration error: subscription.keepalive_interval must be "
"between 1 and 3600 seconds, got "
<< config.subscription.keepalive_interval.count() << std::endl;
valid = false;
}
// Cross-validation checks
if (config.server.max_request_size_bytes >
config.subscription.max_buffer_size_bytes) {
std::cerr << "Configuration warning: server.max_request_size_bytes ("
<< config.server.max_request_size_bytes
<< ") is larger than subscription.max_buffer_size_bytes ("
<< config.subscription.max_buffer_size_bytes << ")" << std::endl;
// This is just a warning, not a validation failure
}
return valid;
}
} // namespace weaseldb