Skip to content

Commit

Permalink
Support --filter-mark mark/[/mask]
Browse files Browse the repository at this point in the history
This commit allows specifying a mask when filtering for packet marks
with pwru.

Fixes: #296

Signed-off-by: Robin Gögge <[email protected]>
  • Loading branch information
rgo3 authored and jschwinger233 committed Dec 19, 2024
1 parent f2c9e8b commit 99075ad
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
3 changes: 2 additions & 1 deletion bpf/kprobe_pwru.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct {
struct config {
u32 netns;
u32 mark;
u32 mask;
u32 ifindex;
u8 output_meta: 1;
u8 output_tuple: 1;
Expand Down Expand Up @@ -218,7 +219,7 @@ filter_meta(struct sk_buff *skb) {
if (cfg->netns && get_netns(skb) != cfg->netns) {
return false;
}
if (cfg->mark && BPF_CORE_READ(skb, mark) != cfg->mark) {
if (cfg->mark && cfg->mask && (BPF_CORE_READ(skb, mark) & cfg->mask) != cfg->mark) {
return false;
}
if (cfg->ifindex != 0 && BPF_CORE_READ(skb, dev, ifindex) != cfg->ifindex) {
Expand Down
10 changes: 6 additions & 4 deletions internal/pwru/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ const (
var Version string = "version unknown"

type FilterCfg struct {
FilterNetns uint32
FilterMark uint32
FilterIfindex uint32
FilterNetns uint32
FilterMark uint32
FilterMarkMask uint32
FilterIfindex uint32

OutputFlags uint8
FilterFlags uint8
Expand All @@ -49,7 +50,8 @@ type FilterCfg struct {

func GetConfig(flags *Flags) (cfg FilterCfg, err error) {
cfg = FilterCfg{
FilterMark: flags.FilterMark,
FilterMark: flags.FilterMark,
FilterMarkMask: flags.FilterMarkMask,
}
cfg.FilterFlags |= IsSetMask
if flags.OutputSkb {
Expand Down
59 changes: 58 additions & 1 deletion internal/pwru/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package pwru
import (
"fmt"
"os"
"strconv"
"strings"

flag "github.com/spf13/pflag"
Expand All @@ -27,6 +28,7 @@ type Flags struct {

FilterNetns string
FilterMark uint32
FilterMarkMask uint32
FilterFunc string
FilterNonSkbFuncs []string
FilterTrackSkb bool
Expand Down Expand Up @@ -67,7 +69,7 @@ func (f *Flags) SetFlags() {
flag.StringVar(&f.FilterFunc, "filter-func", "", "filter kernel functions to be probed by name (exact match, supports RE2 regular expression)")
flag.StringSliceVar(&f.FilterNonSkbFuncs, "filter-non-skb-funcs", nil, "filter non-skb kernel functions to be probed (--filter-track-skb-by-stackid will be enabled)")
flag.StringVar(&f.FilterNetns, "filter-netns", "", "filter netns (\"/proc/<pid>/ns/net\", \"inode:<inode>\")")
flag.Uint32Var(&f.FilterMark, "filter-mark", 0, "filter skb mark")
flag.Var(newMarkFlagValue(&f.FilterMark, &f.FilterMarkMask), "filter-mark", "filter skb mark (format: mark[/mask], e.g., 0xa00/0xf00)")
flag.BoolVar(&f.FilterTrackSkb, "filter-track-skb", false, "trace a packet even if it does not match given filters (e.g., after NAT or tunnel decapsulation)")
flag.BoolVar(&f.FilterTrackSkbByStackid, "filter-track-skb-by-stackid", false, "trace a packet even after it is kfreed (e.g., traffic going through bridge)")
flag.BoolVar(&f.FilterTraceTc, "filter-trace-tc", false, "trace TC bpf progs")
Expand Down Expand Up @@ -155,3 +157,58 @@ type Event struct {
ParamThird uint64
CPU uint32
}

type markFlagValue struct {
mark *uint32
mask *uint32
}

func newMarkFlagValue(mark, mask *uint32) *markFlagValue {
return &markFlagValue{mark: mark, mask: mask}
}

func (f *markFlagValue) String() string {
if *f.mask == 0 {
return fmt.Sprintf("0x%x", *f.mark)
}
return fmt.Sprintf("0x%x/0x%x", *f.mark, *f.mask)
}

func (f *markFlagValue) Set(value string) error {
parts := strings.Split(value, "/")

mark, err := parseUint32HexOrDecimal(parts[0])
if err != nil {
return fmt.Errorf("invalid mark value: %v", err)
}
*f.mark = mark
*f.mask = 0xffffffff

if len(parts) > 1 {
mask, err := parseUint32HexOrDecimal(parts[1])
if err != nil {
return fmt.Errorf("invalid mask value: %v", err)
}
*f.mask = mask
}

return nil
}

func (f *markFlagValue) Type() string {
return "mark[/mask]"
}

func parseUint32HexOrDecimal(s string) (uint32, error) {
base := 10
if strings.HasPrefix(strings.ToLower(s), "0x") {
s = s[2:]
base = 16
}

val, err := strconv.ParseUint(s, base, 32)
if err != nil {
return 0, err
}
return uint32(val), nil
}

0 comments on commit 99075ad

Please sign in to comment.