diff --git a/broadcast/http.go b/broadcast/http.go index 733935b9..2e86494d 100644 --- a/broadcast/http.go +++ b/broadcast/http.go @@ -22,18 +22,18 @@ const ( // HTTPConfig contains HTTP pubsub adapter configuration type HTTPConfig struct { // Port to listen on - Port int + Port int `toml:"port"` // Path for HTTP broadast - Path string + Path string `toml:"path"` // Secret token to authorize requests - Secret string + Secret string `toml:"secret"` // SecretBase is a secret used to generate a token if none provided SecretBase string // AddCORSHeaders enables adding CORS headers (so you can perform broadcast requests from the browser) // (We mostly need it for Stackblitz) - AddCORSHeaders bool + AddCORSHeaders bool `toml:"cors_headers"` // CORSHosts contains a list of hostnames for CORS (comma-separated) - CORSHosts string + CORSHosts string `toml:"cors_hosts"` } // NewHTTPConfig builds a new config for HTTP pub/sub @@ -47,6 +47,43 @@ func (c *HTTPConfig) IsSecured() bool { return c.Secret != "" || c.SecretBase != "" } +func (c HTTPConfig) ToToml() string { + var result strings.Builder + + result.WriteString("# HTTP server port (can be the same as the main server port)\n") + result.WriteString(fmt.Sprintf("port = %d\n", c.Port)) + + result.WriteString("# HTTP endpoint path for broadcasts\n") + result.WriteString(fmt.Sprintf("path = \"%s\"\n", c.Path)) + + result.WriteString("# Secret token to authenticate broadcasting requests\n") + if c.Secret != "" { + result.WriteString(fmt.Sprintf("secret = \"%s\"\n", c.Secret)) + } else { + result.WriteString("# secret = \"\"\n") + } + + result.WriteString("# Enable CORS headers\n") + if c.AddCORSHeaders { + result.WriteString("cors_headers = true\n") + + result.WriteString("# Allowed hosts for CORS (comma-separated)\n") + if c.CORSHosts != "" { + result.WriteString(fmt.Sprintf("cors_hosts = \"%s\"\n", c.CORSHosts)) + } else { + result.WriteString("# cors_hosts = \"\"\n") + } + } else { + result.WriteString("# cors_headers = false\n") + result.WriteString("# Allowed hosts for CORS (comma-separated)\n") + result.WriteString("# cors_hosts = \"\"\n") + } + + result.WriteString("\n") + + return result.String() +} + // HTTPBroadcaster represents HTTP broadcaster type HTTPBroadcaster struct { port int diff --git a/broadcast/http_test.go b/broadcast/http_test.go index e27c70e5..b3d4d9cf 100644 --- a/broadcast/http_test.go +++ b/broadcast/http_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/BurntSushi/toml" "github.com/anycable/anycable-go/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -89,3 +90,28 @@ func TestHttpHandler(t *testing.T) { assert.Equal(t, http.StatusCreated, rr.Code) }) } + +func TestHTTPConfig__ToToml(t *testing.T) { + conf := NewHTTPConfig() + conf.Port = 8080 + conf.Path = "/broadcast" + conf.Secret = "" + conf.AddCORSHeaders = true + conf.CORSHosts = "example.com,test.com" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "port = 8080") + assert.Contains(t, tomlStr, "path = \"/broadcast\"") + assert.Contains(t, tomlStr, "# secret = \"\"") + assert.Contains(t, tomlStr, "cors_headers = true") + assert.Contains(t, tomlStr, "cors_hosts = \"example.com,test.com\"") + + // Round-trip test + conf2 := NewHTTPConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/broker/config.go b/broker/config.go index 683eb2b0..e058147d 100644 --- a/broker/config.go +++ b/broker/config.go @@ -1,14 +1,19 @@ package broker +import ( + "fmt" + "strings" +) + type Config struct { // Adapter name - Adapter string + Adapter string `toml:"adapter"` // For how long to keep history in seconds - HistoryTTL int64 + HistoryTTL int64 `toml:"history_ttl"` // Max size of messages to keep in the history per stream - HistoryLimit int + HistoryLimit int `toml:"history_limit"` // Sessions cache TTL in seconds (after disconnect) - SessionsTTL int64 + SessionsTTL int64 `toml:"sessions_ttl"` } func NewConfig() Config { @@ -21,3 +26,27 @@ func NewConfig() Config { SessionsTTL: 5 * 60, } } + +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# Broker backend adapter\n") + if c.Adapter == "" { + result.WriteString("# adapter = \"memory\"\n") + } else { + result.WriteString(fmt.Sprintf("adapter = \"%s\"\n", c.Adapter)) + } + + result.WriteString("# For how long to keep streams history (seconds)\n") + result.WriteString(fmt.Sprintf("history_ttl = %d\n", c.HistoryTTL)) + + result.WriteString("# Max number of messages to keep in a stream history\n") + result.WriteString(fmt.Sprintf("history_limit = %d\n", c.HistoryLimit)) + + result.WriteString("# For how long to store sessions state for resumeability (seconds)\n") + result.WriteString(fmt.Sprintf("sessions_ttl = %d\n", c.SessionsTTL)) + + result.WriteString("\n") + + return result.String() +} diff --git a/broker/config_test.go b/broker/config_test.go new file mode 100644 index 00000000..7d03d028 --- /dev/null +++ b/broker/config_test.go @@ -0,0 +1,50 @@ +package broker + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig__DecodeToml(t *testing.T) { + tomlString := ` + adapter = "nats" + history_ttl = 100 + history_limit = 1000 + sessions_ttl = 600 + ` + + conf := NewConfig() + _, err := toml.Decode(tomlString, &conf) + require.NoError(t, err) + + assert.Equal(t, "nats", conf.Adapter) + assert.Equal(t, int64(100), conf.HistoryTTL) + assert.Equal(t, 1000, conf.HistoryLimit) + assert.Equal(t, int64(600), conf.SessionsTTL) +} + +func TestConfig__ToToml(t *testing.T) { + conf := NewConfig() + conf.Adapter = "nats" + conf.HistoryTTL = 100 + conf.HistoryLimit = 1000 + conf.SessionsTTL = 600 + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "adapter = \"nats\"") + assert.Contains(t, tomlStr, "history_ttl = 100") + assert.Contains(t, tomlStr, "history_limit = 1000") + assert.Contains(t, tomlStr, "sessions_ttl = 600") + + // Round-trip test + conf2 := Config{} + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/cli/options.go b/cli/options.go index 0805a40e..a8e23ac8 100644 --- a/cli/options.go +++ b/cli/options.go @@ -380,7 +380,7 @@ Use broadcast_key instead.`) } if shouldPrintConfig { - fmt.Print(c.Display()) + fmt.Print(c.ToToml()) return &c, nil, true } diff --git a/config/config.go b/config/config.go index 21724191..a8dd2efb 100644 --- a/config/config.go +++ b/config/config.go @@ -105,7 +105,7 @@ func (c *Config) LoadFromFile() error { return nil } -func (c Config) Display() string { +func (c Config) ToToml() string { var result strings.Builder result.WriteString("# AnyCable server configuration.\n# Read more at https://docs.anycable.io/anycable-go/configuration\n\n") @@ -168,5 +168,50 @@ func (c Config) Display() string { result.WriteString(fmt.Sprintf("presets = [\"%s\"]\n\n", strings.Join(c.UserPresets, "\", \""))) } + result.WriteString("# Server configuration\n[server]\n") + result.WriteString(c.Server.ToToml()) + + result.WriteString("# Logging configuration\n[logging]\n") + result.WriteString(c.Log.ToToml()) + + result.WriteString("# RPC configuration\n[rpc]\n") + result.WriteString(c.RPC.ToToml()) + + result.WriteString("# Broker configuration\n[broker]\n") + result.WriteString(c.Broker.ToToml()) + + result.WriteString("# JWT configuration\n[jwt]\n") + result.WriteString(c.JWT.ToToml()) + + result.WriteString("# Pub/sub (signed) streams configuration\n[streams]\n") + result.WriteString(c.Streams.ToToml()) + + result.WriteString("# WebSockets configuration\n[ws]\n") + result.WriteString(c.WS.ToToml()) + + result.WriteString("# SSE configuration\n[sse]\n") + result.WriteString(c.SSE.ToToml()) + + result.WriteString("# Redis configuration\n[redis]\n") + result.WriteString(c.Redis.ToToml()) + + result.WriteString("# NATS configuration\n[nats]\n") + result.WriteString(c.NATS.ToToml()) + + result.WriteString("# Broadcasting configuration\n[http_broadcast]\n") + result.WriteString(c.HTTPBroadcast.ToToml()) + + result.WriteString("# Metrics configuration\n[metrics]\n") + result.WriteString(c.Metrics.ToToml()) + + result.WriteString("# App configuration\n[app]\n") + result.WriteString(c.App.ToToml()) + + result.WriteString("# Disconnector configuration\n[disconnector]\n") + result.WriteString(c.DisconnectQueue.ToToml()) + + result.WriteString("# Embedded NATS configuration\n[embedded_nats]\n") + result.WriteString(c.EmbeddedNats.ToToml()) + return result.String() } diff --git a/enats/config.go b/enats/config.go index ffad9f25..03c012cd 100644 --- a/enats/config.go +++ b/enats/config.go @@ -1,20 +1,109 @@ package enats +import ( + "fmt" + "strings" +) + // Config represents NATS service configuration type Config struct { - Enabled bool - Debug bool - Trace bool - Name string - ServiceAddr string - ClusterAddr string - ClusterName string - GatewayAddr string - GatewayAdvertise string - Gateways []string - Routes []string - JetStream bool - StoreDir string + Enabled bool `toml:"enabled"` + Debug bool `toml:"debug"` + Trace bool `toml:"trace"` + Name string `toml:"name"` + ServiceAddr string `toml:"service_addr"` + ClusterAddr string `toml:"cluster_addr"` + ClusterName string `toml:"cluster_name"` + GatewayAddr string `toml:"gateway_addr"` + GatewayAdvertise string `toml:"gateway_advertise"` + Gateways []string `toml:"gateways"` + Routes []string `toml:"routes"` + JetStream bool `toml:"jetstream"` + StoreDir string `toml:"jetstream_store_dir"` // Seconds to wait for JetStream to become ready (can take a lot of time when connecting to a cluster) - JetStreamReadyTimeout int + JetStreamReadyTimeout int `toml:"jetstream_ready_timeout"` +} + +func (c Config) ToToml() string { + var result strings.Builder + + if c.Enabled { + result.WriteString("enabled = true\n") + } else { + result.WriteString("# enabled = true\n") + } + + result.WriteString("#\n# Verbose logging settings\n") + if c.Debug { + result.WriteString("debug = true\n") + } else { + result.WriteString("# debug = true\n") + } + if c.Trace { + result.WriteString("trace = true\n") + } else { + result.WriteString("# trace = true\n") + } + + result.WriteString("#\n# Service name\n") + result.WriteString(fmt.Sprintf("name = \"%s\"\n", c.Name)) + + result.WriteString("#\n# Service address\n") + result.WriteString(fmt.Sprintf("service_addr = \"%s\"\n", c.ServiceAddr)) + + result.WriteString("#\n# Cluster configuration\n#\n") + if c.ClusterAddr != "" { + result.WriteString(fmt.Sprintf("cluster_addr = \"%s\"\n", c.ClusterAddr)) + } else { + result.WriteString("# cluster_addr = \"\"\n") + } + if c.ClusterName != "" { + if c.ClusterAddr == "" { + result.WriteString(fmt.Sprintf("# cluster_name = \"%s\"\n", c.ClusterName)) + } else { + result.WriteString(fmt.Sprintf("cluster_name = \"%s\"\n", c.ClusterName)) + } + } else { + result.WriteString("# cluster_name = \"\"\n") + } + if c.GatewayAddr != "" { + result.WriteString(fmt.Sprintf("gateway_addr = \"%s\"\n", c.GatewayAddr)) + } else { + result.WriteString("# gateway_addr = \"\"\n") + } + if c.GatewayAdvertise != "" { + result.WriteString(fmt.Sprintf("gateway_advertise = \"%s\"\n", c.GatewayAdvertise)) + } else { + result.WriteString("# gateway_advertise = \"\"\n") + } + if len(c.Gateways) != 0 { + result.WriteString(fmt.Sprintf("gateways = [\"%s\"]\n", strings.Join(c.Gateways, "\", \""))) + } else { + result.WriteString("# gateways = []\n") + } + if len(c.Routes) != 0 { + result.WriteString(fmt.Sprintf("routes = [\"%s\"]\n", strings.Join(c.Routes, "\", \""))) + } else { + result.WriteString("# routes = []\n") + } + + result.WriteString("#\n# JetStream configuration\n#\n") + if c.JetStream { + result.WriteString("jetstream = true\n") + } else { + result.WriteString("# jetstream = true\n") + } + + if c.StoreDir == "" { + result.WriteString("# jetstream_store_dir = \"\"\n") + } else { + result.WriteString(fmt.Sprintf("jetstream_store_dir = \"%s\"\n", c.StoreDir)) + } + if c.JetStream { + result.WriteString(fmt.Sprintf("jetstream_ready_timeout = %d\n", c.JetStreamReadyTimeout)) + } else { + result.WriteString(fmt.Sprintf("# jetstream_ready_timeout = %d\n", c.JetStreamReadyTimeout)) + } + + return result.String() } diff --git a/enats/config_test.go b/enats/config_test.go new file mode 100644 index 00000000..c3c29cca --- /dev/null +++ b/enats/config_test.go @@ -0,0 +1,52 @@ +package enats + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_ToToml(t *testing.T) { + conf := Config{ + Enabled: true, + Debug: false, + Trace: true, + Name: "test-service", + ServiceAddr: "localhost:4222", + ClusterAddr: "localhost:6222", + ClusterName: "test-cluster", + GatewayAddr: "localhost:7222", + GatewayAdvertise: "public.example.com:7222", + Gateways: []string{"nats://gateway1:7222", "nats://gateway2:7222"}, + Routes: []string{"nats://route1:6222", "nats://route2:6222"}, + JetStream: true, + StoreDir: "/tmp/nats-store", + JetStreamReadyTimeout: 30, + } + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "enabled = true") + assert.Contains(t, tomlStr, "# debug = true") + assert.Contains(t, tomlStr, "trace = true") + assert.Contains(t, tomlStr, "name = \"test-service\"") + assert.Contains(t, tomlStr, "service_addr = \"localhost:4222\"") + assert.Contains(t, tomlStr, "cluster_addr = \"localhost:6222\"") + assert.Contains(t, tomlStr, "cluster_name = \"test-cluster\"") + assert.Contains(t, tomlStr, "gateway_addr = \"localhost:7222\"") + assert.Contains(t, tomlStr, "gateway_advertise = \"public.example.com:7222\"") + assert.Contains(t, tomlStr, "gateways = [\"nats://gateway1:7222\", \"nats://gateway2:7222\"]") + assert.Contains(t, tomlStr, "routes = [\"nats://route1:6222\", \"nats://route2:6222\"]") + assert.Contains(t, tomlStr, "jetstream = true") + assert.Contains(t, tomlStr, "jetstream_store_dir = \"/tmp/nats-store\"") + assert.Contains(t, tomlStr, "jetstream_ready_timeout = 30") + + // Round-trip test + var conf2 Config + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/features/file_config.testfile b/features/file_config.testfile index 6e2b4a6d..ec584898 100644 --- a/features/file_config.testfile +++ b/features/file_config.testfile @@ -1,6 +1,6 @@ # Generate a configuration file with provided values run :anycable_gen_config, - ["sh", "-c", './dist/anycable-go --noauth --port 2024 --broadcast_adapter=http,redisx --print-config > ./anycable.toml'], + ["sh", "-c", './dist/anycable-go --noauth --port 2024 --broadcast_adapter=http,redisx --sse --metrics_tags=env:production,node_id:xyz --print-config > ./anycable.toml'], env: {"ANYCABLE_SECRET" => "file-secret", "ANYCABLE_NODE_ID" => "node-1"}, clean_env: true unless File.exist?("anycable.toml") @@ -25,8 +25,10 @@ assert_equal("noauth", true, config["noauth"]) assert_equal("secret", "file-secret", config["secret"]) assert_equal("broadcast adapters", %w[http redisx], config["broadcast_adapters"]) -# nested params: TODO -# assert_equal("server.port", 2024, config.dig("server", "port")) +# nested params +assert_equal("server.port", 2024, config.dig("server", "port")) +assert_equal("sse.enabled", true, config.dig("sse", "enabled")) +assert_equal("metrics.tags", {"env" => "production", "node_id" => "xyz"}, config.dig("metrics", "tags")) if $errors.any? fail $errors.join("\n") diff --git a/features/runner.rb b/features/runner.rb index 03f9bfde..550c5f79 100644 --- a/features/runner.rb +++ b/features/runner.rb @@ -40,6 +40,7 @@ class BenchRunner def initialize @processes = {} + @teardowns = [] @pipes = {} @log_level = ENV["DEBUG"] == "true" ? LOG_LEVEL_TO_NUM[:debug] : LOG_LEVEL_TO_NUM[:info] end @@ -160,6 +161,9 @@ def shutdown processes.each_value do |process| process.stop end + + teardowns.each(&:call) + teardowns.clear end def retrying(delay: 1, attempts: 2, &block) @@ -176,9 +180,13 @@ def retrying(delay: 1, attempts: 2, &block) end end + def at_exit(&block) + teardowns << block + end + private - attr_reader :processes, :pipes, :log_level + attr_reader :processes, :pipes, :log_level, :teardowns def log(level, &block) return unless log_level >= LOG_LEVEL_TO_NUM[level] diff --git a/identity/jwt.go b/identity/jwt.go index 0e3c08ee..91361b6a 100644 --- a/identity/jwt.go +++ b/identity/jwt.go @@ -15,10 +15,10 @@ const ( ) type JWTConfig struct { - Secret string - Param string + Secret string `toml:"secret"` + Param string `toml:"param"` Algo jwt.SigningMethod - Force bool + Force bool `toml:"force"` } var ( @@ -33,6 +33,27 @@ func (c JWTConfig) Enabled() bool { return c.Secret != "" } +func (c JWTConfig) ToToml() string { + var result strings.Builder + + result.WriteString("# Secret key\n") + result.WriteString(fmt.Sprintf("secret = \"%s\"\n", c.Secret)) + + result.WriteString("# Parameter name (an URL query or a header name carrying a token, e.g., `x-`)\n") + result.WriteString(fmt.Sprintf("param = \"%s\"\n", c.Param)) + + result.WriteString("# Enfore JWT authentication\n") + if c.Force { + result.WriteString("force = true\n") + } else { + result.WriteString("# force = true\n") + } + + result.WriteString("\n") + + return result.String() +} + type JWTIdentifier struct { secret []byte paramName string diff --git a/identity/jwt_test.go b/identity/jwt_test.go index 33eacbf6..b857e4e3 100644 --- a/identity/jwt_test.go +++ b/identity/jwt_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/BurntSushi/toml" "github.com/anycable/anycable-go/common" "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" @@ -164,3 +165,23 @@ func TestJWTIdentifierIdentify(t *testing.T) { assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions) }) } + +func TestConfig__ToToml(t *testing.T) { + conf := NewJWTConfig("jwt-secret") + conf.Force = false + conf.Param = "token" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "param = \"token\"") + assert.Contains(t, tomlStr, "secret = \"jwt-secret\"") + assert.Contains(t, tomlStr, "# force = true") + + // Round-trip test + conf2 := NewJWTConfig("bla") + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/logger/config.go b/logger/config.go index f237f6b0..9ea601f3 100644 --- a/logger/config.go +++ b/logger/config.go @@ -1,9 +1,14 @@ package logger +import ( + "fmt" + "strings" +) + type Config struct { - LogLevel string - LogFormat string - Debug bool + LogLevel string `toml:"level"` + LogFormat string `toml:"format"` + Debug bool `toml:"debug"` } func NewConfig() Config { @@ -12,3 +17,24 @@ func NewConfig() Config { LogFormat: "text", } } + +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# Logging level (debug, info, warn, error)\n") + result.WriteString(fmt.Sprintf("level = \"%s\"\n", c.LogLevel)) + + result.WriteString("# Logs formatting (e.g., 'text' or 'json')\n") + result.WriteString(fmt.Sprintf("format = \"%s\"\n", c.LogFormat)) + + result.WriteString("# Enable debug (verbose) logging\n") + if c.Debug { + result.WriteString("debug = true\n") + } else { + result.WriteString("# debug = true\n") + } + + result.WriteString("\n") + + return result.String() +} diff --git a/logger/config_test.go b/logger/config_test.go new file mode 100644 index 00000000..7b74c0be --- /dev/null +++ b/logger/config_test.go @@ -0,0 +1,46 @@ +package logger + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig__DecodeToml(t *testing.T) { + tomlString := ` + level = "warn" + format = "json" + debug = true + ` + + conf := NewConfig() + _, err := toml.Decode(tomlString, &conf) + require.NoError(t, err) + + assert.Equal(t, "warn", conf.LogLevel) + assert.Equal(t, "json", conf.LogFormat) + assert.True(t, conf.Debug) +} + +func TestConfig__ToToml(t *testing.T) { + conf := NewConfig() + conf.LogLevel = "warn" + conf.LogFormat = "json" + conf.Debug = false + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "level = \"warn\"") + assert.Contains(t, tomlStr, "format = \"json\"") + assert.Contains(t, tomlStr, "# debug = true") + + // Round-trip test + conf2 := Config{} + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/metrics/config.go b/metrics/config.go index dd85c14d..af73dd7e 100644 --- a/metrics/config.go +++ b/metrics/config.go @@ -1,18 +1,23 @@ package metrics +import ( + "fmt" + "strings" +) + // Config contains metrics configuration type Config struct { - Log bool - LogInterval int // Deprecated - RotateInterval int + Log bool `toml:"log"` + LogInterval int // Deprecated + RotateInterval int `toml:"rotate_interval"` LogFormatter string // Print only specified metrics - LogFilter []string - HTTP string - Host string - Port int - Tags map[string]string - Statsd StatsdConfig + LogFilter []string `toml:"log_filter"` + HTTP string `toml:"http_path"` + Host string `toml:"host"` + Port int `toml:"port"` + Tags map[string]string `toml:"tags"` + Statsd StatsdConfig `toml:"statsd"` } // NewConfig creates an empty Config struct @@ -37,3 +42,59 @@ func (c *Config) HTTPEnabled() bool { func (c *Config) LogFormatterEnabled() bool { return c.LogFormatter != "" } + +// ToToml converts the Config to a TOML string representation +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# HTTP endpoint (Prometheus)\n") + if c.HTTP != "" { + result.WriteString(fmt.Sprintf("http = \"%s\"\n", c.HTTP)) + } else { + result.WriteString("# http = \"/metrics\"\n") + } + + result.WriteString("# Standalone metrics HTTP server host to bind to\n") + if c.Host != "" { + result.WriteString(fmt.Sprintf("host = \"%s\"\n", c.Host)) + } else { + result.WriteString("# host = \"localhost\"\n") + } + + result.WriteString("# Metrics HTTP server port to listen on\n# (can be the same as the main server's port)\n") + if c.Port != 0 { + result.WriteString(fmt.Sprintf("port = %d\n", c.Port)) + } else { + result.WriteString("# port = 8082\n") + } + + result.WriteString("# Enable metrics logging\n") + if c.Log { + result.WriteString("log = true\n") + } else { + result.WriteString("# log = true\n") + } + + result.WriteString("# Log rotation interval (seconds)\n") + result.WriteString(fmt.Sprintf("rotate_interval = %d\n", c.RotateInterval)) + + result.WriteString("# Log filter (show only selected metrics)\n") + if len(c.LogFilter) > 0 { + result.WriteString(fmt.Sprintf("log_filter = [ \"%s\" ]\n", strings.Join(c.LogFilter, "\", \""))) + } else { + result.WriteString("# log_filter = []\n") + } + + result.WriteString("# Metrics tags\n") + if len(c.Tags) > 0 { + for key, value := range c.Tags { + result.WriteString(fmt.Sprintf("tags.%s = \"%s\"\n", key, value)) + } + } else { + result.WriteString("# tags.key = \"value\"\n") + } + + result.WriteString("\n") + + return result.String() +} diff --git a/metrics/config_test.go b/metrics/config_test.go index bd4342b8..51e8dc5d 100644 --- a/metrics/config_test.go +++ b/metrics/config_test.go @@ -3,7 +3,9 @@ package metrics import ( "testing" + "github.com/BurntSushi/toml" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLogEnabled(t *testing.T) { @@ -28,3 +30,31 @@ func TestHTTPEnabled(t *testing.T) { config.HTTP = "/metrics" assert.True(t, config.HTTPEnabled()) } + +func TestConfig_ToToml(t *testing.T) { + conf := NewConfig() + conf.Log = true + conf.RotateInterval = 30 + conf.LogFilter = []string{"metric1", "metric2"} + conf.Host = "example.com" + conf.Port = 9090 + conf.Tags = map[string]string{"env": "prod", "region": "us-west"} + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "host = \"example.com\"") + assert.Contains(t, tomlStr, "port = 9090") + assert.Contains(t, tomlStr, "log = true") + assert.Contains(t, tomlStr, "rotate_interval = 30") + assert.Contains(t, tomlStr, "log_filter = [ \"metric1\", \"metric2\" ]") + assert.Contains(t, tomlStr, "tags.env = \"prod\"") + assert.Contains(t, tomlStr, "tags.region = \"us-west\"") + + // Round-trip test + conf2 := NewConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/nats/config.go b/nats/config.go index ea848e28..e278596b 100644 --- a/nats/config.go +++ b/nats/config.go @@ -1,16 +1,19 @@ package nats import ( + "fmt" + "strings" + natsgo "github.com/nats-io/nats.go" ) type NATSConfig struct { - Servers string - Channel string - DontRandomizeServers bool - MaxReconnectAttempts int + Servers string `toml:"servers"` + Channel string `toml:"channel"` + DontRandomizeServers bool `toml:"dont_randomize_servers"` + MaxReconnectAttempts int `toml:"max_reconnect_attempts"` // Internal channel name for node-to-node broadcasting - InternalChannel string + InternalChannel string `toml:"internal_channel"` } func NewNATSConfig() NATSConfig { @@ -21,3 +24,30 @@ func NewNATSConfig() NATSConfig { InternalChannel: "__anycable_internal__", } } + +func (c NATSConfig) ToToml() string { + var result strings.Builder + + result.WriteString("# NATS server URLs (comma-separated)\n") + result.WriteString(fmt.Sprintf("servers = \"%s\"\n", c.Servers)) + + result.WriteString("# Channel name for legacy broadasting\n") + result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel)) + + result.WriteString("# Don't randomize servers during connection\n") + if c.DontRandomizeServers { + result.WriteString("dont_randomize_servers = true\n") + } else { + result.WriteString("# dont_randomize_servers = true\n") + } + + result.WriteString("# Max number of reconnect attempts\n") + result.WriteString(fmt.Sprintf("max_reconnect_attempts = %d\n", c.MaxReconnectAttempts)) + + result.WriteString("# Channel name for pub/sub (node-to-node)\n") + result.WriteString(fmt.Sprintf("internal_channel = \"%s\"\n", c.InternalChannel)) + + result.WriteString("\n") + + return result.String() +} diff --git a/nats/config_test.go b/nats/config_test.go new file mode 100644 index 00000000..58926188 --- /dev/null +++ b/nats/config_test.go @@ -0,0 +1,34 @@ +package nats + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNATSConfig_ToToml(t *testing.T) { + conf := NewNATSConfig() + conf.Servers = "nats://localhost:4222" + conf.Channel = "test_channel" + conf.DontRandomizeServers = true + conf.MaxReconnectAttempts = 10 + conf.InternalChannel = "test_internal_channel" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "servers = \"nats://localhost:4222\"") + assert.Contains(t, tomlStr, "channel = \"test_channel\"") + assert.Contains(t, tomlStr, "dont_randomize_servers = true") + assert.Contains(t, tomlStr, "max_reconnect_attempts = 10") + assert.Contains(t, tomlStr, "internal_channel = \"test_internal_channel\"") + + // Round-trip test + conf2 := NewNATSConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/node/config.go b/node/config.go index 9862ada3..986615b3 100644 --- a/node/config.go +++ b/node/config.go @@ -1,5 +1,10 @@ package node +import ( + "fmt" + "strings" +) + const ( DISCONNECT_MODE_ALWAYS = "always" DISCONNECT_MODE_AUTO = "auto" @@ -11,21 +16,21 @@ var DISCONNECT_MODES = []string{DISCONNECT_MODE_ALWAYS, DISCONNECT_MODE_AUTO, DI // Config contains general application/node settings type Config struct { // Define when to invoke Disconnect callback - DisconnectMode string + DisconnectMode string `toml:"disconnect_mode"` // The number of goroutines to use for disconnect calls on shutdown - ShutdownDisconnectPoolSize int + ShutdownDisconnectPoolSize int `toml:"shutdown_disconnect_pool_size"` // How often server should send Action Cable ping messages (seconds) - PingInterval int + PingInterval int `toml:"ping_interval"` // How ofter to refresh node stats (seconds) - StatsRefreshInterval int + StatsRefreshInterval int `toml:"stats_refresh_interval"` // The max size of the Go routines pool for hub - HubGopoolSize int + HubGopoolSize int `toml:"broadcast_gopool_size"` // How should ping message timestamp be formatted? ('s' => seconds, 'ms' => milli seconds, 'ns' => nano seconds) - PingTimestampPrecision string + PingTimestampPrecision string `toml:"ping_timestamp_precision"` // For how long to wait for pong message before disconnecting (seconds) - PongTimeout int + PongTimeout int `toml:"pong_timeout"` // For how long to wait for disconnect callbacks to be processed before exiting (seconds) - ShutdownTimeout int + ShutdownTimeout int `toml:"shutdown_timeout"` } // NewConfig builds a new config @@ -40,3 +45,39 @@ func NewConfig() Config { ShutdownTimeout: 30, } } + +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# Server-to-client heartbeat interval (seconds)\n") + result.WriteString(fmt.Sprintf("ping_interval = %d\n", c.PingInterval)) + + result.WriteString("# Timestamp format for ping messages (s, ms, or ns)\n") + result.WriteString(fmt.Sprintf("ping_timestamp_precision = \"%s\"\n", c.PingTimestampPrecision)) + + result.WriteString("# Client-to-server pong timeout (seconds)\n") + if c.PongTimeout == 0 { + result.WriteString("# pong_timeout = 6\n") + } else { + result.WriteString(fmt.Sprintf("pong_timeout = %d\n", c.PongTimeout)) + } + + result.WriteString("# Define when to invoke Disconnect RPC callback\n") + result.WriteString(fmt.Sprintf("disconnect_mode = \"%s\"\n", c.DisconnectMode)) + + result.WriteString("# Graceful shutdown period (seconds)\n") + result.WriteString(fmt.Sprintf("shutdown_timeout = %d\n", c.ShutdownTimeout)) + + result.WriteString("# How often to refresh system-wide metrics (seconds)\n") + result.WriteString(fmt.Sprintf("stats_refresh_interval = %d\n", c.StatsRefreshInterval)) + + result.WriteString("# The number of Go routines to use for broadcasting (server-to-client fan-out)\n") + result.WriteString(fmt.Sprintf("broadcast_gopool_size = %d\n", c.HubGopoolSize)) + + result.WriteString("# The number of goroutines to use for Disconnect RPC calls on shutdown\n") + result.WriteString(fmt.Sprintf("shutdown_disconnect_pool_size = %d\n", c.ShutdownDisconnectPoolSize)) + + result.WriteString("\n") + + return result.String() +} diff --git a/node/config_test.go b/node/config_test.go new file mode 100644 index 00000000..e064413b --- /dev/null +++ b/node/config_test.go @@ -0,0 +1,31 @@ +package node + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_ToToml(t *testing.T) { + conf := NewConfig() + conf.DisconnectMode = "always" + conf.HubGopoolSize = 100 + conf.PingTimestampPrecision = "ns" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "disconnect_mode = \"always\"") + assert.Contains(t, tomlStr, "broadcast_gopool_size = 100") + assert.Contains(t, tomlStr, "ping_timestamp_precision = \"ns\"") + assert.Contains(t, tomlStr, "# pong_timeout = 6") + + // Round-trip test + conf2 := NewConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/node/disconnect_queue.go b/node/disconnect_queue.go index 87181ade..b1b13a74 100644 --- a/node/disconnect_queue.go +++ b/node/disconnect_queue.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "strings" "sync" "time" ) @@ -23,6 +24,20 @@ func NewDisconnectQueueConfig() DisconnectQueueConfig { return DisconnectQueueConfig{Rate: 100, Backlog: 4096} } +func (c DisconnectQueueConfig) ToToml() string { + var result strings.Builder + + result.WriteString("# Limit the number of Disconnect RPC calls per second\n") + result.WriteString(fmt.Sprintf("rate = %d\n", c.Rate)) + + result.WriteString("# Queue size for disconnect requests\n") + result.WriteString(fmt.Sprintf("backlog = %d\n", c.Backlog)) + + result.WriteString("\n") + + return result.String() +} + // DisconnectQueue is a rate-limited executor type DisconnectQueue struct { node *Node diff --git a/node/disconnect_queue_test.go b/node/disconnect_queue_test.go index 8c538e6c..c1733531 100644 --- a/node/disconnect_queue_test.go +++ b/node/disconnect_queue_test.go @@ -6,7 +6,9 @@ import ( "runtime" "testing" + "github.com/BurntSushi/toml" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDisconnectQueue_Run(t *testing.T) { @@ -84,3 +86,22 @@ func newQueue() *DisconnectQueue { return q } + +func TestDisconnectQueueConfig_ToToml(t *testing.T) { + conf := NewDisconnectQueueConfig() + conf.Rate = 50 + conf.Backlog = 2048 + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "rate = 50") + assert.Contains(t, tomlStr, "backlog = 2048") + + // Round-trip test + conf2 := NewDisconnectQueueConfig() + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/redis/config.go b/redis/config.go index 6c5db315..1bf4e73c 100644 --- a/redis/config.go +++ b/redis/config.go @@ -1,6 +1,7 @@ package redis import ( + "fmt" "net/url" "strings" "time" @@ -12,27 +13,27 @@ import ( type RedisConfig struct { // Redis instance URL or master name in case of sentinels usage // or list of URLs if cluster usage - URL string + URL string `toml:"url"` // Redis channel to subscribe to (legacy pub/sub) - Channel string + Channel string `toml:"channel"` // Redis stream consumer group name - Group string + Group string `toml:"group"` // Redis stream read wait time in milliseconds - StreamReadBlockMilliseconds int64 + StreamReadBlockMilliseconds int64 `toml:"stream_read_block_milliseconds"` // Internal channel name for node-to-node broadcasting - InternalChannel string + InternalChannel string `toml:"internal_channel"` // List of Redis Sentinel addresses - Sentinels string + Sentinels string `toml:"sentinels"` // Redis Sentinel discovery interval (seconds) - SentinelDiscoveryInterval int + SentinelDiscoveryInterval int `toml:"sentinel_discovery_interval"` // Redis keepalive ping interval (seconds) - KeepalivePingInterval int + KeepalivePingInterval int `toml:"keepalive_ping_interval"` // Whether to check server's certificate for validity (in case of rediss:// protocol) - TLSVerify bool + TLSVerify bool `toml:"tls_verify"` // Max number of reconnect attempts - MaxReconnectAttempts int + MaxReconnectAttempts int `toml:"max_reconnect_attempts"` // Disable client-side caching - DisableCache bool + DisableCache bool `toml:"disable_cache"` // List of hosts to connect hosts []string @@ -121,6 +122,56 @@ func (config *RedisConfig) parseSentinels() (*rueidis.ClientOption, error) { return options, nil } +func (config RedisConfig) ToToml() string { + var result strings.Builder + + result.WriteString("# Redis instance URL or master name in case of sentinels usage\n") + result.WriteString("# or list of URLs if cluster usage\n") + result.WriteString(fmt.Sprintf("url = \"%s\"\n", config.URL)) + + result.WriteString("# Channel name for legacy broadcasting\n") + result.WriteString(fmt.Sprintf("channel = \"%s\"\n", config.Channel)) + + result.WriteString("# Stream consumer group name for RedisX broadcasting\n") + result.WriteString(fmt.Sprintf("group = \"%s\"\n", config.Group)) + + result.WriteString("# Streams read wait time in milliseconds\n") + result.WriteString(fmt.Sprintf("stream_read_block_milliseconds = %d\n", config.StreamReadBlockMilliseconds)) + + result.WriteString("# Channel name for pub/sub (node-to-node)\n") + result.WriteString(fmt.Sprintf("internal_channel = \"%s\"\n", config.InternalChannel)) + + result.WriteString("# Sentinel addresses (comma-separated list)\n") + result.WriteString(fmt.Sprintf("sentinels = \"%s\"\n", config.Sentinels)) + + result.WriteString("# Sentinel discovery interval (seconds)\n") + result.WriteString(fmt.Sprintf("sentinel_discovery_interval = %d\n", config.SentinelDiscoveryInterval)) + + result.WriteString("# Keepalive ping interval (seconds)\n") + result.WriteString(fmt.Sprintf("keepalive_ping_interval = %d\n", config.KeepalivePingInterval)) + + result.WriteString("# Enable TLS Verify\n") + if config.TLSVerify { + result.WriteString(fmt.Sprintf("tls_verify = %t\n", config.TLSVerify)) + } else { + result.WriteString("# tls_verify = true\n") + } + + result.WriteString("# Max number of reconnect attempts\n") + result.WriteString(fmt.Sprintf("max_reconnect_attempts = %d\n", config.MaxReconnectAttempts)) + + result.WriteString("# Disable client-side caching\n") + if config.DisableCache { + result.WriteString(fmt.Sprintf("disable_cache = %t\n", config.DisableCache)) + } else { + result.WriteString("# disable_cache = true\n") + } + + result.WriteString("\n") + + return result.String() +} + func parseRedisURL(url string) (options *rueidis.ClientOption, err error) { urls := strings.Split(url, ",") diff --git a/redis/config_test.go b/redis/config_test.go index 83163cca..31b90119 100644 --- a/redis/config_test.go +++ b/redis/config_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/BurntSushi/toml" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -165,3 +166,40 @@ func TestInvalidURL(t *testing.T) { _, err := config.ToRueidisOptions() require.Error(t, err) } + +func TestRedisConfig__ToToml(t *testing.T) { + config := NewRedisConfig() + config.URL = "redis://example.com:6379" + config.Channel = "test_channel" + config.Group = "test_group" + config.StreamReadBlockMilliseconds = 3000 + config.InternalChannel = "test_internal" + config.Sentinels = "sentinel1:26379,sentinel2:26379" + config.SentinelDiscoveryInterval = 60 + config.KeepalivePingInterval = 45 + config.TLSVerify = true + config.MaxReconnectAttempts = 10 + config.DisableCache = true + + tomlStr := config.ToToml() + + assert.Contains(t, tomlStr, "url = \"redis://example.com:6379\"") + assert.Contains(t, tomlStr, "channel = \"test_channel\"") + assert.Contains(t, tomlStr, "group = \"test_group\"") + assert.Contains(t, tomlStr, "stream_read_block_milliseconds = 3000") + assert.Contains(t, tomlStr, "internal_channel = \"test_internal\"") + assert.Contains(t, tomlStr, "sentinels = \"sentinel1:26379,sentinel2:26379\"") + assert.Contains(t, tomlStr, "sentinel_discovery_interval = 60") + assert.Contains(t, tomlStr, "keepalive_ping_interval = 45") + assert.Contains(t, tomlStr, "tls_verify = true") + assert.Contains(t, tomlStr, "max_reconnect_attempts = 10") + assert.Contains(t, tomlStr, "disable_cache = true") + + // Round-trip test + config2 := NewRedisConfig() + + _, err := toml.Decode(tomlStr, &config2) + require.NoError(t, err) + + assert.Equal(t, config, config2) +} diff --git a/rpc/config.go b/rpc/config.go index 9b27db3d..d3bbe453 100644 --- a/rpc/config.go +++ b/rpc/config.go @@ -33,33 +33,33 @@ type Dialer = func(c *Config, l *slog.Logger) (pb.RPCClient, ClientHelper, error // Config contains RPC controller configuration type Config struct { // RPC instance host - Host string + Host string `toml:"host"` // ProxyHeaders to add to RPC request env - ProxyHeaders []string + ProxyHeaders []string `toml:"proxy_headers"` // ProxyCookies to add to RPC request env - ProxyCookies []string + ProxyCookies []string `toml:"proxy_cookies"` // The max number of simultaneous requests. // Should be slightly less than the RPC server concurrency to avoid // ResourceExhausted errors - Concurrency int + Concurrency int `toml:"concurrency"` // Enable client-side TLS on RPC connections? - EnableTLS bool + EnableTLS bool `toml:"enable_tls"` // Whether to verify the RPC server's certificate chain and host name - TLSVerify bool + TLSVerify bool `toml:"tls_verify"` // CA root TLS certificate path - TLSRootCA string + TLSRootCA string `toml:"tls_root_ca_path"` // Max receive msg size (bytes) - MaxRecvSize int + MaxRecvSize int `toml:"max_recv_size"` // Max send msg size (bytes) - MaxSendSize int + MaxSendSize int `toml:"max_send_size"` // Underlying implementation (grpc, http, or none) - Implementation string + Implementation string `toml:"implementation"` // Alternative dialer implementation DialFun Dialer // Secret for HTTP RPC authentication - Secret string + Secret string `toml:"secret"` // Timeout for HTTP RPC requests (in ms) - RequestTimeout int + RequestTimeout int `toml:"http_request_timeout"` // SecretBase is a secret used to generate authentication token SecretBase string } @@ -145,3 +145,73 @@ func ensureGrpcScheme(url string) string { return "grpc://" + url } + +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# RPC implementation (grpc, http, or none)\n") + result.WriteString(fmt.Sprintf("implementation = \"%s\"\n", c.Implementation)) + + result.WriteString("# RPC service hostname (including port, e.g., 'anycable-rpc:50051')\n") + result.WriteString(fmt.Sprintf("host = \"%s\"\n", c.Host)) + + result.WriteString("# Specify HTTP headers that must be proxied to the RPC service\n") + if len(c.ProxyHeaders) > 0 { + result.WriteString(fmt.Sprintf("proxy_headers = [\"%s\"]\n", strings.Join(c.ProxyHeaders, "\", \""))) + } else { + result.WriteString("# proxy_headers = [\"cookie\"]\n") + } + + result.WriteString("# Specify which cookies must be kept in the proxied Cookie header\n") + if len(c.ProxyCookies) > 0 { + result.WriteString(fmt.Sprintf("proxy_cookies = [\"%s\"]\n", strings.Join(c.ProxyCookies, "\", \""))) + } else { + result.WriteString("# proxy_cookies = [\"_session_id\"]\n") + } + + result.WriteString("# RPC concurrency (max number of concurrent RPC requests)\n") + result.WriteString(fmt.Sprintf("concurrency = %d\n", c.Concurrency)) + + result.WriteString("# Enable client-side TLS on RPC connections\n") + if c.EnableTLS { + result.WriteString(fmt.Sprintf("enable_tls = %v\n", c.EnableTLS)) + } else { + result.WriteString("# enable_tls = true\n") + } + + result.WriteString("# Enable TLS Verify for RPC connections\n") + if c.TLSVerify { + result.WriteString(fmt.Sprintf("tls_verify = %v\n", c.TLSVerify)) + } else { + result.WriteString("# tls_verify = true\n") + } + + result.WriteString("# CA root TLS certificate path\n") + if c.TLSRootCA == "" { + result.WriteString(fmt.Sprintf("tls_root_ca_path = \"%s\"\n", c.TLSRootCA)) + } else { + result.WriteString("# tls_root_ca_path =\n") + } + + result.WriteString("# HTTP RPC specific settings\n") + result.WriteString("# Secret for HTTP RPC authentication\n") + if c.Secret != "" { + result.WriteString(fmt.Sprintf("secret = \"%s\"\n", c.Secret)) + } else { + result.WriteString("# secret =\n") + } + + result.WriteString("# Timeout for HTTP RPC requests (in ms)\n") + result.WriteString(fmt.Sprintf("http_request_timeout = %d\n", c.RequestTimeout)) + + result.WriteString("# GRPC fine-tuning\n") + result.WriteString("# Max allowed incoming message size (bytes)\n") + result.WriteString(fmt.Sprintf("max_recv_size = %d\n", c.MaxRecvSize)) + + result.WriteString("# Max allowed outgoing message size (bytes)\n") + result.WriteString(fmt.Sprintf("max_send_size = %d\n", c.MaxSendSize)) + + result.WriteString("\n") + + return result.String() +} diff --git a/rpc/config_test.go b/rpc/config_test.go index 6f9d7ce3..20495ae8 100644 --- a/rpc/config_test.go +++ b/rpc/config_test.go @@ -3,7 +3,9 @@ package rpc import ( "testing" + "github.com/BurntSushi/toml" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConfig_Impl(t *testing.T) { @@ -39,3 +41,28 @@ func TestConfig_Impl(t *testing.T) { c.Host = "invalid://:+" assert.Equal(t, "", c.Impl()) } + +func TestConfig__ToToml(t *testing.T) { + conf := NewConfig() + conf.Host = "rpc.test" + conf.Concurrency = 10 + conf.Implementation = "http" + conf.ProxyHeaders = []string{"Cookie", "X-Api-Key"} + conf.ProxyCookies = []string{"_session_id", "_csrf_token"} + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "implementation = \"http\"") + assert.Contains(t, tomlStr, "host = \"rpc.test\"") + assert.Contains(t, tomlStr, "concurrency = 10") + assert.Contains(t, tomlStr, "proxy_headers = [\"Cookie\", \"X-Api-Key\"]") + assert.Contains(t, tomlStr, "proxy_cookies = [\"_session_id\", \"_csrf_token\"]") + + // Round-trip test + conf2 := Config{} + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/server/config.go b/server/config.go index 55cdbd12..6ccdc8a6 100644 --- a/server/config.go +++ b/server/config.go @@ -1,11 +1,16 @@ package server +import ( + "fmt" + "strings" +) + type Config struct { - Host string - Port int - MaxConn int - HealthPath string - SSL SSLConfig + Host string `toml:"host"` + Port int `toml:"port"` + MaxConn int `toml:"max_conn"` + HealthPath string `toml:"health_path"` + SSL SSLConfig `toml:"ssl"` } func NewConfig() Config { @@ -16,3 +21,38 @@ func NewConfig() Config { SSL: NewSSLConfig(), } } + +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# Host address to bind to\n") + result.WriteString(fmt.Sprintf("host = %q\n", c.Host)) + result.WriteString("# Port to listen on\n") + result.WriteString(fmt.Sprintf("port = %d\n", c.Port)) + result.WriteString("# Maximum number of allowed concurrent connections\n") + if c.MaxConn == 0 { + result.WriteString("# max_conn = 1000\n") + } else { + result.WriteString(fmt.Sprintf("max_conn = %d\n", c.MaxConn)) + } + result.WriteString("# Health check endpoint path\n") + result.WriteString(fmt.Sprintf("health_path = %q\n", c.HealthPath)) + + result.WriteString("# SSL configuration\n") + + if c.SSL.CertPath != "" { + result.WriteString(fmt.Sprintf("ssl.cert_path = %q\n", c.SSL.CertPath)) + } else { + result.WriteString("# ssl.cert_path =\n") + } + + if c.SSL.KeyPath != "" { + result.WriteString(fmt.Sprintf("ssl.key_path = %q\n", c.SSL.KeyPath)) + } else { + result.WriteString("# ssl.key_path =\n") + } + + result.WriteString("\n") + + return result.String() +} diff --git a/server/config_test.go b/server/config_test.go new file mode 100644 index 00000000..e1f1d50e --- /dev/null +++ b/server/config_test.go @@ -0,0 +1,53 @@ +package server + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig__DecodeToml(t *testing.T) { + tomlString := ` + host = "0.0.0.0" + port = 8081 + max_conn = 100 + health_path = "/healthz" + ` + + conf := NewConfig() + _, err := toml.Decode(tomlString, &conf) + require.NoError(t, err) + + assert.Equal(t, "0.0.0.0", conf.Host) + assert.Equal(t, 8081, conf.Port) + assert.Equal(t, 100, conf.MaxConn) + assert.Equal(t, "/healthz", conf.HealthPath) +} + +func TestConfig__ToToml(t *testing.T) { + conf := NewConfig() + conf.Host = "local.test" + conf.Port = 8082 + conf.HealthPath = "/healthz" + conf.SSL.CertPath = "/path/to/cert" + conf.SSL.KeyPath = "/path/to/key" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "host = \"local.test\"") + assert.Contains(t, tomlStr, "port = 8082") + assert.Contains(t, tomlStr, "# max_conn = 1000") + assert.Contains(t, tomlStr, "health_path = \"/healthz\"") + assert.Contains(t, tomlStr, "ssl.cert_path = \"/path/to/cert\"") + assert.Contains(t, tomlStr, "ssl.key_path = \"/path/to/key\"") + + // Round-trip test + conf2 := Config{} + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/server/ssl_config.go b/server/ssl_config.go index 31fcd029..e60bffe0 100644 --- a/server/ssl_config.go +++ b/server/ssl_config.go @@ -2,8 +2,8 @@ package server // SSLConfig contains SSL parameters type SSLConfig struct { - CertPath string - KeyPath string + CertPath string `toml:"cert_path"` + KeyPath string `toml:"key_path"` } // NewSSLConfig build a new SSLConfig struct diff --git a/sse/config.go b/sse/config.go index 387ce203..6d3c36ae 100644 --- a/sse/config.go +++ b/sse/config.go @@ -1,14 +1,19 @@ package sse +import ( + "fmt" + "strings" +) + const ( defaultMaxBodySize = 65536 // 64 kB ) -// Long-polling configuration +// Server-sent events configuration type Config struct { - Enabled bool + Enabled bool `toml:"enabled"` // Path is the URL path to handle SSE requests - Path string + Path string `toml:"path"` // List of allowed origins for CORS requests // We inherit it from the ws.Config AllowedOrigins string @@ -21,3 +26,22 @@ func NewConfig() Config { Path: "/events", } } + +// ToToml converts the Config struct to a TOML string representation +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# Enable Server-sent events support\n") + if c.Enabled { + result.WriteString("enabled = true\n") + } else { + result.WriteString("# enabled = true\n") + } + + result.WriteString("# Server-sent events endpoint path\n") + result.WriteString(fmt.Sprintf("path = \"%s\"\n", c.Path)) + + result.WriteString("\n") + + return result.String() +} diff --git a/sse/config_test.go b/sse/config_test.go new file mode 100644 index 00000000..e20f6ed1 --- /dev/null +++ b/sse/config_test.go @@ -0,0 +1,27 @@ +package sse + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_ToToml(t *testing.T) { + conf := NewConfig() + conf.Path = "/events" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "path = \"/events\"") + assert.Contains(t, tomlStr, "# enabled = true") + + // Round-trip test + conf2 := Config{} + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/streams/config.go b/streams/config.go index 0dfe0849..aff615bf 100644 --- a/streams/config.go +++ b/streams/config.go @@ -2,30 +2,35 @@ // without using channels (a simplified pub/sub mode) package streams +import ( + "fmt" + "strings" +) + type Config struct { // Secret is a key used to sign and verify streams - Secret string + Secret string `toml:"secret"` // Public determines if public (unsigned) streams are allowed - Public bool + Public bool `toml:"public"` // Whisper determines if whispering is enabled for pub/sub streams - Whisper bool + Whisper bool `toml:"whisper"` // PubSubChannel is the channel name used for direct pub/sub - PubSubChannel string + PubSubChannel string `toml:"pubsub_channel"` // Turbo is a flag to enable Turbo Streams support - Turbo bool + Turbo bool `toml:"turbo"` // TurboSecret is a custom secret key used to verify Turbo Streams - TurboSecret string + TurboSecret string `toml:"turbo_secret"` // CableReady is a flag to enable CableReady support - CableReady bool + CableReady bool `toml:"cable_ready"` // CableReadySecret is a custom secret key used to verify CableReady streams - CableReadySecret string + CableReadySecret string `toml:"cable_ready_secret"` } // NewConfig returns a new Config with the given key @@ -50,3 +55,63 @@ func (c Config) GetCableReadySecret() string { return c.Secret } + +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# Secret key used to sign and verify pub/sub streams\n") + if c.Secret != "" { + result.WriteString(fmt.Sprintf("secret = \"%s\"\n", c.Secret)) + } else { + result.WriteString("# secret = \"\"\n") + } + + result.WriteString("# Enable public (unsigned) streams\n") + if c.Public { + result.WriteString("public = true\n") + } else { + result.WriteString("# public = true\n") + } + + result.WriteString("# Enable whispering support for pub/sub streams\n") + if c.Whisper { + result.WriteString("whisper = true\n") + } else { + result.WriteString("# whisper = true\n") + } + + result.WriteString("# Name of the channel used for pub/sub\n") + result.WriteString(fmt.Sprintf("pubsub_channel = \"%s\"\n", c.PubSubChannel)) + + result.WriteString("# Enable Turbo Streams support\n") + if c.Turbo { + result.WriteString("turbo = true\n") + } else { + result.WriteString("# turbo = true\n") + } + + result.WriteString("# Custom secret key used to verify Turbo Streams\n") + if c.TurboSecret != "" { + result.WriteString(fmt.Sprintf("turbo_secret = \"%s\"\n", c.TurboSecret)) + } else { + result.WriteString("# turbo_secret = \"\"\n") + } + + result.WriteString("# Enable CableReady support\n") + if c.CableReady { + result.WriteString("cable_ready = true\n") + } else { + result.WriteString("# cable_ready = true\n") + } + + result.WriteString("# Custom secret key used to verify CableReady streams\n") + if c.CableReadySecret != "" { + result.WriteString(fmt.Sprintf("cable_ready_secret = \"%s\"\n", c.CableReadySecret)) + } else { + result.WriteString("# cable_ready_secret = \"\"\n") + } + + result.WriteString("\n") + + return result.String() +} diff --git a/streams/config_test.go b/streams/config_test.go new file mode 100644 index 00000000..e99a4dfc --- /dev/null +++ b/streams/config_test.go @@ -0,0 +1,40 @@ +package streams + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_ToToml(t *testing.T) { + conf := NewConfig() + conf.Secret = "test-secret" + conf.Public = true + conf.Whisper = false + conf.PubSubChannel = "test-channel" + conf.Turbo = true + conf.TurboSecret = "turbo-secret" + conf.CableReady = false + conf.CableReadySecret = "cable-ready-secret" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "secret = \"test-secret\"") + assert.Contains(t, tomlStr, "public = true") + assert.Contains(t, tomlStr, "# whisper = true") + assert.Contains(t, tomlStr, "pubsub_channel = \"test-channel\"") + assert.Contains(t, tomlStr, "turbo = true") + assert.Contains(t, tomlStr, "turbo_secret = \"turbo-secret\"") + assert.Contains(t, tomlStr, "# cable_ready = true") + assert.Contains(t, tomlStr, "cable_ready_secret = \"cable-ready-secret\"") + + // Round-trip test + conf2 := Config{} + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +} diff --git a/ws/config.go b/ws/config.go index 0e157c1b..810d6e6d 100644 --- a/ws/config.go +++ b/ws/config.go @@ -1,16 +1,51 @@ package ws +import ( + "fmt" + "strings" +) + // Config contains WebSocket connection configuration. type Config struct { - Paths []string - ReadBufferSize int - WriteBufferSize int - MaxMessageSize int64 - EnableCompression bool - AllowedOrigins string + Paths []string `toml:"paths"` + ReadBufferSize int `toml:"read_buffer_size"` + WriteBufferSize int `toml:"write_buffer_size"` + MaxMessageSize int64 `toml:"max_message_size"` + EnableCompression bool `toml:"enable_compression"` + AllowedOrigins string `toml:"allowed_origins"` } // NewConfig build a new Config struct func NewConfig() Config { return Config{Paths: []string{"/cable"}, ReadBufferSize: 1024, WriteBufferSize: 1024, MaxMessageSize: 65536} } + +// ToToml converts the Config struct to a TOML string representation +func (c Config) ToToml() string { + var result strings.Builder + + result.WriteString("# WebSocket endpoint paths\n") + result.WriteString(fmt.Sprintf("paths = [\"%s\"]\n", strings.Join(c.Paths, "\", \""))) + + result.WriteString("# Allowed origins (a comma-separated list)\n") + result.WriteString(fmt.Sprintf("allowed_origins = \"%s\"\n", c.AllowedOrigins)) + + result.WriteString("# Read buffer size\n") + result.WriteString(fmt.Sprintf("read_buffer_size = %d\n", c.ReadBufferSize)) + + result.WriteString("# Write buffer size\n") + result.WriteString(fmt.Sprintf("write_buffer_size = %d\n", c.WriteBufferSize)) + + result.WriteString("# Maximum message size\n") + result.WriteString(fmt.Sprintf("max_message_size = %d\n", c.MaxMessageSize)) + + if c.EnableCompression { + result.WriteString("# Enable compression (per-message deflate)\n") + result.WriteString("enable_compression = true\n") + result.WriteString("# enable_compression = true\n") + } + + result.WriteString("\n") + + return result.String() +} diff --git a/ws/config_test.go b/ws/config_test.go new file mode 100644 index 00000000..83888b6c --- /dev/null +++ b/ws/config_test.go @@ -0,0 +1,36 @@ +package ws + +import ( + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_ToToml(t *testing.T) { + conf := NewConfig() + conf.Paths = []string{"/ws", "/socket"} + conf.ReadBufferSize = 2048 + conf.WriteBufferSize = 2048 + conf.MaxMessageSize = 131072 + conf.EnableCompression = true + conf.AllowedOrigins = "http://example.com" + + tomlStr := conf.ToToml() + + assert.Contains(t, tomlStr, "paths = [\"/ws\", \"/socket\"]") + assert.Contains(t, tomlStr, "read_buffer_size = 2048") + assert.Contains(t, tomlStr, "write_buffer_size = 2048") + assert.Contains(t, tomlStr, "max_message_size = 131072") + assert.Contains(t, tomlStr, "enable_compression = true") + assert.Contains(t, tomlStr, "allowed_origins = \"http://example.com\"") + + // Round-trip test + conf2 := Config{} + + _, err := toml.Decode(tomlStr, &conf2) + require.NoError(t, err) + + assert.Equal(t, conf, conf2) +}