diff --git a/go.mod b/go.mod index 38005c8..ed188dc 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/projectdiscovery/networkpolicy go 1.21 require ( + github.com/gaissmai/bart v0.9.5 github.com/projectdiscovery/utils v0.0.82 github.com/stretchr/testify v1.9.0 github.com/yl2chen/cidranger v1.0.2 @@ -10,6 +11,7 @@ require ( require ( github.com/aymerick/douceur v0.2.0 // indirect + github.com/bits-and-blooms/bitset v1.13.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gorilla/css v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index 3facd31..fca923b 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,13 @@ github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gaissmai/bart v0.9.5 h1:vy+r4Px6bjZ+v2QYXAsg63vpz9IfzdW146A8Cn4GPIo= +github.com/gaissmai/bart v0.9.5/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= diff --git a/networkpolicy.go b/networkpolicy.go index 7ba7d45..4ba2fe7 100644 --- a/networkpolicy.go +++ b/networkpolicy.go @@ -2,12 +2,13 @@ package networkpolicy import ( "net" + "net/netip" "regexp" "strconv" + "github.com/gaissmai/bart" iputil "github.com/projectdiscovery/utils/ip" urlutil "github.com/projectdiscovery/utils/url" - "github.com/yl2chen/cidranger" ) func init() { @@ -31,10 +32,12 @@ type Options struct { var DefaultOptions Options type NetworkPolicy struct { - Options *Options - hasFilters bool - DenyRanger cidranger.Ranger - AllowRanger cidranger.Ranger + Options *Options + hasFilters bool + + DenyRanger *bart.Table[net.IP] + AllowRanger *bart.Table[net.IP] + AllowRules map[string]*regexp.Regexp DenyRules map[string]*regexp.Regexp AllowSchemeList map[string]struct{} @@ -58,16 +61,15 @@ func New(options Options) (*NetworkPolicy, error) { allowRules := make(map[string]*regexp.Regexp) denyRules := make(map[string]*regexp.Regexp) - var allowRanger cidranger.Ranger + var allowRanger *bart.Table[net.IP] if len(options.AllowList) > 0 { - allowRanger = cidranger.NewPCTrieRanger() + allowRanger = new(bart.Table[net.IP]) + for _, r := range options.AllowList { // handle if ip/cidr cidr, err := asCidr(r) if err == nil { - if err := allowRanger.Insert(cidranger.NewBasicRangerEntry(*cidr)); err != nil { - return nil, err - } + allowRanger.Insert(cidr, nil) continue } @@ -80,16 +82,15 @@ func New(options Options) (*NetworkPolicy, error) { } } - var denyRanger cidranger.Ranger + var denyRanger *bart.Table[net.IP] if len(options.DenyList) > 0 { - denyRanger = cidranger.NewPCTrieRanger() + denyRanger = new(bart.Table[net.IP]) + for _, r := range options.DenyList { // handle if ip/cidr cidr, err := asCidr(r) if err == nil { - if err := denyRanger.Insert(cidranger.NewBasicRangerEntry(*cidr)); err != nil { - return nil, err - } + denyRanger.Insert(cidr, nil) continue } @@ -125,13 +126,17 @@ func (r NetworkPolicy) Validate(host string) bool { // check if it's an ip if iputil.IsIP(host) { - IP := net.ParseIP(host) - if r.DenyRanger != nil && r.DenyRanger.Len() > 0 && rangerContains(r.DenyRanger, IP) { + parsed, err := netip.ParseAddr(host) + if err != nil { + return false + } + + if r.DenyRanger != nil && r.DenyRanger.Size() > 0 && rangerContains(r.DenyRanger, parsed) { return false } - if r.AllowRanger != nil && r.AllowRanger.Len() > 0 { - return rangerContains(r.AllowRanger, IP) + if r.AllowRanger != nil && r.AllowRanger.Size() > 0 { + return rangerContains(r.AllowRanger, parsed) } return true @@ -209,15 +214,15 @@ func (r NetworkPolicy) ValidateURLWithIP(host string, ip string) bool { } func (r NetworkPolicy) ValidateAddress(IP string) bool { - IPDest := net.ParseIP(IP) - if IPDest == nil { + IPDest, err := netip.ParseAddr(IP) + if err != nil { return false } - if r.DenyRanger != nil && r.DenyRanger.Len() > 0 && rangerContains(r.DenyRanger, IPDest) { + if r.DenyRanger != nil && r.DenyRanger.Size() > 0 && rangerContains(r.DenyRanger, IPDest) { return false } - if r.AllowRanger != nil && r.AllowRanger.Len() > 0 { + if r.AllowRanger != nil && r.AllowRanger.Size() > 0 { return rangerContains(r.AllowRanger, IPDest) } @@ -240,21 +245,20 @@ func (r NetworkPolicy) ValidatePort(port int) bool { return true } -func asCidr(s string) (*net.IPNet, error) { +func asCidr(s string) (netip.Prefix, error) { if iputil.IsIP(s) { s += "/32" } - _, cidr, err := net.ParseCIDR(s) + cidr, err := netip.ParsePrefix(s) if err != nil { - return nil, err + return cidr, err } - return cidr, nil } -func rangerContains(ranger cidranger.Ranger, IP net.IP) bool { - ok, err := ranger.Contains(IP) - return ok && err == nil +func rangerContains(ranger *bart.Table[net.IP], IP netip.Addr) bool { + _, ok := ranger.Lookup(IP) + return ok } func portIsListed(list map[int]struct{}, port int) bool { diff --git a/networkpolicy_test.go b/networkpolicy_test.go index 0e9515d..8a84d96 100644 --- a/networkpolicy_test.go +++ b/networkpolicy_test.go @@ -2,9 +2,13 @@ package networkpolicy import ( "log" + "net" + "net/netip" "testing" + "github.com/gaissmai/bart" "github.com/stretchr/testify/require" + "github.com/yl2chen/cidranger" ) func TestValidateAddress(t *testing.T) { @@ -49,3 +53,32 @@ func TestMultipleCases(t *testing.T) { require.Equal(t, tc.expectedValid, ok, "Unexpected result for address: "+tc.address) } } + +func Benchmark_Networkpolicy_CIDRRanger(b *testing.B) { + for i := 0; i < b.N; i++ { + ranger := cidranger.NewPCTrieRanger() + for _, r := range DefaultIPv4DenylistRanges { + _, cidr, _ := net.ParseCIDR(r) + _ = ranger.Insert(cidranger.NewBasicRangerEntry(*cidr)) + } + contains, err := ranger.Contains(net.ParseIP("127.0.0.1")) + if err != nil || !contains { + b.Fatalf("unexpected error: %v %v", err, contains) + } + } +} + +func Benchmark_Networkpolicy_BartAlgorithm(b *testing.B) { + for i := 0; i < b.N; i++ { + rtbl := new(bart.Table[net.IP]) + for _, r := range DefaultIPv4DenylistRanges { + parsed, _ := netip.ParsePrefix(r) + rtbl.Insert(parsed, nil) + } + + _, contains := rtbl.Lookup(netip.MustParseAddr("127.0.0.1")) + if !contains { + b.Fatalf("expected to contain") + } + } +}