diff --git a/bpf/include/util.h b/bpf/include/util.h index 9cc2059..3347c99 100644 --- a/bpf/include/util.h +++ b/bpf/include/util.h @@ -12,4 +12,11 @@ Copyright (C) Kubeshark #define memcpy(dest, src, n) __builtin_memcpy((dest), (src), (n)) #endif +#ifndef likely + #define likely(x) __builtin_expect((x), 1) +#endif +#ifndef unlikely + #define unlikely(x) __builtin_expect((x), 0) +#endif + #endif /* __UTIL__ */ diff --git a/bpf/packet_sniffer.c b/bpf/packet_sniffer.c index ca28fb9..1799381 100644 --- a/bpf/packet_sniffer.c +++ b/bpf/packet_sniffer.c @@ -118,14 +118,14 @@ static __always_inline int filter_packets(struct __sk_buff *skb, void *cgrpctxma if (side == PACKET_DIRECTION_RECEIVED) { TRACE_PACKET("cg/in", true, skb->local_ip4, skb->remote_ip4, skb->local_port & 0xffff, skb->remote_port & 0xffff, cgroup_id); + save_packet(skb, src_ip, skb->remote_port>>16, dst_ip, bpf_htons(skb->local_port), cgroup_id, side); } else { TRACE_PACKET("cg/out", true, skb->local_ip4, skb->remote_ip4, skb->local_port & 0xffff, skb->remote_port & 0xffff, cgroup_id); + save_packet(skb, src_ip, bpf_htons(skb->local_port), dst_ip, skb->remote_port>>16, cgroup_id, side); } - save_packet(skb, src_ip, src_port, dst_ip, dst_port, cgroup_id, side); - return 1; } @@ -211,6 +211,14 @@ static __noinline void _save_packet(struct pkt_sniffer_ctx *ctx) packet_id = pkt_id_ptr->id++; bpf_spin_unlock(&pkt_id_ptr->lock); + // send initial chunk before the first packet + if (unlikely(packet_id == 0)) { + if (bpf_perf_event_output(skb, &pkts_buffer, BPF_F_CURRENT_CPU, p, 0)) + { + log_error(skb, LOG_ERROR_PKT_SNIFFER, 11, 0l, 0l); + } + } + if (bpf_map_update_elem(&packet_context, &packet_id, p, BPF_NOEXIST)) { log_error(skb, LOG_ERROR_PKT_SNIFFER, 5, 0l, 0l); diff --git a/main.go b/main.go index 38ef1a6..1a288ae 100644 --- a/main.go +++ b/main.go @@ -41,6 +41,8 @@ var procfs = flag.String("procfs", "/proc", "The procfs directory, used when map // development var debug = flag.Bool("debug", false, "Enable debug mode") +var initBPF = flag.Bool("init-bpf", false, "Use to initialize bpf filesystem. Common usage is from init containers.") + var disableEbpfCapture = flag.Bool("disable-ebpf", false, "Disable capture packet via eBPF") var disableTlsLog = flag.Bool("disable-tls-log", false, "Disable tls logging") @@ -103,6 +105,10 @@ func main() { } }() + if *initBPF { + initBPFSubsystem() + return + } run() } @@ -120,11 +126,11 @@ func run() { return } - tcpMap, err := resolver.GatherPidsTCPMap(*procfs, isCgroupsV2) - if err != nil { - log.Error().Err(err).Msg("tcp map lookup failed") - return - } + tcpMap, err := resolver.GatherPidsTCPMap(*procfs, isCgroupsV2) + if err != nil { + log.Error().Err(err).Msg("tcp map lookup failed") + return + } tracer = &Tracer{ procfs: *procfs, diff --git a/pkg/bpf/bpf.go b/pkg/bpf/bpf.go index ad107f4..949d46d 100644 --- a/pkg/bpf/bpf.go +++ b/pkg/bpf/bpf.go @@ -2,6 +2,8 @@ package bpf import ( "fmt" + "path/filepath" + "strings" "bytes" "os" @@ -11,11 +13,18 @@ import ( "github.com/cilium/ebpf/asm" "github.com/cilium/ebpf/features" "github.com/go-errors/errors" + "github.com/kubeshark/tracer/misc" "github.com/kubeshark/tracer/pkg/utils" "github.com/moby/moby/pkg/parsers/kernel" "github.com/rs/zerolog/log" ) +const ( + PinPath = "/sys/fs/bpf/kubeshark" + PinNamePlainPackets = "packets_plain" + PinNameTLSPackets = "packets_tls" +) + // TODO: cilium/ebpf does not support .kconfig Therefore; for now, we build object files per kernel version. //go:generate go run github.com/cilium/ebpf/cmd/bpf2go@v0.12.3 -target $BPF_TARGET -cflags $BPF_CFLAGS -type tls_chunk -type goid_offsets Tracer ../../bpf/tracer.c @@ -27,7 +36,7 @@ type BpfObjectsImpl struct { specs *ebpf.CollectionSpec } -func (objs *BpfObjectsImpl) loadBpfObjects(bpfConstants map[string]uint64, reader *bytes.Reader) error { +func (objs *BpfObjectsImpl) loadBpfObjects(bpfConstants map[string]uint64, mapReplacements map[string]*ebpf.Map, reader *bytes.Reader) error { var err error objs.specs, err = ebpf.LoadCollectionSpecFromReader(reader) @@ -44,7 +53,10 @@ func (objs *BpfObjectsImpl) loadBpfObjects(bpfConstants map[string]uint64, reade return err } - err = objs.specs.LoadAndAssign(objs.bpfObjs, nil) + opts := ebpf.CollectionOptions{ + MapReplacements: mapReplacements, + } + err = objs.specs.LoadAndAssign(objs.bpfObjs, &opts) if err != nil { var ve *ebpf.VerifierError if errors.As(err, &ve) { @@ -59,8 +71,7 @@ func (objs *BpfObjectsImpl) loadBpfObjects(bpfConstants map[string]uint64, reade } type BpfObjects struct { - BpfObjs TracerObjects - IsCgroupV2 bool + BpfObjs TracerObjects } func programHelperExists(pt ebpf.ProgramType, helper asm.BuiltinFunc) uint64 { @@ -82,15 +93,41 @@ func NewBpfObjects(disableEbpfCapture bool) (*BpfObjects, error) { } cgroupV1 := uint64(1) - objs.IsCgroupV2, err = utils.IsCgroupV2() + isCgroupV2, err := utils.IsCgroupV2() if err != nil { log.Error().Err(err).Msg("read cgroups information failed:") } - if objs.IsCgroupV2 { + if isCgroupV2 { cgroupV1 = 0 } - log.Info().Msg(fmt.Sprintf("Detected Linux kernel version: %s cgroups version2: %v", kernelVersion, objs.IsCgroupV2)) + mapReplacements := make(map[string]*ebpf.Map) + plainPath := filepath.Join(PinPath, PinNamePlainPackets) + tlsPath := filepath.Join(PinPath, PinNameTLSPackets) + + if !kernel.CheckKernelVersion(5, 4, 0) { + disableEbpfCapture = true + } + + markDisabledEBPF := func() error { + pathNoEbpf := filepath.Join(misc.GetDataDir(), "noebpf") + file, err := os.Create(pathNoEbpf) + if err != nil { + return err + } + file.Close() + return nil + } + + ebpfBackendStatus := "enabled" + if disableEbpfCapture { + ebpfBackendStatus = "disabled" + if err = markDisabledEBPF(); err != nil { + return nil, err + } + } + + log.Info().Msg(fmt.Sprintf("Detected Linux kernel version: %s cgroups version2: %v, eBPF backend %v", kernelVersion, isCgroupV2, ebpfBackendStatus)) kernelVersionInt := uint64(1_000_000)*uint64(kernelVersion.Kernel) + uint64(1_000)*uint64(kernelVersion.Major) + uint64(kernelVersion.Minor) // TODO: cilium/ebpf does not support .kconfig Therefore; for now, we load object files according to kernel version. @@ -127,16 +164,70 @@ func NewBpfObjects(disableEbpfCapture bool) (*BpfObjects, error) { "DISABLE_EBPF_CAPTURE": disableCapture, } - err = objects.loadBpfObjects(bpfConsts, bytes.NewReader(_TracerBytes)) + pktsBuffer, err := ebpf.LoadPinnedMap(plainPath, nil) if err == nil { - objs.BpfObjs = *objects.bpfObjs.(*TracerObjects) + mapReplacements["pkts_buffer"] = pktsBuffer + log.Info().Str("path", tlsPath).Msg("loaded plain packets buffer") + } else if !errors.Is(err, os.ErrNotExist) { + log.Error().Msg(fmt.Sprintf("load plain packets map failed: %v", err)) } - if err != nil { + chunksBuffer, err := ebpf.LoadPinnedMap(tlsPath, nil) + if err == nil { + mapReplacements["chunks_buffer"] = chunksBuffer + log.Info().Str("path", tlsPath).Msg("loaded tls packets buffer") + } else if !errors.Is(err, os.ErrNotExist) { + log.Error().Msg(fmt.Sprintf("load tls packets map failed: %v", err)) + } + + err = objects.loadBpfObjects(bpfConsts, mapReplacements, bytes.NewReader(_TracerBytes)) + if err == nil { + objs.BpfObjs = *objects.bpfObjs.(*TracerObjects) + } else if err != nil { log.Error().Msg(fmt.Sprintf("load bpf objects failed: %v", err)) return nil, err } } + // Pin packet perf maps: + + defer func() { + if os.IsPermission(err) || strings.Contains(fmt.Sprintf("%v", err), "permission") { + log.Warn().Msg(fmt.Sprintf("There are no enough permissions to activate eBPF. Error: %v", err)) + if err = markDisabledEBPF(); err != nil { + log.Error().Err(err).Msg("disable ebpf failed") + } else { + err = nil + } + } + }() + + if err = os.MkdirAll(PinPath, 0700); err != nil { + log.Error().Msg(fmt.Sprintf("mkdir pin path failed: %v", err)) + return nil, err + } + + pinMap := func(mapName, path string, mapObj *ebpf.Map) error { + if _, ok := mapReplacements[mapName]; !ok { + if err = mapObj.Pin(path); err != nil { + log.Error().Err(err).Str("path", path).Msg("pin perf buffer failed") + return err + } else { + log.Info().Str("path", path).Msg("pinned perf buffer") + } + } + return nil + } + + if !disableEbpfCapture { + if err = pinMap("pkts_buffer", plainPath, objs.BpfObjs.PktsBuffer); err != nil { + return nil, err + } + } + + if err = pinMap("chunks_buffer", tlsPath, objs.BpfObjs.ChunksBuffer); err != nil { + return nil, err + } + return &objs, nil } diff --git a/pkg/bpf/tls_poller.go b/pkg/bpf/tls_poller.go index fbad671..9749b9b 100644 --- a/pkg/bpf/tls_poller.go +++ b/pkg/bpf/tls_poller.go @@ -6,10 +6,13 @@ import ( "fmt" "os" "strconv" + "time" + "github.com/cilium/ebpf" "github.com/cilium/ebpf/perf" "github.com/go-errors/errors" "github.com/hashicorp/golang-lru/simplelru" + "github.com/kubeshark/gopacket" "github.com/kubeshark/tracer/misc" "github.com/kubeshark/tracer/pkg/utils" "github.com/rs/zerolog/log" @@ -20,24 +23,32 @@ const ( fdCacheMaxItems = 500000 / fdCachedItemAvgSize ) +type RawWriter func(timestamp uint64, cgroupId uint64, direction uint8, firstLayerType gopacket.LayerType, l ...gopacket.SerializableLayer) (err error) +type GopacketWriter func(packet gopacket.Packet) (err error) + type TlsPoller struct { - streams map[string]*TlsStream - closeStreams chan string - chunksReader *perf.Reader - fdCache *simplelru.LRU // Actual type is map[string]addressPair - evictedCounter int - Sorter *PacketSorter + streams map[string]*TlsStream + closeStreams chan string + chunksReader *perf.Reader + fdCache *simplelru.LRU // Actual type is map[string]addressPair + evictedCounter int + rawWriter RawWriter + gopacketWriter GopacketWriter + receivedPackets uint64 + lostChunks uint64 } func NewTlsPoller( - bpfObjs *BpfObjects, - sorter *PacketSorter, + perfBuffer *ebpf.Map, + rawWriter RawWriter, + gopacketWriter GopacketWriter, ) (*TlsPoller, error) { poller := &TlsPoller{ - streams: make(map[string]*TlsStream), - closeStreams: make(chan string, misc.TlsCloseChannelBufferSize), - chunksReader: nil, - Sorter: sorter, + streams: make(map[string]*TlsStream), + closeStreams: make(chan string, misc.TlsCloseChannelBufferSize), + chunksReader: nil, + rawWriter: rawWriter, + gopacketWriter: gopacketWriter, } fdCache, err := simplelru.NewLRU(fdCacheMaxItems, poller.fdCacheEvictCallback) @@ -46,7 +57,7 @@ func NewTlsPoller( } poller.fdCache = fdCache - poller.chunksReader, err = perf.NewReader(bpfObjs.BpfObjs.ChunksBuffer, os.Getpagesize()*10000) + poller.chunksReader, err = perf.NewReader(perfBuffer, os.Getpagesize()*10000) if err != nil { return nil, errors.Wrap(err, 0) @@ -84,9 +95,30 @@ func (p *TlsPoller) Start() { }() } +func (p *TlsPoller) GetLostChunks() uint64 { + return p.lostChunks +} + +func (p *TlsPoller) GetReceivedPackets() uint64 { + return p.receivedPackets +} + func (p *TlsPoller) pollChunksPerfBuffer(chunks chan<- *TracerTlsChunk) { log.Info().Msg("Start polling for tls events") + p.chunksReader.SetDeadline(time.Unix(1, 0)) + var emptyRecord perf.Record + for { + err := p.chunksReader.ReadInto(&emptyRecord) + if errors.Is(err, os.ErrDeadlineExceeded) { + break + } else if err != nil { + log.Error().Err(err).Msg("Error reading chunks from pkts perf, aborting!") + return + } + } + p.chunksReader.SetDeadline(time.Time{}) + for { record, err := p.chunksReader.Read() @@ -97,12 +129,13 @@ func (p *TlsPoller) pollChunksPerfBuffer(chunks chan<- *TracerTlsChunk) { return } - utils.LogError(errors.Errorf("Error reading chunks from tls perf, aborting TLS! %v", err)) + log.Error().Err(err).Msg("Error reading chunks from pkts perf, aborting!") return } if record.LostSamples != 0 { log.Info().Msg(fmt.Sprintf("Buffer is full, dropped %d chunks", record.LostSamples)) + p.lostChunks += record.LostSamples continue } @@ -111,7 +144,7 @@ func (p *TlsPoller) pollChunksPerfBuffer(chunks chan<- *TracerTlsChunk) { var chunk TracerTlsChunk if err := binary.Read(buffer, binary.LittleEndian, &chunk); err != nil { - utils.LogError(errors.Errorf("Error parsing chunk %v", err)) + log.Error().Err(err).Msg("Error parsing chunk") continue } diff --git a/pkg/bpf/tls_stream.go b/pkg/bpf/tls_stream.go index 4c7a613..7b43bd1 100644 --- a/pkg/bpf/tls_stream.go +++ b/pkg/bpf/tls_stream.go @@ -4,11 +4,13 @@ import ( "net" "strconv" "sync" + "time" "github.com/kubeshark/gopacket" "github.com/kubeshark/gopacket/layers" "github.com/kubeshark/tracer/misc" "github.com/kubeshark/tracer/misc/ethernet" + "github.com/kubeshark/tracerproto/pkg/unixpacket" "github.com/rs/zerolog/log" ) @@ -25,19 +27,33 @@ func (l *tlsLayers) swap() { } type TlsStream struct { - poller *TlsPoller - key string - id int64 - Client *tlsReader - Server *tlsReader - layers *tlsLayers + serializeOptions gopacket.SerializeOptions + ipv4Decoder gopacket.Decoder + poller *TlsPoller + key string + id int64 + Client *tlsReader + Server *tlsReader + layers *tlsLayers sync.Mutex } func NewTlsStream(poller *TlsPoller, key string) *TlsStream { + ipv4Decoder := gopacket.DecodersByLayerName["IPv4"] + if ipv4Decoder == nil { + log.Error().Msg("Failed to get IPv4 decoder") + return nil + } + serializeOptions := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + return &TlsStream{ - poller: poller, - key: key, + serializeOptions: serializeOptions, + ipv4Decoder: ipv4Decoder, + poller: poller, + key: key, } } @@ -93,26 +109,50 @@ func (t *TlsStream) writeData(timestamp uint64, cgroupId uint64, direction uint8 } func (t *TlsStream) writeLayers(timestamp uint64, cgroupId uint64, direction uint8, data []byte, isClient bool, sentLen uint32) { - t.writePacket( - timestamp, - cgroupId, - direction, - layers.LayerTypeEthernet, - t.layers.ethernet, - t.layers.ipv4, - t.layers.tcp, - gopacket.Payload(data), - ) - t.doTcpSeqAckWalk(isClient, sentLen) -} + t.poller.receivedPackets++ + if t.poller.rawWriter != nil { + err := t.poller.rawWriter( + timestamp, + cgroupId, + direction, + layers.LayerTypeEthernet, + t.layers.ethernet, + t.layers.ipv4, + t.layers.tcp, + gopacket.Payload(data), + ) + if err != nil { + log.Error().Err(err).Msg("Error writing PCAP:") + return + } + } -func (t *TlsStream) writePacket(timestamp uint64, cgroupId uint64, direction uint8, firstLayerType gopacket.LayerType, l ...gopacket.SerializableLayer) { + if t.poller.gopacketWriter != nil { + buf := gopacket.NewSerializeBuffer() - err := t.poller.Sorter.WriteTLSPacket(timestamp, cgroupId, direction, firstLayerType, l...) - if err != nil { - log.Error().Err(err).Msg("Error writing PCAP:") - return + err := gopacket.SerializeLayers(buf, t.serializeOptions, t.layers.ipv4, t.layers.tcp, gopacket.Payload(data)) + if err != nil { + log.Error().Err(err).Msg("Error serializing packet:") + return + } + + bufBytes := buf.Bytes() + pkt := gopacket.NewPacket(bufBytes, t.ipv4Decoder, gopacket.NoCopy, cgroupId, unixpacket.PacketDirection(direction)) + m := pkt.Metadata() + ci := &m.CaptureInfo + ci.Timestamp = time.Unix(0, int64(timestamp)) + ci.CaptureLength = len(bufBytes) + ci.Length = len(bufBytes) + ci.CaptureBackend = gopacket.CaptureBackendEbpfTls + + err = t.poller.gopacketWriter(pkt) + if err != nil { + log.Error().Err(err).Msg("Error writing gopacket:") + return + } } + + t.doTcpSeqAckWalk(isClient, sentLen) } func (t *TlsStream) loadSecNumbers(isClient bool) { diff --git a/pkg/discoverer/discoverer.go b/pkg/discoverer/discoverer.go index 0eb423d..a90525a 100644 --- a/pkg/discoverer/discoverer.go +++ b/pkg/discoverer/discoverer.go @@ -28,7 +28,6 @@ type InternalEventsDiscoverer interface { type InternalEventsDiscovererImpl struct { bpfObjects *bpf.BpfObjects - isCgroupV2 bool sslHooks map[string]sslHooks.SslHooks perfFoundOpenssl *ebpf.Map perfFoundCgroup *ebpf.Map @@ -42,7 +41,6 @@ type InternalEventsDiscovererImpl struct { func NewInternalEventsDiscoverer(procfs string, bpfObjects *bpf.BpfObjects, cgroupsController cgroup.CgroupsController) InternalEventsDiscoverer { impl := InternalEventsDiscovererImpl{ bpfObjects: bpfObjects, - isCgroupV2: bpfObjects.IsCgroupV2, perfFoundOpenssl: bpfObjects.BpfObjs.PerfFoundOpenssl, perfFoundCgroup: bpfObjects.BpfObjs.PerfFoundCgroup, sslHooks: make(map[string]sslHooks.SslHooks), @@ -88,7 +86,7 @@ func (e *InternalEventsDiscovererImpl) Start() error { e.scanExistingCgroups(isCgroupV2) - if err = e.pids.scanExistingPIDs(e.isCgroupV2); err != nil { + if err = e.pids.scanExistingPIDs(isCgroupV2); err != nil { return errors.Wrap(err, 0) } diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go new file mode 100644 index 0000000..fbb8309 --- /dev/null +++ b/pkg/packet/packet.go @@ -0,0 +1,120 @@ +package packet + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/cilium/ebpf" + "github.com/kubeshark/gopacket" + "github.com/kubeshark/tracer/pkg/bpf" + "github.com/kubeshark/tracer/pkg/poller/packets" +) + +var ( + ErrNotSupported = errors.New("source is not supported") +) + +type PacketData struct { + Timestamp uint64 + Data []byte +} + +type PacketSource interface { + NextPacket() (gopacket.Packet, error) + Start() error + Stop() error + Stats() (packetsGot, packetsLost uint64) +} + +type PacketsPoller interface { + Start() + Stop() error + GetReceivedPackets() uint64 + GetLostChunks() uint64 +} + +type PacketSourceImpl struct { + perfBuffer *ebpf.Map + poller PacketsPoller + pktCh chan gopacket.Packet +} + +type createPollerFunc func(*ebpf.Map, bpf.RawWriter, bpf.GopacketWriter) (PacketsPoller, error) + +func newPacketSource(perfName string, createPoller createPollerFunc, pathNotSupported string) (PacketSource, error) { + path := filepath.Join(bpf.PinPath, perfName) + + var err error + var perfBuffer *ebpf.Map + for { + perfBuffer, err = ebpf.LoadPinnedMap(path, nil) + if errors.Is(err, os.ErrNotExist) { + if file, errStat := os.Open(pathNotSupported); errStat == nil { + return nil, ErrNotSupported + } else { + file.Close() + } + time.Sleep(100 * time.Millisecond) + } else if err != nil { + return nil, err + } else { + break + } + } + + p := PacketSourceImpl{ + perfBuffer: perfBuffer, + pktCh: make(chan gopacket.Packet), + } + + if p.poller, err = createPoller(p.perfBuffer, nil, p.WritePacket); err != nil { + return nil, fmt.Errorf("poller create failed: %v", err) + } + + return &p, nil +} + +func NewTLSPacketSource(dataDir string) (PacketSource, error) { + poller := func(m *ebpf.Map, wr bpf.RawWriter, goWr bpf.GopacketWriter) (PacketsPoller, error) { + return bpf.NewTlsPoller(m, wr, goWr) + } + + return newPacketSource(bpf.PinNameTLSPackets, poller, "") +} + +func NewPlainPacketSource(dataDir string) (PacketSource, error) { + poller := func(m *ebpf.Map, wr bpf.RawWriter, goWr bpf.GopacketWriter) (PacketsPoller, error) { + return packets.NewPacketsPoller(m, wr, goWr) + } + + return newPacketSource(bpf.PinNamePlainPackets, poller, filepath.Join(dataDir, "noebpf")) +} + +func (p *PacketSourceImpl) WritePacket(pkt gopacket.Packet) error { + p.pktCh <- pkt + return nil +} + +func (p *PacketSourceImpl) Start() error { + p.poller.Start() + return nil +} + +func (p *PacketSourceImpl) Stop() error { + return p.poller.Stop() +} + +func (p *PacketSourceImpl) Stats() (packetsGot, packetsLost uint64) { + packetsGot = p.poller.GetReceivedPackets() + // Using chunks instead of packets: + packetsLost = p.poller.GetLostChunks() + return +} + +func (p *PacketSourceImpl) NextPacket() (gopacket.Packet, error) { + pkt := <-p.pktCh + return pkt, nil +} diff --git a/pkg/poller/packets/packets_poller.go b/pkg/poller/packets/packets_poller.go index beb7471..926783f 100644 --- a/pkg/poller/packets/packets_poller.go +++ b/pkg/poller/packets/packets_poller.go @@ -7,6 +7,7 @@ import ( "time" "unsafe" + "github.com/cilium/ebpf" "github.com/cilium/ebpf/perf" "github.com/go-errors/errors" @@ -14,7 +15,7 @@ import ( "github.com/kubeshark/gopacket/layers" "github.com/kubeshark/tracer/misc/ethernet" "github.com/kubeshark/tracer/pkg/bpf" - "github.com/kubeshark/tracer/pkg/utils" + "github.com/kubeshark/tracerproto/pkg/unixpacket" "github.com/rs/zerolog/log" ) @@ -44,25 +45,38 @@ type pktBuffer struct { } type PacketsPoller struct { - ethhdr *layers.Ethernet - mtx sync.Mutex - chunksReader *perf.Reader - sorter *bpf.PacketSorter - pktsMap map[uint64]*pktBuffer // packet id to packet + ipv4Decoder gopacket.Decoder + ethhdr *layers.Ethernet + mtx sync.Mutex + chunksReader *perf.Reader + rawWriter bpf.RawWriter + gopacketWriter bpf.GopacketWriter + pktsMap map[uint64]*pktBuffer // packet id to packet + receivedPackets uint64 + lostChunks uint64 } func NewPacketsPoller( - bpfObjs *bpf.BpfObjects, - sorter *bpf.PacketSorter, + perfBuffer *ebpf.Map, + rawWriter bpf.RawWriter, + gopacketWriter bpf.GopacketWriter, ) (*PacketsPoller, error) { var err error + + ipv4Decoder := gopacket.DecodersByLayerName["IPv4"] + if ipv4Decoder == nil { + return nil, errors.New("Failed to get IPv4 decoder") + } + poller := &PacketsPoller{ - ethhdr: ethernet.NewEthernetLayer(layers.EthernetTypeIPv4), - sorter: sorter, - pktsMap: make(map[uint64]*pktBuffer), + ipv4Decoder: ipv4Decoder, + ethhdr: ethernet.NewEthernetLayer(layers.EthernetTypeIPv4), + rawWriter: rawWriter, + gopacketWriter: gopacketWriter, + pktsMap: make(map[uint64]*pktBuffer), } - poller.chunksReader, err = perf.NewReader(bpfObjs.BpfObjs.PktsBuffer, os.Getpagesize()*10000) + poller.chunksReader, err = perf.NewReader(perfBuffer, os.Getpagesize()*10000) if err != nil { return nil, errors.Wrap(err, 0) @@ -79,6 +93,14 @@ func (p *PacketsPoller) Start() { go p.poll() } +func (p *PacketsPoller) GetLostChunks() uint64 { + return p.lostChunks +} + +func (p *PacketsPoller) GetReceivedPackets() uint64 { + return p.receivedPackets +} + func (p *PacketsPoller) poll() { // tracerPktsChunk is generated by bpf2go. @@ -90,8 +112,14 @@ func (p *PacketsPoller) handlePktChunk(chunk tracerPktChunk) error { p.mtx.Lock() defer p.mtx.Unlock() - const expectedChunkSize = 4148 data := chunk.buf + if len(data) == 4 { + // zero packet to reset + log.Info().Msg("Resetting plain packets buffer") + p.pktsMap = make(map[uint64]*pktBuffer) + return nil + } + const expectedChunkSize = 4148 if len(data) != expectedChunkSize { return fmt.Errorf("bad pkt chunk: size %v expected: %v", len(data), expectedChunkSize) } @@ -111,9 +139,27 @@ func (p *PacketsPoller) handlePktChunk(chunk tracerPktChunk) error { pkts.len += uint32(ptr.Len) if ptr.Last != 0 { - err := p.sorter.WritePlanePacket(ptr.Timestamp, ptr.CgroupID, ptr.Direction, layers.LayerTypeEthernet, p.ethhdr, gopacket.Payload(pkts.buf[:pkts.len])) - if err != nil { - return err + p.receivedPackets++ + if p.rawWriter != nil { + err := p.rawWriter(ptr.Timestamp, ptr.CgroupID, ptr.Direction, layers.LayerTypeEthernet, p.ethhdr, gopacket.Payload(pkts.buf[:pkts.len])) + if err != nil { + return err + } + } + + if p.gopacketWriter != nil { + pkt := gopacket.NewPacket(pkts.buf[:pkts.len], p.ipv4Decoder, gopacket.NoCopy, ptr.CgroupID, unixpacket.PacketDirection(ptr.Direction)) + m := pkt.Metadata() + ci := &m.CaptureInfo + ci.Timestamp = time.Unix(0, int64(ptr.Timestamp)) + ci.CaptureLength = int(pkts.len) + ci.Length = int(pkts.len) + ci.CaptureBackend = gopacket.CaptureBackendEbpf + + err := p.gopacketWriter(pkt) + if err != nil { + return err + } } delete(p.pktsMap, ptr.ID) @@ -125,7 +171,22 @@ func (p *PacketsPoller) handlePktChunk(chunk tracerPktChunk) error { } func (p *PacketsPoller) pollChunksPerfBuffer() { - log.Info().Msg("Start polling for tls events") + log.Info().Msg("Start polling for packet events") + + // remove all existing records + + p.chunksReader.SetDeadline(time.Unix(1, 0)) + var emptyRecord perf.Record + for { + err := p.chunksReader.ReadInto(&emptyRecord) + if errors.Is(err, os.ErrDeadlineExceeded) { + break + } else if err != nil { + log.Error().Err(err).Msg("Error reading chunks from pkts perf, aborting!") + return + } + } + p.chunksReader.SetDeadline(time.Time{}) for { record, err := p.chunksReader.Read() @@ -135,11 +196,12 @@ func (p *PacketsPoller) pollChunksPerfBuffer() { return } - utils.LogError(errors.Errorf("Error reading chunks from pkts perf, aborting! %v", err)) + log.Error().Err(err).Msg("Error reading chunks from pkts perf, aborting!") return } if record.LostSamples != 0 { log.Info().Msg(fmt.Sprintf("Buffer is full, dropped %d pkt chunks", record.LostSamples)) + p.lostChunks++ continue } diff --git a/pkg/poller/poller.go b/pkg/poller/poller.go index 7b37312..98c4210 100644 --- a/pkg/poller/poller.go +++ b/pkg/poller/poller.go @@ -6,7 +6,6 @@ import ( "github.com/kubeshark/tracer/pkg/bpf" "github.com/kubeshark/tracer/pkg/cgroup" logPoller "github.com/kubeshark/tracer/pkg/poller/log" - packetsPoller "github.com/kubeshark/tracer/pkg/poller/packets" syscallPoller "github.com/kubeshark/tracer/pkg/poller/syscall" ) @@ -16,9 +15,7 @@ type BpfPoller interface { } type BpfPollerImpl struct { - tlsPoller *bpf.TlsPoller syscallPoller *syscallPoller.SyscallEventsTracer - packetsPoller *packetsPoller.PacketsPoller logPoller *logPoller.BpfLogger } @@ -26,18 +23,10 @@ func NewBpfPoller(bpfObjs *bpf.BpfObjects, sorter *bpf.PacketSorter, cgroupsCont var err error p := BpfPollerImpl{} - if p.tlsPoller, err = bpf.NewTlsPoller(bpfObjs, sorter); err != nil { - return nil, fmt.Errorf("create tls poller failed: %v", err) - } - if p.syscallPoller, err = syscallPoller.NewSyscallEventsTracer(bpfObjs, cgroupsController); err != nil { return nil, fmt.Errorf("create syscall poller failed: %v", err) } - if p.packetsPoller, err = packetsPoller.NewPacketsPoller(bpfObjs, sorter); err != nil { - return nil, fmt.Errorf("create packets poller failed: %v", err) - } - if p.logPoller, err = logPoller.NewBpfLogger(&bpfObjs.BpfObjs, tlsLogDisabled); err != nil { return nil, fmt.Errorf("create log poller failed: %v", err) } @@ -46,27 +35,17 @@ func NewBpfPoller(bpfObjs *bpf.BpfObjects, sorter *bpf.PacketSorter, cgroupsCont } func (p *BpfPollerImpl) Start() { - p.tlsPoller.Start() p.syscallPoller.Start() - p.packetsPoller.Start() p.logPoller.Start() } func (p *BpfPollerImpl) Stop() error { var err error - if err = p.tlsPoller.Stop(); err != nil { - return fmt.Errorf("stop tls poller failed: %v", err) - } - if err = p.syscallPoller.Stop(); err != nil { return fmt.Errorf("stop syscall poller failed: %v", err) } - if err = p.packetsPoller.Stop(); err != nil { - return fmt.Errorf("stop packets poller failed: %v", err) - } - if err = p.logPoller.Stop(); err != nil { return fmt.Errorf("stop log poller failed: %v", err) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 9e04d09..8d807cc 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -2,10 +2,12 @@ package utils import ( + "fmt" "github.com/go-errors/errors" "github.com/rs/zerolog/log" "golang.org/x/sys/unix" "os" + "path/filepath" "syscall" ) @@ -37,3 +39,23 @@ func GetInode(path string) (uint64, error) { return stat_t.Ino, nil } + +func RemoveAllFilesInDir(dir string) (removedFiles []string, err error) { + files, err := os.ReadDir(dir) + if err != nil { + err = fmt.Errorf("failed to read directory %s: %w", dir, err) + return + } + + for _, file := range files { + filePath := filepath.Join(dir, file.Name()) + if err = os.Remove(filePath); err != nil { + err = fmt.Errorf("failed to remove %s: %w", filePath, err) + return + } else { + removedFiles = append(removedFiles, filePath) + } + } + + return +} diff --git a/tracer.go b/tracer.go index 867eda6..5789916 100644 --- a/tracer.go +++ b/tracer.go @@ -14,6 +14,7 @@ import ( packetHooks "github.com/kubeshark/tracer/pkg/hooks/packet" syscallHooks "github.com/kubeshark/tracer/pkg/hooks/syscall" "github.com/kubeshark/tracer/pkg/poller" + "github.com/kubeshark/tracer/pkg/utils" "github.com/rs/zerolog/log" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" @@ -67,6 +68,7 @@ func (t *Tracer) Init( if err != nil { return fmt.Errorf("creating bpf failed: %v", err) } + t.eventsDiscoverer = discoverer.NewInternalEventsDiscoverer(procfs, t.bpfObjects, t.cgroupsController) if err := t.eventsDiscoverer.Start(); err != nil { log.Error().Msg(fmt.Sprintf("start internal discovery failed: %v", err)) @@ -211,3 +213,20 @@ func getContainerIDs(pod *v1.Pod) []string { return containerIDs } + +func initBPFSubsystem() { + // Cleanup is required in case map set or format is changed in the new tracer version + if files, err := utils.RemoveAllFilesInDir(bpf.PinPath); err != nil { + log.Error().Str("path", bpf.PinPath).Err(err).Msg("directory cleanup failed") + } else { + for _, file := range files { + log.Info().Str("path", file).Msg("removed bpf entry") + } + } + + _, err := bpf.NewBpfObjects(false) + if err != nil { + log.Error().Err(err).Msg("create objects failed") + } + +}