Skip to content

Commit

Permalink
nftables: create one rule/set per decision origin
Browse files Browse the repository at this point in the history
  • Loading branch information
blotus committed Apr 22, 2024
1 parent aa7eddd commit 261f693
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 128 deletions.
7 changes: 4 additions & 3 deletions pkg/iptables/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"

"github.com/crowdsecurity/cs-firewall-bouncer/pkg/metrics"
Expand Down Expand Up @@ -74,8 +75,8 @@ func (ipt *iptables) CollectMetrics() {
}

if (ipt.v4 != nil && !ipt.v4.ipsetContentOnly) || (ipt.v6 != nil && !ipt.v6.ipsetContentOnly) {
metrics.TotalDroppedPackets.Set(ip4DroppedPackets + ip6DroppedPackets)
metrics.TotalDroppedBytes.Set(ip6DroppedBytes + ip4DroppedBytes)
metrics.TotalDroppedPackets.With(prometheus.Labels{"ip_type": "ipv4", "origin": ""}).Set(ip4DroppedPackets + ip6DroppedPackets)
metrics.TotalDroppedBytes.With(prometheus.Labels{"ip_type": "ipv4", "origin": ""}).Set(ip6DroppedBytes + ip4DroppedBytes)
}

out, err := exec.Command(ipt.v4.ipsetBin, "list", "-o", "xml").CombinedOutput()
Expand Down Expand Up @@ -109,6 +110,6 @@ func (ipt *iptables) CollectMetrics() {
}
}

metrics.TotalActiveBannedIPs.Set(newCount)
metrics.TotalActiveBannedIPs.With(prometheus.Labels{"ip_type": "ipv4", "origin": ""}).Set(newCount)
}
}
12 changes: 6 additions & 6 deletions pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import (

const MetricCollectionInterval = time.Second * 10

var TotalDroppedPackets = prometheus.NewGauge(prometheus.GaugeOpts{
var TotalDroppedPackets = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "fw_bouncer_dropped_packets",
Help: "Denotes the number of total dropped packets because of rule(s) created by crowdsec",
})
}, []string{"origin", "ip_type"})

var TotalDroppedBytes = prometheus.NewGauge(prometheus.GaugeOpts{
var TotalDroppedBytes = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "fw_bouncer_dropped_bytes",
Help: "Denotes the number of total dropped bytes because of rule(s) created by crowdsec",
})
}, []string{"origin", "ip_type"})

var TotalActiveBannedIPs = prometheus.NewGauge(prometheus.GaugeOpts{
var TotalActiveBannedIPs = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "fw_bouncer_banned_ips",
Help: "Denotes the number of IPs which are currently banned",
})
}, []string{"origin", "ip_type"})
109 changes: 63 additions & 46 deletions pkg/nftables/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ package nftables

import (
"fmt"
"strings"
"time"

"github.com/crowdsecurity/cs-firewall-bouncer/pkg/metrics"
"github.com/google/nftables/expr"
"github.com/prometheus/client_golang/prometheus"

log "github.com/sirupsen/logrus"
)

func (c *nftContext) collectDroppedPackets(chain string) (int, int, error) {
droppedPackets := 0
droppedBytes := 0
func (c *nftContext) collectDroppedPackets() (map[string]int, map[string]int, error) {
droppedPackets := make(map[string]int)
droppedBytes := make(map[string]int)
//setName := ""
for chainName, chain := range c.chains {
rules, err := c.conn.GetRules(c.table, chain)
Expand All @@ -24,64 +26,59 @@ func (c *nftContext) collectDroppedPackets(chain string) (int, int, error) {
continue
}
for _, rule := range rules {
origin := ""
pkts := 0
bytes := 0
for _, xpr := range rule.Exprs {
switch obj := xpr.(type) {
case *expr.Counter:
log.Infof("rule %d (%s): packets %d, bytes %d", rule.Position, rule.Table.Name, obj.Packets, obj.Bytes)
droppedPackets += int(obj.Packets)
droppedBytes += int(obj.Bytes)
log.Debugf("rule %d (%s): packets %d, bytes %d", rule.Position, rule.Table.Name, obj.Packets, obj.Bytes)
pkts += int(obj.Packets)
bytes += int(obj.Bytes)
case *expr.Lookup:
log.Infof("rule %d (%s): lookup %s", rule.Position, rule.Table.Name, obj.SetName)
//setName = obj.SetName
log.Debugf("rule %d (%s): lookup %s", rule.Position, rule.Table.Name, obj.SetName)
origin, _ = strings.CutPrefix(obj.SetName, fmt.Sprintf("%s-", c.blacklists))
}
}
if origin != "" {
droppedPackets[origin] += pkts
droppedBytes[origin] += bytes
}
}
}

return droppedPackets, droppedBytes, nil
}

func (c *nftContext) collectActiveBannedIPs() (int, error) {
func (c *nftContext) collectActiveBannedIPs() (map[string]int, error) {
//Find the size of the set we have created
set, err := c.conn.GetSetByName(c.table, c.set.Name)

if err != nil {
return 0, fmt.Errorf("can't get set %s: %w", c.set.Name, err)
}
ret := make(map[string]int)

setContent, err := c.conn.GetSetElements(set)

if err != nil {
return 0, fmt.Errorf("can't get set elements for %s: %w", c.set.Name, err)
for origin, set := range c.sets {
setContent, err := c.conn.GetSetElements(set)
if err != nil {
return nil, fmt.Errorf("can't get set elements for %s: %w", set.Name, err)
}
if c.setOnly {
ret[c.blacklists] = len(setContent)
} else {
ret[origin] = len(setContent)
}
return ret, nil
}

return len(setContent), nil
return ret, nil
}

func (c *nftContext) collectDropped(hooks []string) (int, int, int) {
func (c *nftContext) collectDropped() (map[string]int, map[string]int, map[string]int) {
if c.conn == nil {
return 0, 0, 0
return nil, nil, nil
}

var droppedPackets, droppedBytes, banned int

if c.setOnly {
pkt, byt, err := c.collectDroppedPackets(c.chainName)
if err != nil {
log.Errorf("can't collect dropped packets for ip%s from nft: %s", c.version, err)
}
droppedPackets, droppedBytes, err := c.collectDroppedPackets()

droppedPackets += pkt
droppedBytes += byt
} else {
for _, hook := range hooks {
pkt, byt, err := c.collectDroppedPackets(c.chainName + "-" + hook)
if err != nil {
log.Errorf("can't collect dropped packets for ip%s from nft: %s", c.version, err)
}
droppedPackets += pkt
droppedBytes += byt
}
if err != nil {
log.Errorf("can't collect dropped packets for ip%s from nft: %s", c.version, err)
}

banned, err := c.collectActiveBannedIPs()
Expand All @@ -98,15 +95,35 @@ func (n *nft) CollectMetrics() {

for range t.C {
startTime := time.Now()
ip4DroppedPackets, ip4DroppedBytes, bannedIP4 := n.v4.collectDropped(n.Hooks)
ip6DroppedPackets, ip6DroppedBytes, bannedIP6 := n.v6.collectDropped(n.Hooks)
ip4DroppedPackets, ip4DroppedBytes, bannedIP4 := n.v4.collectDropped()
ip6DroppedPackets, ip6DroppedBytes, bannedIP6 := n.v6.collectDropped()

log.Debugf("metrics collection took %s", time.Since(startTime))
log.Debugf("ip4: dropped packets: %d, dropped bytes: %d, banned IPs: %d", ip4DroppedPackets, ip4DroppedBytes, bannedIP4)
log.Debugf("ip6: dropped packets: %d, dropped bytes: %d, banned IPs: %d", ip6DroppedPackets, ip6DroppedBytes, bannedIP6)
log.Debugf("ip4: dropped packets: %+v, dropped bytes: %+v, banned IPs: %+v", ip4DroppedPackets, ip4DroppedBytes, bannedIP4)
log.Debugf("ip6: dropped packets: %+v, dropped bytes: %+v, banned IPs: %+v", ip6DroppedPackets, ip6DroppedBytes, bannedIP6)

for origin, count := range bannedIP4 {
metrics.TotalActiveBannedIPs.With(prometheus.Labels{"origin": origin, "ip_type": "ip4"}).Set(float64(count))
}

metrics.TotalDroppedPackets.Set(float64(ip4DroppedPackets + ip6DroppedPackets))
metrics.TotalDroppedBytes.Set(float64(ip6DroppedBytes + ip4DroppedBytes))
metrics.TotalActiveBannedIPs.Set(float64(bannedIP4 + bannedIP6))
for origin, count := range bannedIP6 {
metrics.TotalActiveBannedIPs.With(prometheus.Labels{"origin": origin, "ip_type": "ip6"}).Set(float64(count))
}

for origin, count := range ip4DroppedPackets {
metrics.TotalDroppedPackets.With(prometheus.Labels{"origin": origin, "ip_type": "ip4"}).Set(float64(count))
}

for origin, count := range ip6DroppedPackets {
metrics.TotalDroppedPackets.With(prometheus.Labels{"origin": origin, "ip_type": "ip6"}).Set(float64(count))
}

for origin, count := range ip4DroppedBytes {
metrics.TotalDroppedBytes.With(prometheus.Labels{"origin": origin, "ip_type": "ip4"}).Set(float64(count))
}

for origin, count := range ip6DroppedBytes {
metrics.TotalDroppedBytes.With(prometheus.Labels{"origin": origin, "ip_type": "ip6"}).Set(float64(count))
}
}
}
83 changes: 70 additions & 13 deletions pkg/nftables/nftables.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
log "github.com/sirupsen/logrus"

"github.com/crowdsecurity/crowdsec/pkg/models"
Expand Down Expand Up @@ -49,11 +50,11 @@ func NewNFTables(config *cfg.BouncerConfig) (*nft, error) {
func (n *nft) Init() error {
log.Debug("nftables: Init()")

if err := n.v4.init(n.Hooks, n.DenyLog, n.DenyLogPrefix, n.DenyAction); err != nil {
if err := n.v4.init(n.Hooks); err != nil {
return err
}

if err := n.v6.init(n.Hooks, n.DenyLog, n.DenyLogPrefix, n.DenyAction); err != nil {
if err := n.v6.init(n.Hooks); err != nil {
return err
}

Expand Down Expand Up @@ -139,14 +140,42 @@ func (n *nft) commitDeletedDecisions() error {
return nil
}

func (n *nft) createSetAndRuleForOrigin(ctx *nftContext, origin string) error {
if _, ok := ctx.sets[origin]; !ok {
//First time we see this origin, create the rule/set for all hooks
set := &nftables.Set{
Name: fmt.Sprintf("%s-%s", ctx.blacklists, origin),
Table: ctx.table,
KeyType: ctx.typeIPAddr,
KeyByteOrder: binaryutil.BigEndian,
HasTimeout: true,
}

ctx.sets[origin] = set

if err := ctx.conn.AddSet(set, []nftables.SetElement{}); err != nil {
return err
}
for _, chain := range ctx.chains {
rule, err := ctx.createRule(chain, set, n.DenyLog, n.DenyLogPrefix, n.DenyAction)
if err != nil {
return err
}
ctx.conn.AddRule(rule)
log.Infof("Created set and rule for origin %s and type %s", origin, ctx.typeIPAddr.Name)
}
}
return nil
}

func (n *nft) commitAddedDecisions() error {
banned, err := n.getBannedState()
if err != nil {
return fmt.Errorf("failed to get current state: %w", err)
}

ip4 := []nftables.SetElement{}
ip6 := []nftables.SetElement{}
ip4 := make(map[string][]nftables.SetElement, 0)
ip6 := make(map[string][]nftables.SetElement, 0)

n.decisionsToAdd = normalizedDecisions(n.decisionsToAdd)

Expand All @@ -159,20 +188,37 @@ func (n *nft) commitAddedDecisions() error {

t, _ := time.ParseDuration(*decision.Duration)

origin := *decision.Origin

if strings.Contains(ip.String(), ":") {
if n.v6.conn != nil {
if n.v6.setOnly {
origin = n.v6.blacklists
}
log.Tracef("adding %s to buffer", ip)

ip6 = append(ip6, nftables.SetElement{Timeout: t, Key: ip.To16()})
if _, ok := ip6[origin]; !ok {
ip6[origin] = make([]nftables.SetElement, 0)
}
ip6[origin] = append(ip6[origin], nftables.SetElement{Timeout: t, Key: ip.To16()})
if !n.v6.setOnly {
n.createSetAndRuleForOrigin(n.v6, origin)

Check failure on line 204 in pkg/nftables/nftables.go

View workflow job for this annotation

GitHub Actions / golangci-lint + codeql

Error return value of `n.createSetAndRuleForOrigin` is not checked (errcheck)
}
}

continue
}

if n.v4.conn != nil {
if n.v4.setOnly {
origin = n.v4.blacklists
}
log.Tracef("adding %s to buffer", ip)

ip4 = append(ip4, nftables.SetElement{Timeout: t, Key: ip.To4()})
if _, ok := ip4[origin]; !ok {
ip4[origin] = make([]nftables.SetElement, 0)
}
ip4[origin] = append(ip4[origin], nftables.SetElement{Timeout: t, Key: ip.To4()})
if !n.v4.setOnly {
n.createSetAndRuleForOrigin(n.v4, origin)

Check failure on line 220 in pkg/nftables/nftables.go

View workflow job for this annotation

GitHub Actions / golangci-lint + codeql

Error return value of `n.createSetAndRuleForOrigin` is not checked (errcheck)
}
}
}

Expand Down Expand Up @@ -201,9 +247,14 @@ func (n *nft) Commit() error {
return nil
}

type tmpDecisions struct {
duration time.Duration
origin string
}

// remove duplicates, normalize decision timeouts, keep the longest decision when dups are present.
func normalizedDecisions(decisions []*models.Decision) []*models.Decision {
vals := make(map[string]time.Duration)
vals := make(map[string]tmpDecisions)
finalDecisions := make([]*models.Decision, 0)

for _, d := range decisions {
Expand All @@ -213,16 +264,22 @@ func normalizedDecisions(decisions []*models.Decision) []*models.Decision {
}

*d.Value = strings.Split(*d.Value, "/")[0]
vals[*d.Value] = maxTime(t, vals[*d.Value])
if max, ok := vals[*d.Value]; !ok || t > max.duration {
vals[*d.Value] = tmpDecisions{
duration: t,
origin: *d.Origin,
}
}
}

for ip, duration := range vals {
d := duration.String()
for ip, decision := range vals {
d := decision.duration.String()
i := ip // copy it because we don't same value for all decisions as `ip` is same pointer :)

finalDecisions = append(finalDecisions, &models.Decision{
Duration: &d,
Value: &i,
Origin: &decision.origin,

Check failure on line 282 in pkg/nftables/nftables.go

View workflow job for this annotation

GitHub Actions / golangci-lint + codeql

exporting a pointer for the loop variable decision (exportloopref)
})
}

Expand Down
Loading

0 comments on commit 261f693

Please sign in to comment.