Skip to content

Commit

Permalink
Merge pull request #90 from projectdiscovery/use-bart-algorithm-speedup
Browse files Browse the repository at this point in the history
feat: use bart algorithm for less memory usage + speed improvements
  • Loading branch information
Mzack9999 authored Jun 17, 2024
2 parents fad8266 + 779f2ec commit 496a450
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 30 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module github.com/projectdiscovery/networkpolicy
go 1.21

require (
github.com/gaissmai/bart v0.9.5
github.com/projectdiscovery/utils v0.0.82
github.com/stretchr/testify v1.9.0
github.com/yl2chen/cidranger v1.0.2
)

require (
github.com/aymerick/douceur v0.2.0 // indirect
github.com/bits-and-blooms/bitset v1.13.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/kr/text v0.2.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gaissmai/bart v0.9.5 h1:vy+r4Px6bjZ+v2QYXAsg63vpz9IfzdW146A8Cn4GPIo=
github.com/gaissmai/bart v0.9.5/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
Expand Down
64 changes: 34 additions & 30 deletions networkpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package networkpolicy

import (
"net"
"net/netip"
"regexp"
"strconv"

"github.com/gaissmai/bart"
iputil "github.com/projectdiscovery/utils/ip"
urlutil "github.com/projectdiscovery/utils/url"
"github.com/yl2chen/cidranger"
)

func init() {
Expand All @@ -31,10 +32,12 @@ type Options struct {
var DefaultOptions Options

type NetworkPolicy struct {
Options *Options
hasFilters bool
DenyRanger cidranger.Ranger
AllowRanger cidranger.Ranger
Options *Options
hasFilters bool

DenyRanger *bart.Table[net.IP]
AllowRanger *bart.Table[net.IP]

AllowRules map[string]*regexp.Regexp
DenyRules map[string]*regexp.Regexp
AllowSchemeList map[string]struct{}
Expand All @@ -58,16 +61,15 @@ func New(options Options) (*NetworkPolicy, error) {
allowRules := make(map[string]*regexp.Regexp)
denyRules := make(map[string]*regexp.Regexp)

var allowRanger cidranger.Ranger
var allowRanger *bart.Table[net.IP]
if len(options.AllowList) > 0 {
allowRanger = cidranger.NewPCTrieRanger()
allowRanger = new(bart.Table[net.IP])

for _, r := range options.AllowList {
// handle if ip/cidr
cidr, err := asCidr(r)
if err == nil {
if err := allowRanger.Insert(cidranger.NewBasicRangerEntry(*cidr)); err != nil {
return nil, err
}
allowRanger.Insert(cidr, nil)
continue
}

Expand All @@ -80,16 +82,15 @@ func New(options Options) (*NetworkPolicy, error) {
}
}

var denyRanger cidranger.Ranger
var denyRanger *bart.Table[net.IP]
if len(options.DenyList) > 0 {
denyRanger = cidranger.NewPCTrieRanger()
denyRanger = new(bart.Table[net.IP])

for _, r := range options.DenyList {
// handle if ip/cidr
cidr, err := asCidr(r)
if err == nil {
if err := denyRanger.Insert(cidranger.NewBasicRangerEntry(*cidr)); err != nil {
return nil, err
}
denyRanger.Insert(cidr, nil)
continue
}

Expand Down Expand Up @@ -125,13 +126,17 @@ func (r NetworkPolicy) Validate(host string) bool {

// check if it's an ip
if iputil.IsIP(host) {
IP := net.ParseIP(host)
if r.DenyRanger != nil && r.DenyRanger.Len() > 0 && rangerContains(r.DenyRanger, IP) {
parsed, err := netip.ParseAddr(host)
if err != nil {
return false
}

if r.DenyRanger != nil && r.DenyRanger.Size() > 0 && rangerContains(r.DenyRanger, parsed) {
return false
}

if r.AllowRanger != nil && r.AllowRanger.Len() > 0 {
return rangerContains(r.AllowRanger, IP)
if r.AllowRanger != nil && r.AllowRanger.Size() > 0 {
return rangerContains(r.AllowRanger, parsed)
}

return true
Expand Down Expand Up @@ -209,15 +214,15 @@ func (r NetworkPolicy) ValidateURLWithIP(host string, ip string) bool {
}

func (r NetworkPolicy) ValidateAddress(IP string) bool {
IPDest := net.ParseIP(IP)
if IPDest == nil {
IPDest, err := netip.ParseAddr(IP)
if err != nil {
return false
}
if r.DenyRanger != nil && r.DenyRanger.Len() > 0 && rangerContains(r.DenyRanger, IPDest) {
if r.DenyRanger != nil && r.DenyRanger.Size() > 0 && rangerContains(r.DenyRanger, IPDest) {
return false
}

if r.AllowRanger != nil && r.AllowRanger.Len() > 0 {
if r.AllowRanger != nil && r.AllowRanger.Size() > 0 {
return rangerContains(r.AllowRanger, IPDest)
}

Expand All @@ -240,21 +245,20 @@ func (r NetworkPolicy) ValidatePort(port int) bool {
return true
}

func asCidr(s string) (*net.IPNet, error) {
func asCidr(s string) (netip.Prefix, error) {
if iputil.IsIP(s) {
s += "/32"
}
_, cidr, err := net.ParseCIDR(s)
cidr, err := netip.ParsePrefix(s)
if err != nil {
return nil, err
return cidr, err
}

return cidr, nil
}

func rangerContains(ranger cidranger.Ranger, IP net.IP) bool {
ok, err := ranger.Contains(IP)
return ok && err == nil
func rangerContains(ranger *bart.Table[net.IP], IP netip.Addr) bool {
_, ok := ranger.Lookup(IP)
return ok
}

func portIsListed(list map[int]struct{}, port int) bool {
Expand Down
33 changes: 33 additions & 0 deletions networkpolicy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package networkpolicy

import (
"log"
"net"
"net/netip"
"testing"

"github.com/gaissmai/bart"
"github.com/stretchr/testify/require"
"github.com/yl2chen/cidranger"
)

func TestValidateAddress(t *testing.T) {
Expand Down Expand Up @@ -49,3 +53,32 @@ func TestMultipleCases(t *testing.T) {
require.Equal(t, tc.expectedValid, ok, "Unexpected result for address: "+tc.address)
}
}

func Benchmark_Networkpolicy_CIDRRanger(b *testing.B) {
for i := 0; i < b.N; i++ {
ranger := cidranger.NewPCTrieRanger()
for _, r := range DefaultIPv4DenylistRanges {
_, cidr, _ := net.ParseCIDR(r)
_ = ranger.Insert(cidranger.NewBasicRangerEntry(*cidr))
}
contains, err := ranger.Contains(net.ParseIP("127.0.0.1"))
if err != nil || !contains {
b.Fatalf("unexpected error: %v %v", err, contains)
}
}
}

func Benchmark_Networkpolicy_BartAlgorithm(b *testing.B) {
for i := 0; i < b.N; i++ {
rtbl := new(bart.Table[net.IP])
for _, r := range DefaultIPv4DenylistRanges {
parsed, _ := netip.ParsePrefix(r)
rtbl.Insert(parsed, nil)
}

_, contains := rtbl.Lookup(netip.MustParseAddr("127.0.0.1"))
if !contains {
b.Fatalf("expected to contain")
}
}
}

0 comments on commit 496a450

Please sign in to comment.