Skip to content

Commit

Permalink
templates: retrofit to avoid issue when upgrading dep (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
lspgn authored Apr 21, 2023
1 parent e3f58f2 commit 69a6eaf
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 13 deletions.
113 changes: 106 additions & 7 deletions decoders/netflow/netflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,46 @@ import (
"context"
"encoding/binary"
"fmt"
"sync"

"github.com/netsampler/goflow2/decoders/netflow/templates"
"github.com/netsampler/goflow2/decoders/utils"
)

type FlowBaseTemplateSet map[uint16]map[uint32]map[uint16]interface{}

type NetFlowTemplateSystem interface {
GetTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error)
AddTemplate(version uint16, obsDomainId uint32, template interface{})
}

// Transition structure to ease the conversion with the new template systems
type TemplateWrapper struct {
Ctx context.Context
Key string
Inner templates.TemplateInterface
}

func (w *TemplateWrapper) getTemplateId(template interface{}) (templateId uint16) {
switch templateIdConv := template.(type) {
case IPFIXOptionsTemplateRecord:
templateId = templateIdConv.TemplateId
case NFv9OptionsTemplateRecord:
templateId = templateIdConv.TemplateId
case TemplateRecord:
templateId = templateIdConv.TemplateId
}
return templateId
}

func (w TemplateWrapper) GetTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error) {
return w.Inner.GetTemplate(w.Ctx, &templates.TemplateKey{w.Key, version, obsDomainId, templateId})
}

func (w TemplateWrapper) AddTemplate(version uint16, obsDomainId uint32, template interface{}) {
w.Inner.AddTemplate(w.Ctx, &templates.TemplateKey{w.Key, version, obsDomainId, w.getTemplateId(template)}, template)
}

func DecodeNFv9OptionsTemplateSet(payload *bytes.Buffer) ([]NFv9OptionsTemplateRecord, error) {
var records []NFv9OptionsTemplateRecord
var err error
Expand Down Expand Up @@ -243,11 +278,70 @@ func DecodeDataSet(version uint16, payload *bytes.Buffer, listFields []Field) ([
return records, nil
}

func DecodeMessage(payload *bytes.Buffer, templates templates.TemplateInterface) (interface{}, error) {
func (ts *BasicTemplateSystem) GetTemplates() map[uint16]map[uint32]map[uint16]interface{} {
ts.templateslock.RLock()
tmp := ts.templates
ts.templateslock.RUnlock()
return tmp
}

func (ts *BasicTemplateSystem) AddTemplate(version uint16, obsDomainId uint32, template interface{}) {
ts.templateslock.Lock()
defer ts.templateslock.Unlock()
_, exists := ts.templates[version]
if exists != true {
ts.templates[version] = make(map[uint32]map[uint16]interface{})
}
_, exists = ts.templates[version][obsDomainId]
if exists != true {
ts.templates[version][obsDomainId] = make(map[uint16]interface{})
}
var templateId uint16
switch templateIdConv := template.(type) {
case IPFIXOptionsTemplateRecord:
templateId = templateIdConv.TemplateId
case NFv9OptionsTemplateRecord:
templateId = templateIdConv.TemplateId
case TemplateRecord:
templateId = templateIdConv.TemplateId
}
ts.templates[version][obsDomainId][templateId] = template
}

func (ts *BasicTemplateSystem) GetTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error) {
ts.templateslock.RLock()
defer ts.templateslock.RUnlock()
templatesVersion, okver := ts.templates[version]
if okver {
templatesObsDom, okobs := templatesVersion[obsDomainId]
if okobs {
template, okid := templatesObsDom[templateId]
if okid {
return template, nil
}
}
}
return nil, NewErrorTemplateNotFound(version, obsDomainId, templateId, "info")
}

type BasicTemplateSystem struct {
templates FlowBaseTemplateSet
templateslock *sync.RWMutex
}

func CreateTemplateSystem() *BasicTemplateSystem {
ts := &BasicTemplateSystem{
templates: make(FlowBaseTemplateSet),
templateslock: &sync.RWMutex{},
}
return ts
}

func DecodeMessage(payload *bytes.Buffer, templates NetFlowTemplateSystem) (interface{}, error) {
return DecodeMessageContext(context.Background(), payload, "", templates)
}

func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKey string, tpli templates.TemplateInterface) (interface{}, error) {
func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKey string, tpli NetFlowTemplateSystem) (interface{}, error) {
var size uint16
packetNFv9 := NFv9Packet{}
packetIPFIX := IPFIXPacket{}
Expand Down Expand Up @@ -309,7 +403,8 @@ func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKe

if tpli != nil {
for _, record := range records {
tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
tpli.AddTemplate(version, obsDomainId, record)
//tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
}
}

Expand All @@ -327,7 +422,8 @@ func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKe

if tpli != nil {
for _, record := range records {
tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
tpli.AddTemplate(version, obsDomainId, record)
//tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
}
}

Expand All @@ -345,7 +441,8 @@ func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKe

if tpli != nil {
for _, record := range records {
tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
tpli.AddTemplate(version, obsDomainId, record)
//tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
}
}

Expand All @@ -363,7 +460,8 @@ func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKe

if tpli != nil {
for _, record := range records {
tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
tpli.AddTemplate(version, obsDomainId, record)
//tpli.AddTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, record.TemplateId), record)
}
}

Expand All @@ -374,7 +472,8 @@ func DecodeMessageContext(ctx context.Context, payload *bytes.Buffer, templateKe
continue
}

template, err := tpli.GetTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, fsheader.Id))
template, err := tpli.GetTemplate(version, obsDomainId, fsheader.Id)
//template, err := tpli.GetTemplate(ctx, templates.NewTemplateKey(templateKey, version, obsDomainId, fsheader.Id))

if err == nil {
switch templatec := template.(type) {
Expand Down
6 changes: 1 addition & 5 deletions decoders/netflow/netflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@ package netflow

import (
"bytes"
"context"
"testing"

"github.com/netsampler/goflow2/decoders/netflow/templates/memory"

"github.com/stretchr/testify/assert"
)

func TestDecodeNetFlowV9(t *testing.T) {
templates := &memory.MemoryDriver{}
templates.Init(context.Background())
templates := CreateTemplateSystem()

// Decode a template
template := []byte{
Expand Down
2 changes: 1 addition & 1 deletion utils/netflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (s *StateNetFlow) DecodeFlow(msg interface{}) error {
}

timeTrackStart := time.Now()
msgDec, err := netflow.DecodeMessageContext(s.ctx, buf, key, s.TemplateSystem)
msgDec, err := netflow.DecodeMessageContext(s.ctx, buf, key, netflow.TemplateWrapper{s.ctx, key, s.TemplateSystem})
if err != nil {
switch err.(type) {
case *netflow.ErrorTemplateNotFound:
Expand Down

0 comments on commit 69a6eaf

Please sign in to comment.