diff --git a/cmd/root.go b/cmd/root.go index 168834dc..5359703b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,7 +1,6 @@ package cmd import ( - "bytes" "context" "flag" "fmt" @@ -20,8 +19,10 @@ import ( csbouncer "github.com/crowdsecurity/go-cs-bouncer" "github.com/crowdsecurity/go-cs-lib/csdaemon" + "github.com/crowdsecurity/go-cs-lib/csstring" "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/cs-firewall-bouncer/pkg/backend" @@ -152,17 +153,19 @@ func Execute() error { return fmt.Errorf("configuration file is required") } - configBytes, err := cfg.MergedConfig(*configPath) + configMerged, err := cfg.MergedConfig(*configPath) if err != nil { return fmt.Errorf("unable to read config file: %w", err) } if *showConfig { - fmt.Println(string(configBytes)) + fmt.Println(string(configMerged)) return nil } - config, err := cfg.NewConfig(bytes.NewReader(configBytes)) + configExpanded := csstring.StrictExpand(string(configMerged), os.LookupEnv) + + config, err := cfg.NewConfig(strings.NewReader(configExpanded)) if err != nil { return fmt.Errorf("unable to load configuration: %w", err) } @@ -186,7 +189,7 @@ func Execute() error { bouncer := &csbouncer.StreamBouncer{} - err = bouncer.ConfigReader(bytes.NewReader(configBytes)) + err = bouncer.ConfigReader(strings.NewReader(configExpanded)) if err != nil { return err } diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 0251d913..6656876f 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -1,14 +1,13 @@ package cfg import ( + "errors" "fmt" "io" - "os" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/csstring" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/yamlpatch" ) @@ -86,9 +85,7 @@ func NewConfig(reader io.Reader) (*BouncerConfig, error) { return nil, err } - configBuff := csstring.StrictExpand(string(fcontent), os.LookupEnv) - - err = yaml.Unmarshal([]byte(configBuff), &config) + err = yaml.Unmarshal(fcontent, &config) if err != nil { return nil, fmt.Errorf("failed to unmarshal: %w", err) } @@ -98,7 +95,7 @@ func NewConfig(reader io.Reader) (*BouncerConfig, error) { } if config.Mode == "" { - return nil, fmt.Errorf("config does not contain 'mode'") + return nil, errors.New("config does not contain 'mode'") } if len(config.SupportedDecisionsTypes) == 0 { @@ -191,7 +188,7 @@ func nftablesConfig(config *BouncerConfig) error { } if !*config.Nftables.Ipv4.Enabled && !*config.Nftables.Ipv6.Enabled { - return fmt.Errorf("both IPv4 and IPv6 disabled, doing nothing") + return errors.New("both IPv4 and IPv6 disabled, doing nothing") } if config.NftablesHooks == nil || len(config.NftablesHooks) == 0 {