From 819c4d8f1eb04e4616adb667fa03ef831a194b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20G=C3=B6gge?= Date: Tue, 17 Dec 2024 15:55:06 +0100 Subject: [PATCH] Support --filter-mark mark/[/mask] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit allows specifying a mask when filtering for packet marks with pwru. Fixes: #296 Signed-off-by: Robin Gögge --- bpf/kprobe_pwru.c | 3 ++- internal/pwru/config.go | 10 ++++--- internal/pwru/types.go | 59 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/bpf/kprobe_pwru.c b/bpf/kprobe_pwru.c index 73dbec0c..42fca157 100644 --- a/bpf/kprobe_pwru.c +++ b/bpf/kprobe_pwru.c @@ -136,6 +136,7 @@ struct { struct config { u32 netns; u32 mark; + u32 mask; u32 ifindex; u8 output_meta: 1; u8 output_tuple: 1; @@ -218,7 +219,7 @@ filter_meta(struct sk_buff *skb) { if (cfg->netns && get_netns(skb) != cfg->netns) { return false; } - if (cfg->mark && BPF_CORE_READ(skb, mark) != cfg->mark) { + if (cfg->mark && cfg->mask && (BPF_CORE_READ(skb, mark) & cfg->mask) != cfg->mark) { return false; } if (cfg->ifindex != 0 && BPF_CORE_READ(skb, dev, ifindex) != cfg->ifindex) { diff --git a/internal/pwru/config.go b/internal/pwru/config.go index 91684b2a..3fff62c4 100644 --- a/internal/pwru/config.go +++ b/internal/pwru/config.go @@ -36,9 +36,10 @@ const ( var Version string = "version unknown" type FilterCfg struct { - FilterNetns uint32 - FilterMark uint32 - FilterIfindex uint32 + FilterNetns uint32 + FilterMark uint32 + FilterMarkMask uint32 + FilterIfindex uint32 OutputFlags uint8 FilterFlags uint8 @@ -49,7 +50,8 @@ type FilterCfg struct { func GetConfig(flags *Flags) (cfg FilterCfg, err error) { cfg = FilterCfg{ - FilterMark: flags.FilterMark, + FilterMark: flags.FilterMark, + FilterMarkMask: flags.FilterMarkMask, } cfg.FilterFlags |= IsSetMask if flags.OutputSkb { diff --git a/internal/pwru/types.go b/internal/pwru/types.go index 2abb1372..71be9de1 100644 --- a/internal/pwru/types.go +++ b/internal/pwru/types.go @@ -7,6 +7,7 @@ package pwru import ( "fmt" "os" + "strconv" "strings" flag "github.com/spf13/pflag" @@ -27,6 +28,7 @@ type Flags struct { FilterNetns string FilterMark uint32 + FilterMarkMask uint32 FilterFunc string FilterNonSkbFuncs []string FilterTrackSkb bool @@ -67,7 +69,7 @@ func (f *Flags) SetFlags() { flag.StringVar(&f.FilterFunc, "filter-func", "", "filter kernel functions to be probed by name (exact match, supports RE2 regular expression)") flag.StringSliceVar(&f.FilterNonSkbFuncs, "filter-non-skb-funcs", nil, "filter non-skb kernel functions to be probed (--filter-track-skb-by-stackid will be enabled)") flag.StringVar(&f.FilterNetns, "filter-netns", "", "filter netns (\"/proc//ns/net\", \"inode:\")") - flag.Uint32Var(&f.FilterMark, "filter-mark", 0, "filter skb mark") + flag.Var(newMarkFlagValue(&f.FilterMark, &f.FilterMarkMask), "filter-mark", "filter skb mark (format: mark[/mask], e.g., 0xa00/0xf00)") flag.BoolVar(&f.FilterTrackSkb, "filter-track-skb", false, "trace a packet even if it does not match given filters (e.g., after NAT or tunnel decapsulation)") flag.BoolVar(&f.FilterTrackSkbByStackid, "filter-track-skb-by-stackid", false, "trace a packet even after it is kfreed (e.g., traffic going through bridge)") flag.BoolVar(&f.FilterTraceTc, "filter-trace-tc", false, "trace TC bpf progs") @@ -155,3 +157,58 @@ type Event struct { ParamThird uint64 CPU uint32 } + +type markFlagValue struct { + mark *uint32 + mask *uint32 +} + +func newMarkFlagValue(mark, mask *uint32) *markFlagValue { + return &markFlagValue{mark: mark, mask: mask} +} + +func (f *markFlagValue) String() string { + if *f.mask == 0 { + return fmt.Sprintf("0x%x", *f.mark) + } + return fmt.Sprintf("0x%x/0x%x", *f.mark, *f.mask) +} + +func (f *markFlagValue) Set(value string) error { + parts := strings.Split(value, "/") + + mark, err := parseUint32HexOrDecimal(parts[0]) + if err != nil { + return fmt.Errorf("invalid mark value: %v", err) + } + *f.mark = mark + *f.mask = 0xffffffff + + if len(parts) > 1 { + mask, err := parseUint32HexOrDecimal(parts[1]) + if err != nil { + return fmt.Errorf("invalid mask value: %v", err) + } + *f.mask = mask + } + + return nil +} + +func (f *markFlagValue) Type() string { + return "mark[/mask]" +} + +func parseUint32HexOrDecimal(s string) (uint32, error) { + base := 10 + if strings.HasPrefix(strings.ToLower(s), "0x") { + s = s[2:] + base = 16 + } + + val, err := strconv.ParseUint(s, base, 32) + if err != nil { + return 0, err + } + return uint32(val), nil +}