diff --git a/pkg/util/collectionutil/map.go b/pkg/util/collectionutil/map.go index 0d55042f58..d214844920 100644 --- a/pkg/util/collectionutil/map.go +++ b/pkg/util/collectionutil/map.go @@ -16,6 +16,14 @@ limitations under the License. package fnutil +func Keys[K comparable, V any](m map[K]V) []K { + rv := make([]K, 0, len(m)) + for k := range m { + rv = append(rv, k) + } + return rv +} + func Values[K comparable, V any](m map[K]V) []V { rv := make([]V, 0, len(m)) for _, v := range m { diff --git a/pkg/util/iputil/bits.go b/pkg/util/iputil/bits.go new file mode 100644 index 0000000000..dbad47343d --- /dev/null +++ b/pkg/util/iputil/bits.go @@ -0,0 +1,37 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iputil + +// setBitAt sets the bit at the i-th position in the byte slice to the given value. +// Panics if the index is out of bounds. +// For example, +// - setBitAt([0x00, 0x00], 8, 1) returns [0x00, 0b1000_0000]. +// - setBitAt([0xff, 0xff], 0, 0) returns [0b0111_1111, 0xff]. +func setBitAt(bytes []byte, i int, bit uint8) { + if bit == 1 { + bytes[i/8] |= 1 << (7 - i%8) + } else { + bytes[i/8] &^= 1 << (7 - i%8) + } +} + +// bitAt returns the bit at the i-th position in the byte slice. +// The return value is either 0 or 1 as uint8. +// Panics if the index is out of bounds. +func bitAt(bytes []byte, i int) uint8 { + return bytes[i/8] >> (7 - i%8) & 1 +} diff --git a/pkg/util/iputil/bits_test.go b/pkg/util/iputil/bits_test.go new file mode 100644 index 0000000000..351e83eb86 --- /dev/null +++ b/pkg/util/iputil/bits_test.go @@ -0,0 +1,103 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iputil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_bitAt(t *testing.T) { + bytes := []byte{0b1010_1010, 0b0101_0101} + assert.Equal(t, uint8(1), bitAt(bytes, 0)) + assert.Equal(t, uint8(0), bitAt(bytes, 1)) + assert.Equal(t, uint8(1), bitAt(bytes, 2)) + assert.Equal(t, uint8(0), bitAt(bytes, 3)) + + assert.Equal(t, uint8(1), bitAt(bytes, 4)) + assert.Equal(t, uint8(0), bitAt(bytes, 5)) + assert.Equal(t, uint8(1), bitAt(bytes, 6)) + assert.Equal(t, uint8(0), bitAt(bytes, 7)) + + assert.Equal(t, uint8(0), bitAt(bytes, 8)) + assert.Equal(t, uint8(1), bitAt(bytes, 9)) + assert.Equal(t, uint8(0), bitAt(bytes, 10)) + assert.Equal(t, uint8(1), bitAt(bytes, 11)) + + assert.Equal(t, uint8(0), bitAt(bytes, 12)) + assert.Equal(t, uint8(1), bitAt(bytes, 13)) + assert.Equal(t, uint8(0), bitAt(bytes, 14)) + assert.Equal(t, uint8(1), bitAt(bytes, 15)) + + assert.Panics(t, func() { bitAt(bytes, 16) }) +} + +func Test_setBitAt(t *testing.T) { + tests := []struct { + name string + initial []byte + index int + bit uint8 + expected []byte + }{ + { + name: "Set first bit to 1", + initial: []byte{0b0000_0000}, + index: 0, + bit: 1, + expected: []byte{0b1000_0000}, + }, + { + name: "Set last bit to 1", + initial: []byte{0b0000_0000}, + index: 7, + bit: 1, + expected: []byte{0b0000_0001}, + }, + { + name: "Set middle bit to 1", + initial: []byte{0b0000_0000}, + index: 4, + bit: 1, + expected: []byte{0b0000_1000}, + }, + { + name: "Set bit to 0", + initial: []byte{0b1111_1111}, + index: 3, + bit: 0, + expected: []byte{0b1110_1111}, + }, + { + name: "Set bit in second byte", + initial: []byte{0b0000_0000, 0b0000_0000}, + index: 9, + bit: 1, + expected: []byte{0b0000_0000, 0b0100_0000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setBitAt(tt.initial, tt.index, tt.bit) + assert.Equal(t, tt.expected, tt.initial) + }) + } + + assert.Panics(t, func() { setBitAt([]byte{0x00}, 8, 1) }) +} diff --git a/pkg/util/iputil/prefix.go b/pkg/util/iputil/prefix.go index 77395c37a8..2c4a8e0298 100644 --- a/pkg/util/iputil/prefix.go +++ b/pkg/util/iputil/prefix.go @@ -17,8 +17,10 @@ limitations under the License. package iputil import ( + "bytes" "fmt" "net/netip" + "sort" ) // IsPrefixesAllowAll returns true if one of the prefixes allows all addresses. @@ -61,9 +63,108 @@ func GroupPrefixesByFamily(vs []netip.Prefix) ([]netip.Prefix, []netip.Prefix) { return v4, v6 } -// AggregatePrefixes aggregates prefixes. -// Overlapping prefixes are merged. +// ContainsPrefix checks if prefix p fully contains prefix o. +// It returns true if o is a subset of p, meaning all addresses in o are also in p. +// This is true when p overlaps with o and p has fewer or equal number of bits than o. +func ContainsPrefix(p netip.Prefix, o netip.Prefix) bool { + return p.Bits() <= o.Bits() && p.Overlaps(o) +} + +// mergeAdjacentPrefixes attempts to merge two adjacent prefixes into a single prefix. +// It returns the merged prefix and a boolean indicating success. +// Note: This function only merges adjacent prefixes, not overlapping ones. +func mergeAdjacentPrefixes(p1, p2 netip.Prefix) (netip.Prefix, bool) { + // Merge neighboring prefixes if possible + if p1.Bits() != p2.Bits() || p1.Bits() == 0 { + return netip.Prefix{}, false + } + + var ( + bits = p1.Bits() + p1Bytes = p1.Addr().AsSlice() + p2Bytes = p2.Addr().AsSlice() + ) + if bitAt(p1Bytes, bits-1) == 0 { + setBitAt(p1Bytes, bits-1, 1) + } else { + setBitAt(p2Bytes, bits-1, 1) + } + if !bytes.Equal(p1Bytes, p2Bytes) { + return netip.Prefix{}, false + } + + rv, _ := p1.Addr().Prefix(bits - 1) + return rv, true +} + +// aggregatePrefixesForSingleIPFamily merges overlapping or adjacent prefixes into a single prefix. +// The input prefixes must be the same IP family (IPv4 or IPv6). +// For example, +// - [192.168.0.0/32, 192.168.0.1/32] -> [192.168.0.0/31] (adjacent) +// - [192.168.0.0/24, 192.168.0.1/32] -> [192.168.1.0/24] (overlapping) +func aggregatePrefixesForSingleIPFamily(prefixes []netip.Prefix) []netip.Prefix { + if len(prefixes) <= 1 { + return prefixes + } + + sort.Slice(prefixes, func(i, j int) bool { + addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) + if addrCmp == 0 { + return prefixes[i].Bits() < prefixes[j].Bits() + } + return addrCmp < 0 + }) + + var rv = []netip.Prefix{prefixes[0]} + + for i := 1; i < len(prefixes); i++ { + last, p := rv[len(rv)-1], prefixes[i] + if ContainsPrefix(last, p) { + // Skip overlapping prefixes + continue + } + rv = append(rv, p) + + // Merge adjacent prefixes if possible + for len(rv) >= 2 { + // Merge the last two prefixes if they are adjacent + p, ok := mergeAdjacentPrefixes(rv[len(rv)-2], rv[len(rv)-1]) + if !ok { + break + } + + // Replace the last two prefixes with the merged prefix + rv = rv[:len(rv)-2] + rv = append(rv, p) + } + } + return rv +} + +// AggregatePrefixes merges overlapping or adjacent prefixes into a single prefix. +// It combines prefixes that can be represented by a larger, more inclusive prefix. +// +// Examples: +// - Adjacent: [192.168.0.0/32, 192.168.0.1/32] -> [192.168.0.0/31] +// - Overlapping: [192.168.0.0/24, 192.168.0.1/32] -> [192.168.0.0/24] func AggregatePrefixes(prefixes []netip.Prefix) []netip.Prefix { + var ( + v4, v6 = GroupPrefixesByFamily(prefixes) + ) + + return append(aggregatePrefixesForSingleIPFamily(v4), aggregatePrefixesForSingleIPFamily(v6)...) +} + +// AggregatePrefixesWithPrefixTree merges overlapping or adjacent prefixes into a single prefix. +// +// This function uses a prefix tree to aggregate the input prefixes. While it achieves +// the same result as AggregatePrefixes, it is less efficient. For better performance, +// use AggregatePrefixes instead. +// +// Examples: +// - Adjacent: [192.168.0.0/32, 192.168.0.1/32] -> [192.168.0.0/31] +// - Overlapping: [192.168.0.0/24, 192.168.0.1/32] -> [192.168.0.0/24] +func AggregatePrefixesWithPrefixTree(prefixes []netip.Prefix) []netip.Prefix { var ( v4, v6 = GroupPrefixesByFamily(prefixes) v4Tree = newPrefixTreeForIPv4() diff --git a/pkg/util/iputil/prefix_test.go b/pkg/util/iputil/prefix_test.go index 30661c4a4e..0875e5eb0a 100644 --- a/pkg/util/iputil/prefix_test.go +++ b/pkg/util/iputil/prefix_test.go @@ -23,6 +23,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" ) func TestIsPrefixesAllowAll(t *testing.T) { @@ -266,80 +268,250 @@ func TestAggregatePrefixes(t *testing.T) { return tt.Output[i].String() < tt.Output[j].String() }) assert.Equal(t, tt.Output, got) + + { + // Test the prefix tree implementation + var got = AggregatePrefixesWithPrefixTree(tt.Input) + + sort.Slice(got, func(i, j int) bool { + return got[i].String() < got[j].String() + }) + sort.Slice(tt.Output, func(i, j int) bool { + return tt.Output[i].String() < tt.Output[j].String() + }) + assert.Equal(t, tt.Output, got) + } }) } } -func BenchmarkAggregatePrefixes(b *testing.B) { - fixtureIPv4Prefixes := func(n int64) []netip.Prefix { - prefixes := make([]netip.Prefix, 0, n) - for i := int64(0); i < n; i++ { - addr := netip.AddrFrom4([4]byte{ - byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), - }) - prefix, err := addr.Prefix(32) - assert.NoError(b, err) - prefixes = append(prefixes, prefix) - } - - return prefixes +func TestContainsPrefix(t *testing.T) { + tests := []struct { + name string + p netip.Prefix + o netip.Prefix + expected bool + }{ + { + name: "IPv4: Exact match", + p: netip.MustParsePrefix("192.168.0.0/24"), + o: netip.MustParsePrefix("192.168.0.0/24"), + expected: true, + }, + { + name: "IPv4: Larger contains smaller", + p: netip.MustParsePrefix("192.168.0.0/16"), + o: netip.MustParsePrefix("192.168.1.0/24"), + expected: true, + }, + { + name: "IPv4: Smaller doesn't contain larger", + p: netip.MustParsePrefix("192.168.1.0/24"), + o: netip.MustParsePrefix("192.168.0.0/16"), + expected: false, + }, + { + name: "IPv4: Non-overlapping", + p: netip.MustParsePrefix("192.168.0.0/24"), + o: netip.MustParsePrefix("192.169.0.0/24"), + expected: false, + }, + { + name: "IPv6: Exact match", + p: netip.MustParsePrefix("2001:db8::/32"), + o: netip.MustParsePrefix("2001:db8::/32"), + expected: true, + }, + { + name: "IPv6: Larger contains smaller", + p: netip.MustParsePrefix("2001:db8::/32"), + o: netip.MustParsePrefix("2001:db8:1::/48"), + expected: true, + }, + { + name: "IPv6: Smaller doesn't contain larger", + p: netip.MustParsePrefix("2001:db8:1::/48"), + o: netip.MustParsePrefix("2001:db8::/32"), + expected: false, + }, + { + name: "IPv6: Non-overlapping", + p: netip.MustParsePrefix("2001:db8::/32"), + o: netip.MustParsePrefix("2001:db9::/32"), + expected: false, + }, } - fixtureIPv6Prefixes := func(n int64) []netip.Prefix { - prefixes := make([]netip.Prefix, 0, n) - for i := int64(0); i < n; i++ { - addr := netip.AddrFrom16([16]byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - byte(i >> 56), byte(i >> 48), byte(i >> 40), byte(i >> 32), - byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), - }) - prefix, err := addr.Prefix(128) - assert.NoError(b, err) - prefixes = append(prefixes, prefix) - } - return prefixes + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ContainsPrefix(tt.p, tt.o) + assert.Equal(t, tt.expected, result) + }) } +} - runIPv4Tests := func(b *testing.B, n int64) { - b.Run(fmt.Sprintf("IPv4-%d", n), func(b *testing.B) { - b.StopTimer() - prefixes := fixtureIPv4Prefixes(n) - b.StartTimer() +func TestMergePrefixes(t *testing.T) { + tests := []struct { + name string + p1 netip.Prefix + p2 netip.Prefix + expected netip.Prefix + ok bool + }{ + { + name: "IPv4: Overlapping prefixes", + p1: netip.MustParsePrefix("192.168.0.0/24"), + p2: netip.MustParsePrefix("192.168.0.0/25"), + expected: netip.Prefix{}, + ok: false, + }, + { + name: "IPv4: Adjacent prefixes", + p1: netip.MustParsePrefix("192.168.0.0/25"), + p2: netip.MustParsePrefix("192.168.0.128/25"), + expected: netip.MustParsePrefix("192.168.0.0/24"), + ok: true, + }, + { + name: "IPv4: Non-mergeable prefixes", + p1: netip.MustParsePrefix("192.168.0.0/24"), + p2: netip.MustParsePrefix("192.168.2.0/24"), + expected: netip.Prefix{}, + ok: false, + }, + { + name: "IPv6: Overlapping prefixes", + p1: netip.MustParsePrefix("2001:db8::/32"), + p2: netip.MustParsePrefix("2001:db8::/48"), + expected: netip.Prefix{}, + ok: false, + }, + { + name: "IPv6: Adjacent prefixes", + p1: netip.MustParsePrefix("2001:db8::/33"), + p2: netip.MustParsePrefix("2001:db8:8000::/33"), + expected: netip.MustParsePrefix("2001:db8::/32"), + ok: true, + }, + { + name: "IPv6: Non-mergeable prefixes", + p1: netip.MustParsePrefix("2001:db8::/32"), + p2: netip.MustParsePrefix("2001:db10::/32"), + expected: netip.Prefix{}, + ok: false, + }, + } - for i := 0; i < b.N; i++ { - AggregatePrefixes(prefixes) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, ok := mergeAdjacentPrefixes(tt.p1, tt.p2) + assert.Equal(t, tt.ok, ok) + assert.Equal(t, tt.expected, result) }) } +} - runIPv6Tests := func(b *testing.B, n int64) { - b.Run(fmt.Sprintf("IPv6-%d", n), func(b *testing.B) { - b.StopTimer() - prefixes := fixtureIPv4Prefixes(n) - b.StartTimer() +// BenchmarkPrefixFixtures generates a list of prefixes for aggregation benchmarks. +// The second return value is the expected result of the aggregation. +func benchmarkPrefixFixtures() ([]netip.Prefix, []netip.Prefix) { + var rv []netip.Prefix + for i := 0; i <= 255; i++ { + for j := 0; j <= 255; j++ { + rv = append(rv, netip.MustParsePrefix(fmt.Sprintf("192.168.%d.%d/32", i, j))) + } + } - for i := 0; i < b.N; i++ { - AggregatePrefixes(prefixes) - } - }) + return rv, []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), } +} - runMixedTests := func(b *testing.B, n int64) { - b.Run(fmt.Sprintf("IPv4-IPv6-%d", 2*n), func(b *testing.B) { - b.StopTimer() - prefixes := append(fixtureIPv4Prefixes(n), fixtureIPv6Prefixes(n)...) - b.StartTimer() +func BenchmarkAggregatePrefixes(b *testing.B) { + prefixes, expected := benchmarkPrefixFixtures() + b.ResetTimer() + for i := 0; i < b.N; i++ { + actual := AggregatePrefixes(prefixes) + assert.Len(b, actual, 1) + assert.Equal(b, expected, actual) + } +} - for i := 0; i < b.N; i++ { - AggregatePrefixes(prefixes) - } - }) +func BenchmarkAggregatePrefixesWithPrefixTree(b *testing.B) { + prefixes, expected := benchmarkPrefixFixtures() + b.ResetTimer() + for i := 0; i < b.N; i++ { + actual := AggregatePrefixesWithPrefixTree(prefixes) + assert.Len(b, actual, 1) + assert.Equal(b, expected, actual) } +} - for _, n := range []int64{100, 1_000, 10_000} { - runIPv4Tests(b, n) - runIPv6Tests(b, n) - runMixedTests(b, n) +func FuzzAggregatePrefixesIPv4(f *testing.F) { + f.Add( + netip.MustParseAddr("192.168.0.0").AsSlice(), + 24, + netip.MustParseAddr("192.168.1.0").AsSlice(), + 24, + netip.MustParseAddr("10.0.0.0").AsSlice(), + 8, + ) + + parsePrefix := func(bytes []byte, bits int) (netip.Prefix, error) { + if bits < 0 || bits > 32 { + return netip.Prefix{}, fmt.Errorf("invalid bits") + } + + addr, ok := netip.AddrFromSlice(bytes) + if !ok { + return netip.Prefix{}, fmt.Errorf("invalid address") + } + + return addr.Prefix(bits) } + + listAddressesAsString := func(prefixes ...netip.Prefix) []string { + rv := make(map[string]struct{}) + for _, p := range prefixes { + for addr := p.Addr(); p.Contains(addr); addr = addr.Next() { + rv[addr.String()] = struct{}{} + } + } + return fnutil.Keys(rv) + } + + f.Fuzz(func( + t *testing.T, + p1Bytes []byte, p1Bits int, + p2Bytes []byte, p2Bits int, + p3Bytes []byte, p3Bits int, + ) { + + p1, err := parsePrefix(p1Bytes, p1Bits) + if err != nil { + return + } + p2, err := parsePrefix(p2Bytes, p2Bits) + if err != nil { + return + } + p3, err := parsePrefix(p3Bytes, p3Bits) + if err != nil { + return + } + + input := []netip.Prefix{p1, p2, p3} + output := AggregatePrefixes(input) + + prefixAsString := func(p netip.Prefix) string { return p.String() } + t.Logf("input: %s", fnutil.Map(prefixAsString, input)) + t.Logf("output: %s", fnutil.Map(prefixAsString, output)) + + expectedAddresses := listAddressesAsString(input...) + actualAddresses := listAddressesAsString(output...) + assert.Equal(t, len(expectedAddresses), len(actualAddresses)) + + sort.Strings(expectedAddresses) + sort.Strings(actualAddresses) + assert.Equal(t, expectedAddresses, actualAddresses) + }) } diff --git a/pkg/util/iputil/prefix_tree.go b/pkg/util/iputil/prefix_tree.go index 2399c09ebd..a7515b42fe 100644 --- a/pkg/util/iputil/prefix_tree.go +++ b/pkg/util/iputil/prefix_tree.go @@ -29,10 +29,26 @@ type prefixTreeNode struct { r *prefixTreeNode // right child node } -// pruneToRoot prunes the tree to the root. -// If a node's left and right children are both masked, -// it is masked and its children are pruned. -// This is done recursively up to the root. +// pruneToRoot checks if the current node and its sibling are masked, +// and if so, marks their parent as masked and removes both children. +// This process is repeated up the tree until a node with an unmasked sibling is found. +// +// The process can be visualized as follows: +// +// Before: After: +// P P (masked) +// / \ / \ +// A B -> X X +// (M) (M) +// +// Where: +// +// P: Parent node +// A, B: Child nodes +// M: Masked +// X: Removed +// +// This method helps to optimize the tree structure by condensing fully masked subtrees. func (n *prefixTreeNode) pruneToRoot() { var node = n for node.p != nil { @@ -49,6 +65,31 @@ func (n *prefixTreeNode) pruneToRoot() { } } +// prefixTree represents a tree structure for storing and managing IP prefixes. +// It efficiently handles prefix aggregation, merging of overlapping prefixes, +// and collapsing of neighboring prefixes. +// +// The tree is structured as follows: +// - Each node represents a bit in the IP address +// - Left child represents a 0 bit, right child represents a 1 bit +// - Masked nodes indicate the end of a prefix +// - Unused branches are represented by nil pointers +// +// Example tree for 128.0.0.0/4 (binary 1000 0000): +// +// 0 (0.0.0.0/0) +// / \ +// X 1 (128.0.0.0/1) +// / \ +// 0 X +// / \ +// 0 X +// / \ +// 0* X +// +// Where: +// * denotes a masked node (prefix end) +// X denotes an unused branch (nil pointer) type prefixTree struct { maxBits int root *prefixTreeNode