From 66b47ee328feef8cad42bfe704f6f26368def25e Mon Sep 17 00:00:00 2001 From: Vincent Bernat Date: Sat, 19 Aug 2023 04:38:00 +0200 Subject: [PATCH] decoders: replace binary.Read with a version without reflection and allocations (#141) Instead of allocating small slices, we rely on the fact that most call sites are providing a `bytes.Buffer` and use the `Next()` method. For sFlow decoding, in my case, we get a 33% speedup. A `bytes.Reader` would even be more efficient, but unfortunately, they don't have a `Next()` method. While Go should be smart enough to make the allocation of `bs` on the stack, it does not, even when `io.ReadFull()` is inlines. ``` decoders/utils/utils.go:23:13: make([]byte, n) escapes to heap ``` --- decoders/netflow/netflow.go | 4 +- decoders/netflowlegacy/netflow.go | 23 ++++- decoders/sflow/sflow.go | 82 ++++++++++++------ decoders/utils/utils.go | 116 ++++++++++++++++++++++++- decoders/utils/utils_test.go | 138 ++++++++++++++++++++++++++++++ producer/producer_nf.go | 7 +- 6 files changed, 338 insertions(+), 32 deletions(-) create mode 100644 decoders/utils/utils_test.go diff --git a/decoders/netflow/netflow.go b/decoders/netflow/netflow.go index cd17bc81..e01529f1 100644 --- a/decoders/netflow/netflow.go +++ b/decoders/netflow/netflow.go @@ -349,7 +349,7 @@ func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKe var version uint16 var obsDomainId uint32 - if err := binary.Read(payload, binary.BigEndian, &version); err != nil { + if err := utils.BinaryRead(payload, binary.BigEndian, &version); err != nil { return nil, fmt.Errorf("Error decoding version: %v", err) } @@ -377,7 +377,7 @@ func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKe for i := 0; ((i < int(size) && version == 9) || version == 10) && payload.Len() > 0; i++ { fsheader := FlowSetHeader{} - if err := utils.BinaryDecoder(payload, &fsheader); err != nil { + if err := utils.BinaryDecoder(payload, &fsheader.Id, &fsheader.Length); err != nil { return returnItem, fmt.Errorf("Error decoding FlowSet header: %v", err) } diff --git a/decoders/netflowlegacy/netflow.go b/decoders/netflowlegacy/netflow.go index 68f9d4d2..9329f703 100644 --- a/decoders/netflowlegacy/netflow.go +++ b/decoders/netflowlegacy/netflow.go @@ -47,7 +47,28 @@ func DecodeMessage(payload *bytes.Buffer) (interface{}, error) { packet.Records = make([]RecordsNetFlowV5, int(packet.Count)) for i := 0; i < int(packet.Count) && payload.Len() >= 48; i++ { record := RecordsNetFlowV5{} - err := utils.BinaryDecoder(payload, &record) + err := utils.BinaryDecoder(payload, + &record.SrcAddr, + &record.DstAddr, + &record.NextHop, + &record.Input, + &record.Output, + &record.DPkts, + &record.DOctets, + &record.First, + &record.Last, + &record.SrcPort, + &record.DstPort, + &record.Pad1, + &record.TCPFlags, + &record.Proto, + &record.Tos, + &record.SrcAS, + &record.DstAS, + &record.SrcMask, + &record.DstMask, + &record.Pad2, + ) if err != nil { return packet, err } diff --git a/decoders/sflow/sflow.go b/decoders/sflow/sflow.go index 42514438..dd096301 100644 --- a/decoders/sflow/sflow.go +++ b/decoders/sflow/sflow.go @@ -81,14 +81,48 @@ func DecodeCounterRecord(header *RecordHeader, payload *bytes.Buffer) (CounterRe switch (*header).DataFormat { case 1: ifCounters := IfCounters{} - err := utils.BinaryDecoder(payload, &ifCounters) + err := utils.BinaryDecoder(payload, + &ifCounters.IfIndex, + &ifCounters.IfType, + &ifCounters.IfSpeed, + &ifCounters.IfDirection, + &ifCounters.IfStatus, + &ifCounters.IfInOctets, + &ifCounters.IfInUcastPkts, + &ifCounters.IfInMulticastPkts, + &ifCounters.IfInBroadcastPkts, + &ifCounters.IfInDiscards, + &ifCounters.IfInErrors, + &ifCounters.IfInUnknownProtos, + &ifCounters.IfOutOctets, + &ifCounters.IfOutUcastPkts, + &ifCounters.IfOutMulticastPkts, + &ifCounters.IfOutBroadcastPkts, + &ifCounters.IfOutDiscards, + &ifCounters.IfOutErrors, + &ifCounters.IfPromiscuousMode, + ) if err != nil { return counterRecord, err } counterRecord.Data = ifCounters case 2: ethernetCounters := EthernetCounters{} - err := utils.BinaryDecoder(payload, ðernetCounters) + err := utils.BinaryDecoder(payload, + ðernetCounters.Dot3StatsAlignmentErrors, + ðernetCounters.Dot3StatsFCSErrors, + ðernetCounters.Dot3StatsSingleCollisionFrames, + ðernetCounters.Dot3StatsMultipleCollisionFrames, + ðernetCounters.Dot3StatsSQETestErrors, + ðernetCounters.Dot3StatsDeferredTransmissions, + ðernetCounters.Dot3StatsLateCollisions, + ðernetCounters.Dot3StatsExcessiveCollisions, + ðernetCounters.Dot3StatsInternalMacTransmitErrors, + ðernetCounters.Dot3StatsCarrierSenseErrors, + ðernetCounters.Dot3StatsFrameTooLongs, + ðernetCounters.Dot3StatsInternalMacReceiveErrors, + ðernetCounters.Dot3StatsSymbolErrors, + ) if err != nil { return counterRecord, err } @@ -117,7 +151,7 @@ func DecodeIP(payload *bytes.Buffer) (uint32, []byte, error) { return ipVersion, ip, NewErrorIPVersion(ipVersion) } if payload.Len() >= len(ip) { - err := utils.BinaryDecoder(payload, &ip) + err := utils.BinaryDecoder(payload, ip) if err != nil { return 0, nil, err } @@ -134,14 +168,14 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, switch (*header).DataFormat { case FORMAT_EXT_SWITCH: extendedSwitch := ExtendedSwitch{} - err := utils.BinaryDecoder(payload, &extendedSwitch) + err := utils.BinaryDecoder(payload, &extendedSwitch.SrcVlan, &extendedSwitch.SrcPriority, &extendedSwitch.DstVlan, &extendedSwitch.DstPriority) if err != nil { return flowRecord, err } flowRecord.Data = extendedSwitch case FORMAT_RAW_PKT: sampledHeader := SampledHeader{} - err := utils.BinaryDecoder(payload, &(sampledHeader.Protocol), &(sampledHeader.FrameLength), &(sampledHeader.Stripped), &(sampledHeader.OriginalLength)) + err := utils.BinaryDecoder(payload, &sampledHeader.Protocol, &sampledHeader.FrameLength, &sampledHeader.Stripped, &sampledHeader.OriginalLength) if err != nil { return flowRecord, err } @@ -152,7 +186,7 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, SrcIP: make([]byte, 4), DstIP: make([]byte, 4), } - err := utils.BinaryDecoder(payload, &sampledIPBase) + err := utils.BinaryDecoder(payload, &sampledIPBase.Length, &sampledIPBase.Protocol, sampledIPBase.SrcIP, sampledIPBase.DstIP, &sampledIPBase.SrcPort, &sampledIPBase.DstPort, &sampledIPBase.TcpFlags) if err != nil { return flowRecord, err } @@ -169,14 +203,14 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, SrcIP: make([]byte, 16), DstIP: make([]byte, 16), } - err := utils.BinaryDecoder(payload, &sampledIPBase) + err := utils.BinaryDecoder(payload, &sampledIPBase.Length, &sampledIPBase.Protocol, sampledIPBase.SrcIP, sampledIPBase.DstIP, &sampledIPBase.SrcPort, &sampledIPBase.DstPort, &sampledIPBase.TcpFlags) if err != nil { return flowRecord, err } sampledIPv6 := SampledIPv6{ Base: sampledIPBase, } - err = utils.BinaryDecoder(payload, &(sampledIPv6.Priority)) + err = utils.BinaryDecoder(payload, &sampledIPv6.Priority) if err != nil { return flowRecord, err } @@ -190,7 +224,7 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, } extendedRouter.NextHopIPVersion = ipVersion extendedRouter.NextHop = ip - err = utils.BinaryDecoder(payload, &(extendedRouter.SrcMaskLen), &(extendedRouter.DstMaskLen)) + err = utils.BinaryDecoder(payload, &extendedRouter.SrcMaskLen, &extendedRouter.DstMaskLen) if err != nil { return flowRecord, err } @@ -203,14 +237,14 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, } extendedGateway.NextHopIPVersion = ipVersion extendedGateway.NextHop = ip - err = utils.BinaryDecoder(payload, &(extendedGateway.AS), &(extendedGateway.SrcAS), &(extendedGateway.SrcPeerAS), - &(extendedGateway.ASDestinations)) + err = utils.BinaryDecoder(payload, &extendedGateway.AS, &extendedGateway.SrcAS, &extendedGateway.SrcPeerAS, + &extendedGateway.ASDestinations) if err != nil { return flowRecord, err } var asPath []uint32 if extendedGateway.ASDestinations != 0 { - err := utils.BinaryDecoder(payload, &(extendedGateway.ASPathType), &(extendedGateway.ASPathLength)) + err := utils.BinaryDecoder(payload, &extendedGateway.ASPathType, &extendedGateway.ASPathLength) if err != nil { return flowRecord, err } @@ -227,7 +261,7 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, } extendedGateway.ASPath = asPath - err = utils.BinaryDecoder(payload, &(extendedGateway.CommunitiesLength)) + err = utils.BinaryDecoder(payload, &extendedGateway.CommunitiesLength) if err != nil { return flowRecord, err } @@ -241,7 +275,7 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, return flowRecord, err } } - err = utils.BinaryDecoder(payload, &(extendedGateway.LocalPref)) + err = utils.BinaryDecoder(payload, &extendedGateway.LocalPref) if err != nil { return flowRecord, err } @@ -258,10 +292,10 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, } func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, error) { - format := (*header).Format + format := header.Format var sample interface{} - err := utils.BinaryDecoder(payload, &((*header).SampleSequenceNumber)) + err := utils.BinaryDecoder(payload, &header.SampleSequenceNumber) if err != nil { return sample, err } @@ -275,7 +309,7 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err (*header).SourceIdType = sourceId >> 24 (*header).SourceIdValue = sourceId & 0x00ffffff } else if format == FORMAT_IPV4 || format == FORMAT_IPV6 { - err = utils.BinaryDecoder(payload, &((*header).SourceIdType), &((*header).SourceIdValue)) + err = utils.BinaryDecoder(payload, &header.SourceIdType, &header.SourceIdValue) if err != nil { return sample, err } @@ -291,8 +325,8 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err flowSample = FlowSample{ Header: *header, } - err = utils.BinaryDecoder(payload, &(flowSample.SamplingRate), &(flowSample.SamplePool), - &(flowSample.Drops), &(flowSample.Input), &(flowSample.Output), &(flowSample.FlowRecordsCount)) + err = utils.BinaryDecoder(payload, &flowSample.SamplingRate, &flowSample.SamplePool, + &flowSample.Drops, &flowSample.Input, &flowSample.Output, &flowSample.FlowRecordsCount) if err != nil { return sample, err } @@ -314,9 +348,9 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err expandedFlowSample = ExpandedFlowSample{ Header: *header, } - err = utils.BinaryDecoder(payload, &(expandedFlowSample.SamplingRate), &(expandedFlowSample.SamplePool), - &(expandedFlowSample.Drops), &(expandedFlowSample.InputIfFormat), &(expandedFlowSample.InputIfValue), - &(expandedFlowSample.OutputIfFormat), &(expandedFlowSample.OutputIfValue), &(expandedFlowSample.FlowRecordsCount)) + err = utils.BinaryDecoder(payload, &expandedFlowSample.SamplingRate, &expandedFlowSample.SamplePool, + &expandedFlowSample.Drops, &expandedFlowSample.InputIfFormat, &expandedFlowSample.InputIfValue, + &expandedFlowSample.OutputIfFormat, &expandedFlowSample.OutputIfValue, &expandedFlowSample.FlowRecordsCount) if err != nil { return sample, err } @@ -326,7 +360,7 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err } for i := 0; i < int(recordsCount) && payload.Len() >= 8; i++ { recordHeader := RecordHeader{} - err = utils.BinaryDecoder(payload, &(recordHeader.DataFormat), &(recordHeader.Length)) + err = utils.BinaryDecoder(payload, &recordHeader.DataFormat, &recordHeader.Length) if err != nil { return sample, err } @@ -386,7 +420,7 @@ func DecodeMessage(payload *bytes.Buffer) (interface{}, error) { } packetV5.AgentIP = ip - err = utils.BinaryDecoder(payload, &(packetV5.SubAgentId), &(packetV5.SequenceNumber), &(packetV5.Uptime), &(packetV5.SamplesCount)) + err = utils.BinaryDecoder(payload, &packetV5.SubAgentId, &packetV5.SequenceNumber, &packetV5.Uptime, &packetV5.SamplesCount) if err != nil { return packetV5, err } diff --git a/decoders/utils/utils.go b/decoders/utils/utils.go index a36e3b2b..9c8e597f 100644 --- a/decoders/utils/utils.go +++ b/decoders/utils/utils.go @@ -1,16 +1,128 @@ package utils import ( + "bytes" "encoding/binary" + "errors" "io" + "reflect" ) -func BinaryDecoder(payload io.Reader, dests ...interface{}) error { +type BytesBuffer interface { + io.Reader + Next(int) []byte +} + +func BinaryDecoder(payload *bytes.Buffer, dests ...interface{}) error { for _, dest := range dests { - err := binary.Read(payload, binary.BigEndian, dest) + err := BinaryRead(payload, binary.BigEndian, dest) if err != nil { return err } } return nil } +func BinaryRead(payload BytesBuffer, order binary.ByteOrder, data any) error { + // Fast path for basic types and slices. + if n := intDataSize(data); n != 0 { + bs := payload.Next(n) + if len(bs) < n { + return io.ErrUnexpectedEOF + } + switch data := data.(type) { + case *bool: + *data = bs[0] != 0 + case *int8: + *data = int8(bs[0]) + case *uint8: + *data = bs[0] + case *int16: + *data = int16(order.Uint16(bs)) + case *uint16: + *data = order.Uint16(bs) + case *int32: + *data = int32(order.Uint32(bs)) + case *uint32: + *data = order.Uint32(bs) + case *int64: + *data = int64(order.Uint64(bs)) + case *uint64: + *data = order.Uint64(bs) + case []bool: + for i, x := range bs { // Easier to loop over the input for 8-bit values. + data[i] = x != 0 + } + case []int8: + for i, x := range bs { + data[i] = int8(x) + } + case []uint8: + copy(data, bs) + case []int16: + for i := range data { + data[i] = int16(order.Uint16(bs[2*i:])) + } + case []uint16: + for i := range data { + data[i] = order.Uint16(bs[2*i:]) + } + case []int32: + for i := range data { + data[i] = int32(order.Uint32(bs[4*i:])) + } + case []uint32: + for i := range data { + data[i] = order.Uint32(bs[4*i:]) + } + case []int64: + for i := range data { + data[i] = int64(order.Uint64(bs[8*i:])) + } + case []uint64: + for i := range data { + data[i] = order.Uint64(bs[8*i:]) + } + default: + n = 0 // fast path doesn't apply + } + if n != 0 { + return nil + } + } + + return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String()) +} + +// intDataSize returns the size of the data required to represent the data when encoded. +// It returns zero if the type cannot be implemented by the fast path in Read or Write. +func intDataSize(data any) int { + switch data := data.(type) { + case bool, int8, uint8, *bool, *int8, *uint8: + return 1 + case []bool: + return len(data) + case []int8: + return len(data) + case []uint8: + return len(data) + case int16, uint16, *int16, *uint16: + return 2 + case []int16: + return 2 * len(data) + case []uint16: + return 2 * len(data) + case int32, uint32, *int32, *uint32: + return 4 + case []int32: + return 4 * len(data) + case []uint32: + return 4 * len(data) + case int64, uint64, *int64, *uint64: + return 8 + case []int64: + return 8 * len(data) + case []uint64: + return 8 * len(data) + } + return 0 +} diff --git a/decoders/utils/utils_test.go b/decoders/utils/utils_test.go new file mode 100644 index 00000000..c9f598b4 --- /dev/null +++ b/decoders/utils/utils_test.go @@ -0,0 +1,138 @@ +package utils + +import ( + "encoding/binary" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testBinaryRead(buf BytesBuffer, data any) error { + order := binary.BigEndian + return BinaryRead(buf, order, data) +} + +func testBinaryReadComparison(buf BytesBuffer, data any) error { + order := binary.BigEndian + return binary.Read(buf, order, data) +} + +type benchFct func(buf BytesBuffer, data any) error + +func TestBinaryReadInteger(t *testing.T) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + var dest uint32 + err := testBinaryRead(buf, &dest) + require.NoError(t, err) + assert.Equal(t, uint32(0x1020304), dest) +} + +func TestBinaryReadBytes(t *testing.T) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + dest := make([]byte, 4) + err := testBinaryRead(buf, dest) + require.NoError(t, err) +} + +func TestBinaryReadUints(t *testing.T) { + buf := newTestBuf([]byte{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}) + dest := make([]uint32, 4) + err := testBinaryRead(buf, dest) + require.NoError(t, err) + assert.Equal(t, uint32(0x1020304), dest[0]) +} + +type testBuf struct { + buf []byte + off int +} + +func newTestBuf(data []byte) *testBuf { + return &testBuf{ + buf: data, + } +} + +func (b *testBuf) Next(n int) []byte { + if n > len(b.buf) { + return b.buf + } + return b.buf[0:n] +} + +func (b *testBuf) Reset() { + b.off = 0 +} + +func (b *testBuf) Read(p []byte) (int, error) { + if len(b.buf) == 0 || b.off >= len(b.buf) { + return 0, io.EOF + } + + n := copy(p, b.buf[b.off:]) + b.off += n + return n, nil +} + +func benchBinaryRead(b *testing.B, buf *testBuf, dest any, cmp bool) { + var fct benchFct + if cmp { + fct = testBinaryReadComparison + } else { + fct = testBinaryRead + } + for n := 0; n < b.N; n++ { + fct(buf, dest) + buf.Reset() + } +} + +func BenchmarkBinaryReadIntegerBase(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + var dest uint32 + benchBinaryRead(b, buf, &dest, false) +} + +func BenchmarkBinaryReadIntegerComparison(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + var dest uint32 + benchBinaryRead(b, buf, &dest, true) +} + +func BenchmarkBinaryReadByteBase(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + var dest byte + benchBinaryRead(b, buf, &dest, false) +} + +func BBenchmarkBinaryReadByteComparison(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + var dest byte + benchBinaryRead(b, buf, &dest, true) +} + +func BenchmarkBinaryReadBytesBase(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + dest := make([]byte, 4) + benchBinaryRead(b, buf, dest, false) +} + +func BenchmarkBinaryReadBytesComparison(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4}) + dest := make([]byte, 4) + benchBinaryRead(b, buf, dest, true) +} + +func BenchmarkBinaryReadUintsBase(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}) + dest := make([]uint32, 4) + benchBinaryRead(b, buf, dest, false) +} + +func BenchmarkBinaryReadUintsComparison(b *testing.B) { + buf := newTestBuf([]byte{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}) + dest := make([]uint32, 4) + benchBinaryRead(b, buf, dest, true) +} diff --git a/producer/producer_nf.go b/producer/producer_nf.go index 07a00f31..d1493655 100644 --- a/producer/producer_nf.go +++ b/producer/producer_nf.go @@ -10,6 +10,7 @@ import ( "time" "github.com/netsampler/goflow2/decoders/netflow" + "github.com/netsampler/goflow2/decoders/utils" flowmessage "github.com/netsampler/goflow2/pb" ) @@ -79,18 +80,18 @@ func NetFlowPopulate(dataFields []netflow.DataField, typeId uint16, addr interfa exists, value := NetFlowLookFor(dataFields, typeId) if exists && value != nil { valueBytes, ok := value.([]byte) - valueReader := bytes.NewReader(valueBytes) + valueReader := bytes.NewBuffer(valueBytes) if ok { switch addrt := addr.(type) { case *(net.IP): *addrt = valueBytes case *(time.Time): t := uint64(0) - binary.Read(valueReader, binary.BigEndian, &t) + utils.BinaryRead(valueReader, binary.BigEndian, &t) t64 := int64(t / 1000) *addrt = time.Unix(t64, 0) default: - binary.Read(valueReader, binary.BigEndian, addr) + utils.BinaryRead(valueReader, binary.BigEndian, addr) } } }