diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index 7775b961..723444b3 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -50,13 +50,15 @@ func NewIPTables(config *cfg.BouncerConfig) (types.Backend, error) { return nil, fmt.Errorf("invalid deny_action '%s', must be one of %s", config.DenyAction, strings.Join(allowedActions, ", ")) } + v4Sets := make(map[string]*ipsetcmd.IPSet) + v6Sets := make(map[string]*ipsetcmd.IPSet) + ipv4Ctx := &ipTablesContext{ version: "v4", SetName: config.BlacklistsIpv4, SetType: config.SetType, SetSize: config.SetSize, Chains: []string{}, - ipsets: make(map[string]*ipsetcmd.IPSet), defaultSet: defaultSet, target: target, } @@ -66,7 +68,6 @@ func NewIPTables(config *cfg.BouncerConfig) (types.Backend, error) { SetType: config.SetType, SetSize: config.SetSize, Chains: []string{}, - ipsets: make(map[string]*ipsetcmd.IPSet), defaultSet: defaultSet, target: target, } @@ -75,6 +76,11 @@ func NewIPTables(config *cfg.BouncerConfig) (types.Backend, error) { if config.Mode == cfg.IpsetMode { ipv4Ctx.ipsetContentOnly = true + set, err := ipsetcmd.NewIPSet(config.BlacklistsIpv4) + if err != nil { + return nil, err + } + v4Sets["ipset"] = set } else { ipv4Ctx.iptablesBin, err = exec.LookPath("iptables") if err != nil { @@ -83,6 +89,7 @@ func NewIPTables(config *cfg.BouncerConfig) (types.Backend, error) { ipv4Ctx.Chains = config.IptablesChains } + ipv4Ctx.ipsets = v4Sets ret.v4 = ipv4Ctx if config.DisableIPV6 { return ret, nil @@ -90,6 +97,11 @@ func NewIPTables(config *cfg.BouncerConfig) (types.Backend, error) { if config.Mode == cfg.IpsetMode { ipv6Ctx.ipsetContentOnly = true + set, err := ipsetcmd.NewIPSet(config.BlacklistsIpv6) + if err != nil { + return nil, err + } + v6Sets["ipset"] = set } else { ipv6Ctx.iptablesBin, err = exec.LookPath("ip6tables") if err != nil { @@ -98,6 +110,7 @@ func NewIPTables(config *cfg.BouncerConfig) (types.Backend, error) { ipv6Ctx.Chains = config.IptablesChains } + ipv6Ctx.ipsets = v6Sets ret.v6 = ipv6Ctx return ret, nil diff --git a/pkg/iptables/iptables_context.go b/pkg/iptables/iptables_context.go index 8fc5687a..e8d2af63 100644 --- a/pkg/iptables/iptables_context.go +++ b/pkg/iptables/iptables_context.go @@ -88,16 +88,23 @@ func (ctx *ipTablesContext) commit() error { }() for _, decision := range ctx.toDel { + + var set *ipsetcmd.IPSet + var ok bool + origin := *decision.Origin if origin == "lists" { origin = origin + ":" + *decision.Scenario } - set, ok := ctx.ipsets[origin] - - if !ok { - //No set for this origin, skip, as there's nothing to delete - continue + if ctx.ipsetContentOnly { + set = ctx.ipsets["ipset"] + } else { + set, ok = ctx.ipsets[origin] + if !ok { + //No set for this origin, skip, as there's nothing to delete + continue + } } delCmd := fmt.Sprintf("del %s %s -exist\n", set.Name(), *decision.Value) @@ -119,6 +126,9 @@ func (ctx *ipTablesContext) commit() error { continue } + var set *ipsetcmd.IPSet + var ok bool + if banDuration.Seconds() > 2147483 { log.Warnf("Ban duration too long (%d seconds), maximum for ipset is 2147483, setting duration to 2147482", int(banDuration.Seconds())) banDuration = time.Duration(2147482) * time.Second @@ -130,50 +140,56 @@ func (ctx *ipTablesContext) commit() error { origin = origin + ":" + *decision.Scenario } - set, ok := ctx.ipsets[origin] + if ctx.ipsetContentOnly { + set = ctx.ipsets["ipset"] + } else { + set, ok = ctx.ipsets[origin] - if !ok { + if !ok { - idx := slices.Index(ctx.originSetMapping, origin) + idx := slices.Index(ctx.originSetMapping, origin) - if idx == -1 { - ctx.originSetMapping = append(ctx.originSetMapping, origin) - idx = len(ctx.originSetMapping) - 1 - } + if idx == -1 { + ctx.originSetMapping = append(ctx.originSetMapping, origin) + idx = len(ctx.originSetMapping) - 1 + } - setName := fmt.Sprintf("%s-%d", ctx.SetName, idx) + setName := fmt.Sprintf("%s-%d", ctx.SetName, idx) - log.Infof("Using %s as set for origin %s", setName, origin) + log.Infof("Using %s as set for origin %s", setName, origin) - set, err = ipsetcmd.NewIPSet(setName) + set, err = ipsetcmd.NewIPSet(setName) - if err != nil { - log.Errorf("error while creating ipset : %s", err) - continue - } + if err != nil { + log.Errorf("error while creating ipset : %s", err) + continue + } - family := "inet" + family := "inet" - if ctx.version == "v6" { - family = "inet6" - } + if ctx.version == "v6" { + family = "inet6" + } - err = set.Create(ipsetcmd.CreateOptions{ - Family: family, - Timeout: "300", - MaxElem: strconv.Itoa(ctx.SetSize), - Type: ctx.SetType, - }) + err = set.Create(ipsetcmd.CreateOptions{ + Family: family, + Timeout: "300", + MaxElem: strconv.Itoa(ctx.SetSize), + Type: ctx.SetType, + }) - if err != nil { - log.Errorf("error while creating ipset : %s", err) - continue - } + if err != nil { + log.Errorf("error while creating ipset : %s", err) + continue + } - ctx.ipsets[origin] = set + ctx.ipsets[origin] = set - //Create the rule to use the set - ctx.createRule(set.Name()) + if !ctx.ipsetContentOnly { + //Create the rule to use the set + ctx.createRule(set.Name()) + } + } } addCmd := fmt.Sprintf("add %s %s timeout %d -exist\n", set.Name(), *decision.Value, int(banDuration.Seconds()))