Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use bart algorithm for less memory usage + speed improvements #90

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module github.com/projectdiscovery/networkpolicy
go 1.21

require (
github.com/gaissmai/bart v0.9.5
github.com/projectdiscovery/utils v0.0.82
github.com/stretchr/testify v1.9.0
github.com/yl2chen/cidranger v1.0.2
)

require (
github.com/aymerick/douceur v0.2.0 // indirect
github.com/bits-and-blooms/bitset v1.13.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/kr/text v0.2.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gaissmai/bart v0.9.5 h1:vy+r4Px6bjZ+v2QYXAsg63vpz9IfzdW146A8Cn4GPIo=
github.com/gaissmai/bart v0.9.5/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
Expand Down
64 changes: 34 additions & 30 deletions networkpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package networkpolicy

import (
"net"
"net/netip"
"regexp"
"strconv"

"github.com/gaissmai/bart"
iputil "github.com/projectdiscovery/utils/ip"
urlutil "github.com/projectdiscovery/utils/url"
"github.com/yl2chen/cidranger"
)

func init() {
Expand All @@ -31,10 +32,12 @@ type Options struct {
var DefaultOptions Options

type NetworkPolicy struct {
Options *Options
hasFilters bool
DenyRanger cidranger.Ranger
AllowRanger cidranger.Ranger
Options *Options
hasFilters bool

DenyRanger *bart.Table[net.IP]
AllowRanger *bart.Table[net.IP]

AllowRules map[string]*regexp.Regexp
DenyRules map[string]*regexp.Regexp
AllowSchemeList map[string]struct{}
Expand All @@ -58,16 +61,15 @@ func New(options Options) (*NetworkPolicy, error) {
allowRules := make(map[string]*regexp.Regexp)
denyRules := make(map[string]*regexp.Regexp)

var allowRanger cidranger.Ranger
var allowRanger *bart.Table[net.IP]
if len(options.AllowList) > 0 {
allowRanger = cidranger.NewPCTrieRanger()
allowRanger = new(bart.Table[net.IP])

for _, r := range options.AllowList {
// handle if ip/cidr
cidr, err := asCidr(r)
if err == nil {
if err := allowRanger.Insert(cidranger.NewBasicRangerEntry(*cidr)); err != nil {
return nil, err
}
allowRanger.Insert(cidr, nil)
continue
}

Expand All @@ -80,16 +82,15 @@ func New(options Options) (*NetworkPolicy, error) {
}
}

var denyRanger cidranger.Ranger
var denyRanger *bart.Table[net.IP]
if len(options.DenyList) > 0 {
denyRanger = cidranger.NewPCTrieRanger()
denyRanger = new(bart.Table[net.IP])

for _, r := range options.DenyList {
// handle if ip/cidr
cidr, err := asCidr(r)
if err == nil {
if err := denyRanger.Insert(cidranger.NewBasicRangerEntry(*cidr)); err != nil {
return nil, err
}
denyRanger.Insert(cidr, nil)
continue
}

Expand Down Expand Up @@ -125,13 +126,17 @@ func (r NetworkPolicy) Validate(host string) bool {

// check if it's an ip
if iputil.IsIP(host) {
IP := net.ParseIP(host)
if r.DenyRanger != nil && r.DenyRanger.Len() > 0 && rangerContains(r.DenyRanger, IP) {
parsed, err := netip.ParseAddr(host)
if err != nil {
return false
}

if r.DenyRanger != nil && r.DenyRanger.Size() > 0 && rangerContains(r.DenyRanger, parsed) {
return false
}

if r.AllowRanger != nil && r.AllowRanger.Len() > 0 {
return rangerContains(r.AllowRanger, IP)
if r.AllowRanger != nil && r.AllowRanger.Size() > 0 {
return rangerContains(r.AllowRanger, parsed)
}

return true
Expand Down Expand Up @@ -209,15 +214,15 @@ func (r NetworkPolicy) ValidateURLWithIP(host string, ip string) bool {
}

func (r NetworkPolicy) ValidateAddress(IP string) bool {
IPDest := net.ParseIP(IP)
if IPDest == nil {
IPDest, err := netip.ParseAddr(IP)
if err != nil {
return false
}
if r.DenyRanger != nil && r.DenyRanger.Len() > 0 && rangerContains(r.DenyRanger, IPDest) {
if r.DenyRanger != nil && r.DenyRanger.Size() > 0 && rangerContains(r.DenyRanger, IPDest) {
return false
}

if r.AllowRanger != nil && r.AllowRanger.Len() > 0 {
if r.AllowRanger != nil && r.AllowRanger.Size() > 0 {
return rangerContains(r.AllowRanger, IPDest)
}

Expand All @@ -240,21 +245,20 @@ func (r NetworkPolicy) ValidatePort(port int) bool {
return true
}

func asCidr(s string) (*net.IPNet, error) {
func asCidr(s string) (netip.Prefix, error) {
if iputil.IsIP(s) {
s += "/32"
}
_, cidr, err := net.ParseCIDR(s)
cidr, err := netip.ParsePrefix(s)
if err != nil {
return nil, err
return cidr, err
}

return cidr, nil
}

func rangerContains(ranger cidranger.Ranger, IP net.IP) bool {
ok, err := ranger.Contains(IP)
return ok && err == nil
func rangerContains(ranger *bart.Table[net.IP], IP netip.Addr) bool {
_, ok := ranger.Lookup(IP)
return ok
}

func portIsListed(list map[int]struct{}, port int) bool {
Expand Down
33 changes: 33 additions & 0 deletions networkpolicy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package networkpolicy

import (
"log"
"net"
"net/netip"
"testing"

"github.com/gaissmai/bart"
"github.com/stretchr/testify/require"
"github.com/yl2chen/cidranger"
)

func TestValidateAddress(t *testing.T) {
Expand Down Expand Up @@ -49,3 +53,32 @@ func TestMultipleCases(t *testing.T) {
require.Equal(t, tc.expectedValid, ok, "Unexpected result for address: "+tc.address)
}
}

func Benchmark_Networkpolicy_CIDRRanger(b *testing.B) {
for i := 0; i < b.N; i++ {
ranger := cidranger.NewPCTrieRanger()
for _, r := range DefaultIPv4DenylistRanges {
_, cidr, _ := net.ParseCIDR(r)
_ = ranger.Insert(cidranger.NewBasicRangerEntry(*cidr))
}
contains, err := ranger.Contains(net.ParseIP("127.0.0.1"))
if err != nil || !contains {
b.Fatalf("unexpected error: %v %v", err, contains)
}
}
}

func Benchmark_Networkpolicy_BartAlgorithm(b *testing.B) {
for i := 0; i < b.N; i++ {
rtbl := new(bart.Table[net.IP])
for _, r := range DefaultIPv4DenylistRanges {
parsed, _ := netip.ParsePrefix(r)
rtbl.Insert(parsed, nil)
}

_, contains := rtbl.Lookup(netip.MustParseAddr("127.0.0.1"))
if !contains {
b.Fatalf("expected to contain")
}
}
}
Loading