diff --git a/internal/testutil/fixture/azure_securitygroup.go b/internal/testutil/fixture/azure_securitygroup.go index 721e381aad..a979fa3646 100644 --- a/internal/testutil/fixture/azure_securitygroup.go +++ b/internal/testutil/fixture/azure_securitygroup.go @@ -27,9 +27,9 @@ import ( "sigs.k8s.io/cloud-provider-azure/internal/testutil" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/securitygroup" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) // NoiseSecurityRules returns 3 non cloud-provider-specific security rules. diff --git a/internal/testutil/fixture/fixture.go b/internal/testutil/fixture/fixture.go index 3278676ad8..92aeba488f 100644 --- a/internal/testutil/fixture/fixture.go +++ b/internal/testutil/fixture/fixture.go @@ -24,7 +24,7 @@ import ( "math/big" "net/netip" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" ) type Fixture struct{} diff --git a/pkg/provider/azure.go b/pkg/provider/azure.go index 4eb6adcc4b..4f2b178fcc 100644 --- a/pkg/provider/azure.go +++ b/pkg/provider/azure.go @@ -80,6 +80,7 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" ratelimitconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" "sigs.k8s.io/cloud-provider-azure/pkg/util/taints" @@ -423,10 +424,10 @@ type Cloud struct { routeUpdater batchProcessor backendPoolUpdater batchProcessor - vmCache azcache.Resource - lbCache azcache.Resource - nsgCache azcache.Resource - rtCache azcache.Resource + vmCache azcache.Resource + lbCache azcache.Resource + nsgRepo securitygroup.Repository + rtCache azcache.Resource // public ip cache // key: [resourceGroupName] // Value: sync.Map of [pipName]*PublicIPAddress @@ -727,8 +728,16 @@ func (az *Cloud) InitializeCloudFromConfig(ctx context.Context, config *Config, if err != nil { return err } - } + networkClientFactory := az.NetworkClientFactory + if networkClientFactory == nil { + networkClientFactory = az.ComputeClientFactory + } + az.nsgRepo, err = securitygroup.NewSecurityGroupRepo(az.SecurityGroupResourceGroup, az.SecurityGroupName, az.NsgCacheTTLInSeconds, az.DisableAPICallCache, networkClientFactory.GetSecurityGroupClient()) + if err != nil { + return err + } + } err = az.initCaches() if err != nil { return err @@ -841,11 +850,6 @@ func (az *Cloud) initCaches() (err error) { return err } - az.nsgCache, err = az.newNSGCache() - if err != nil { - return err - } - az.rtCache, err = az.newRouteTableCache() if err != nil { return err diff --git a/pkg/provider/azure_fakes.go b/pkg/provider/azure_fakes.go index b86c4fcae4..c7c42ab3f6 100644 --- a/pkg/provider/azure_fakes.go +++ b/pkg/provider/azure_fakes.go @@ -46,6 +46,7 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -138,7 +139,7 @@ func GetTestCloud(ctrl *gomock.Controller) (az *Cloud) { az.VMSet, _ = newAvailabilitySet(az) az.vmCache, _ = az.newVMCache() az.lbCache, _ = az.newLBCache() - az.nsgCache, _ = az.newNSGCache() + az.nsgRepo, _ = securitygroup.NewSecurityGroupRepo(az.SecurityGroupResourceGroup, az.SecurityGroupName, az.NsgCacheTTLInSeconds, az.Config.DisableAPICallCache, securtyGrouptrack2Client) az.rtCache, _ = az.newRouteTableCache() az.pipCache, _ = az.newPIPCache() az.plsCache, _ = az.newPLSCache() diff --git a/pkg/provider/azure_loadbalancer.go b/pkg/provider/azure_loadbalancer.go index 931f06b248..8dc08ebe40 100644 --- a/pkg/provider/azure_loadbalancer.go +++ b/pkg/provider/azure_loadbalancer.go @@ -46,10 +46,10 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/log" "sigs.k8s.io/cloud-provider-azure/pkg/metrics" "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" "sigs.k8s.io/cloud-provider-azure/pkg/retry" "sigs.k8s.io/cloud-provider-azure/pkg/trace" "sigs.k8s.io/cloud-provider-azure/pkg/trace/attributes" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -2924,7 +2924,7 @@ func (az *Cloud) reconcileSecurityGroup( var accessControl *loadbalancer.AccessControl { - sg, err := az.getSecurityGroup(ctx, azcache.CacheReadTypeDefault) + sg, err := az.nsgRepo.GetSecurityGroup(ctx) if err != nil { return nil, err } @@ -3017,13 +3017,12 @@ func (az *Cloud) reconcileSecurityGroup( if updated { logger.V(2).Info("Preparing to update security group") logger.V(5).Info("CreateOrUpdateSecurityGroup begin") - err := az.CreateOrUpdateSecurityGroup(rv) + err := az.nsgRepo.CreateOrUpdateSecurityGroup(ctx, rv) if err != nil { logger.Error(err, "Failed to update security group") return nil, err } logger.V(5).Info("CreateOrUpdateSecurityGroup end") - _ = az.nsgCache.Delete(ptr.Deref(rv.Name, "")) } return rv, nil } diff --git a/pkg/provider/azure_loadbalancer_accesscontrol.go b/pkg/provider/azure_loadbalancer_accesscontrol.go index 60394aa77d..e2e48a7993 100644 --- a/pkg/provider/azure_loadbalancer_accesscontrol.go +++ b/pkg/provider/azure_loadbalancer_accesscontrol.go @@ -28,7 +28,7 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/log" "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" ) func filterServicesByIngressIPs(services []*v1.Service, ips []netip.Addr) []*v1.Service { diff --git a/pkg/provider/azure_loadbalancer_accesscontrol_test.go b/pkg/provider/azure_loadbalancer_accesscontrol_test.go index 3f6cdfed26..cd853b50e7 100644 --- a/pkg/provider/azure_loadbalancer_accesscontrol_test.go +++ b/pkg/provider/azure_loadbalancer_accesscontrol_test.go @@ -42,9 +42,9 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/log" "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/securitygroup" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" "sigs.k8s.io/cloud-provider-azure/pkg/retry" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) func TestCloud_reconcileSecurityGroup(t *testing.T) { diff --git a/pkg/provider/azure_wrap.go b/pkg/provider/azure_wrap.go index 1e270b9986..6a819329a8 100644 --- a/pkg/provider/azure_wrap.go +++ b/pkg/provider/azure_wrap.go @@ -17,14 +17,11 @@ limitations under the License. package provider import ( - "errors" "fmt" "net/http" "regexp" "strings" - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) @@ -32,7 +29,6 @@ import ( var ( vmCacheTTLDefaultInSeconds = 60 loadBalancerCacheTTLDefaultInSeconds = 120 - nsgCacheTTLDefaultInSeconds = 120 routeTableCacheTTLDefaultInSeconds = 120 publicIPCacheTTLDefaultInSeconds = 120 plsCacheTTLDefaultInSeconds = 120 @@ -56,20 +52,6 @@ func checkResourceExistsFromError(err *retry.Error) (bool, *retry.Error) { return false, err } -func checkResourceExistsFromAzcoreError(err error) (bool, error) { - if err == nil { - return true, nil - } - var respError *azcore.ResponseError - if errors.As(err, &respError) && respError != nil { - if respError.StatusCode == http.StatusNotFound { - return false, nil - } - } - - return false, err -} - func (az *Cloud) useStandardLoadBalancer() bool { return strings.EqualFold(az.LoadBalancerSku, consts.LoadBalancerSkuStandard) } diff --git a/pkg/provider/loadbalancer/accesscontrol.go b/pkg/provider/loadbalancer/accesscontrol.go index 7ada58579f..71c1ad72f9 100644 --- a/pkg/provider/loadbalancer/accesscontrol.go +++ b/pkg/provider/loadbalancer/accesscontrol.go @@ -27,9 +27,9 @@ import ( "k8s.io/utils/ptr" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/securitygroup" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) var ( diff --git a/pkg/provider/loadbalancer/accesscontrol_test.go b/pkg/provider/loadbalancer/accesscontrol_test.go index 17284b7955..e1eb2eb426 100644 --- a/pkg/provider/loadbalancer/accesscontrol_test.go +++ b/pkg/provider/loadbalancer/accesscontrol_test.go @@ -30,9 +30,9 @@ import ( "sigs.k8s.io/cloud-provider-azure/internal/testutil" "sigs.k8s.io/cloud-provider-azure/internal/testutil/fixture" "sigs.k8s.io/cloud-provider-azure/pkg/log" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/securitygroup" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) func TestAccessControl_IsAllowFromInternet(t *testing.T) { diff --git a/pkg/provider/loadbalancer/configuration.go b/pkg/provider/loadbalancer/configuration.go index 95d2dc6bff..d1f3d46832 100644 --- a/pkg/provider/loadbalancer/configuration.go +++ b/pkg/provider/loadbalancer/configuration.go @@ -25,7 +25,7 @@ import ( v1 "k8s.io/api/core/v1" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) // IsInternal returns true if the given service is internal load balancer. diff --git a/pkg/provider/azure_securitygroup_repo.go b/pkg/provider/securitygroup/azure_securitygroup_repo.go similarity index 55% rename from pkg/provider/azure_securitygroup_repo.go rename to pkg/provider/securitygroup/azure_securitygroup_repo.go index c33c3e012f..0c602b8ee2 100644 --- a/pkg/provider/azure_securitygroup_repo.go +++ b/pkg/provider/securitygroup/azure_securitygroup_repo.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package provider +package securitygroup import ( "context" @@ -31,20 +31,66 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/securitygroupclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" ) -// CreateOrUpdateSecurityGroup invokes az.SecurityGroupsClient.CreateOrUpdate with exponential backoff retry -func (az *Cloud) CreateOrUpdateSecurityGroup(sg *armnetwork.SecurityGroup) error { - ctx, cancel := getContextWithCancel() - defer cancel() - clientFactory := az.NetworkClientFactory - if clientFactory == nil { - clientFactory = az.ComputeClientFactory +const ( + nsgCacheTTLDefaultInSeconds = 120 +) + +type Repository interface { + GetSecurityGroup(ctx context.Context) (*armnetwork.SecurityGroup, error) + CreateOrUpdateSecurityGroup(ctx context.Context, sg *armnetwork.SecurityGroup) error +} + +type securityGroupRepo struct { + securityGroupResourceGroup string + securityGroupName string + nsgCacheTTLInSeconds int + securigyGroupClient securitygroupclient.Interface + nsgCache azcache.Resource +} + +func NewSecurityGroupRepo(securityGroupResourceGroup string, securityGroupName string, nsgCacheTTLInSeconds int, disableAPICallCache bool, securityGroupClient securitygroupclient.Interface) (Repository, error) { + getter := func(ctx context.Context, key string) (interface{}, error) { + nsg, err := securityGroupClient.Get(ctx, securityGroupResourceGroup, key) + exists, rerr := errutils.CheckResourceExistsFromAzcoreError(err) + if rerr != nil { + return nil, err + } + + if !exists { + klog.V(2).Infof("Security group %q not found", key) + return nil, nil + } + + return nsg, nil + } + + if nsgCacheTTLInSeconds == 0 { + nsgCacheTTLInSeconds = nsgCacheTTLDefaultInSeconds + } + cache, err := azcache.NewTimedCache(time.Duration(nsgCacheTTLInSeconds)*time.Second, getter, disableAPICallCache) + if err != nil { + klog.Errorf("Failed to create cache for security group %q: %v", securityGroupName, err) + return nil, err } - sgClient := clientFactory.GetSecurityGroupClient() - _, rerr := sgClient.CreateOrUpdate(ctx, az.SecurityGroupResourceGroup, *sg.Name, *sg) + + return &securityGroupRepo{ + securityGroupResourceGroup: securityGroupResourceGroup, + securityGroupName: securityGroupName, + nsgCacheTTLInSeconds: nsgCacheTTLDefaultInSeconds, + securigyGroupClient: securityGroupClient, + nsgCache: cache, + }, nil +} + +// CreateOrUpdateSecurityGroup invokes az.SecurityGroupsClient.CreateOrUpdate with exponential backoff retry +func (az *securityGroupRepo) CreateOrUpdateSecurityGroup(ctx context.Context, sg *armnetwork.SecurityGroup) error { + _, rerr := az.securigyGroupClient.CreateOrUpdate(ctx, az.securityGroupResourceGroup, *sg.Name, *sg) klog.V(10).Infof("SecurityGroupsClient.CreateOrUpdate(%s): end", *sg.Name) if rerr == nil { // Invalidate the cache right after updating @@ -71,47 +117,19 @@ func (az *Cloud) CreateOrUpdateSecurityGroup(sg *armnetwork.SecurityGroup) error return rerr } -func (az *Cloud) newNSGCache() (azcache.Resource, error) { - getter := func(ctx context.Context, key string) (interface{}, error) { - clientFactory := az.NetworkClientFactory - if clientFactory == nil { - clientFactory = az.ComputeClientFactory - } - sgClient := clientFactory.GetSecurityGroupClient() - - nsg, err := sgClient.Get(ctx, az.SecurityGroupResourceGroup, key) - exists, rerr := checkResourceExistsFromAzcoreError(err) - if rerr != nil { - return nil, err - } - - if !exists { - klog.V(2).Infof("Security group %q not found", key) - return nil, nil - } - - return nsg, nil - } - - if az.NsgCacheTTLInSeconds == 0 { - az.NsgCacheTTLInSeconds = nsgCacheTTLDefaultInSeconds - } - return azcache.NewTimedCache(time.Duration(az.NsgCacheTTLInSeconds)*time.Second, getter, az.Config.DisableAPICallCache) -} - -func (az *Cloud) getSecurityGroup(ctx context.Context, crt azcache.AzureCacheReadType) (*armnetwork.SecurityGroup, error) { +func (az *securityGroupRepo) GetSecurityGroup(ctx context.Context) (*armnetwork.SecurityGroup, error) { nsg := &armnetwork.SecurityGroup{} - if az.SecurityGroupName == "" { + if az.securityGroupName == "" { return nsg, fmt.Errorf("securityGroupName is not configured") } - securityGroup, err := az.nsgCache.GetWithDeepCopy(ctx, az.SecurityGroupName, crt) + securityGroup, err := az.nsgCache.GetWithDeepCopy(ctx, az.securityGroupName, azcache.CacheReadTypeDefault) if err != nil { return nsg, err } if securityGroup == nil { - return nsg, fmt.Errorf("nsg %q not found", az.SecurityGroupName) + return nsg, fmt.Errorf("nsg %q not found", az.securityGroupName) } return securityGroup.(*armnetwork.SecurityGroup), nil diff --git a/pkg/provider/securitygroup/azure_securitygroup_repo_mock.go b/pkg/provider/securitygroup/azure_securitygroup_repo_mock.go new file mode 100644 index 0000000000..017bc064b4 --- /dev/null +++ b/pkg/provider/securitygroup/azure_securitygroup_repo_mock.go @@ -0,0 +1,87 @@ +// /* +// Copyright The Kubernetes Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: azure_securitygroup_repo.go +// +// Generated by this command: +// +// mockgen -package securitygroup -source azure_securitygroup_repo.go -self_package sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup -copyright_file ../../../hack/boilerplate/boilerplate.generatego.txt +// + +// Package securitygroup is a generated GoMock package. +package securitygroup + +import ( + context "context" + reflect "reflect" + + armnetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + gomock "go.uber.org/mock/gomock" +) + +// MockRepository is a mock of Repository interface. +type MockRepository struct { + ctrl *gomock.Controller + recorder *MockRepositoryMockRecorder +} + +// MockRepositoryMockRecorder is the mock recorder for MockRepository. +type MockRepositoryMockRecorder struct { + mock *MockRepository +} + +// NewMockRepository creates a new mock instance. +func NewMockRepository(ctrl *gomock.Controller) *MockRepository { + mock := &MockRepository{ctrl: ctrl} + mock.recorder = &MockRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRepository) EXPECT() *MockRepositoryMockRecorder { + return m.recorder +} + +// CreateOrUpdateSecurityGroup mocks base method. +func (m *MockRepository) CreateOrUpdateSecurityGroup(ctx context.Context, sg *armnetwork.SecurityGroup) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateOrUpdateSecurityGroup", ctx, sg) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateOrUpdateSecurityGroup indicates an expected call of CreateOrUpdateSecurityGroup. +func (mr *MockRepositoryMockRecorder) CreateOrUpdateSecurityGroup(ctx, sg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdateSecurityGroup", reflect.TypeOf((*MockRepository)(nil).CreateOrUpdateSecurityGroup), ctx, sg) +} + +// GetSecurityGroup mocks base method. +func (m *MockRepository) GetSecurityGroup(ctx context.Context) (*armnetwork.SecurityGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSecurityGroup", ctx) + ret0, _ := ret[0].(*armnetwork.SecurityGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSecurityGroup indicates an expected call of GetSecurityGroup. +func (mr *MockRepositoryMockRecorder) GetSecurityGroup(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSecurityGroup", reflect.TypeOf((*MockRepository)(nil).GetSecurityGroup), ctx) +} diff --git a/pkg/provider/azure_securitygroup_repo_test.go b/pkg/provider/securitygroup/azure_securitygroup_repo_test.go similarity index 66% rename from pkg/provider/azure_securitygroup_repo_test.go rename to pkg/provider/securitygroup/azure_securitygroup_repo_test.go index 4d6be443ca..7bfd7ae117 100644 --- a/pkg/provider/azure_securitygroup_repo_test.go +++ b/pkg/provider/securitygroup/azure_securitygroup_repo_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package provider +package securitygroup import ( "context" @@ -31,7 +31,6 @@ import ( "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/securitygroupclient/mock_securitygroupclient" "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" @@ -40,24 +39,23 @@ import ( func TestCreateOrUpdateSecurityGroupCanceled(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + mockSGClient := mock_securitygroupclient.NewMockInterface(ctrl) + az, err := NewSecurityGroupRepo("rg", "sg", 120, false, mockSGClient) + assert.NoError(t, err) + az.(*securityGroupRepo).nsgCache.Set("sg", "test") - az := GetTestCloud(ctrl) - az.nsgCache.Set("sg", "test") - clientFactory := az.NetworkClientFactory.(*mock_azclient.MockClientFactory) - mockSGClient := clientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - - mockSGClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, &azcore.ResponseError{ + mockSGClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "sg", gomock.Any()).Return(nil, &azcore.ResponseError{ RawResponse: &http.Response{ Body: io.NopCloser(strings.NewReader(consts.OperationCanceledErrorMessage)), }, }) - mockSGClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "sg").Return(&armnetwork.SecurityGroup{}, nil) + mockSGClient.EXPECT().Get(gomock.Any(), "rg", "sg").Return(&armnetwork.SecurityGroup{}, nil) - err := az.CreateOrUpdateSecurityGroup(&armnetwork.SecurityGroup{Name: ptr.To("sg")}) + err = az.CreateOrUpdateSecurityGroup(context.TODO(), &armnetwork.SecurityGroup{Name: ptr.To("sg")}) assert.Contains(t, err.Error(), "canceledandsupersededduetoanotheroperation") // security group should be removed from cache if the operation is canceled - shouldBeEmpty, err := az.nsgCache.GetWithDeepCopy(context.TODO(), "sg", cache.CacheReadTypeDefault) + shouldBeEmpty, err := az.(*securityGroupRepo).nsgCache.GetWithDeepCopy(context.TODO(), "sg", cache.CacheReadTypeDefault) assert.NoError(t, err) assert.Empty(t, shouldBeEmpty) } diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup.go b/pkg/provider/securitygroup/securitygroup.go similarity index 99% rename from pkg/provider/loadbalancer/securitygroup/securitygroup.go rename to pkg/provider/securitygroup/securitygroup.go index 8fb6213fcf..7ed4be70ff 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup.go +++ b/pkg/provider/securitygroup/securitygroup.go @@ -30,8 +30,8 @@ import ( "k8s.io/utils/ptr" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) const ( diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go b/pkg/provider/securitygroup/securitygroup_test.go similarity index 99% rename from pkg/provider/loadbalancer/securitygroup/securitygroup_test.go rename to pkg/provider/securitygroup/securitygroup_test.go index dd7c3e5cb4..90b5dfbc47 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go +++ b/pkg/provider/securitygroup/securitygroup_test.go @@ -29,9 +29,9 @@ import ( "sigs.k8s.io/cloud-provider-azure/internal/testutil" "sigs.k8s.io/cloud-provider-azure/internal/testutil/fixture" "sigs.k8s.io/cloud-provider-azure/pkg/log" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" - . "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/securitygroup" //nolint:revive + . "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" //nolint:revive + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) func ExpectNewSecurityGroupHelper(t *testing.T, sg *armnetwork.SecurityGroup) *RuleHelper { diff --git a/pkg/provider/loadbalancer/securitygroup/securityrule.go b/pkg/provider/securitygroup/securityrule.go similarity index 97% rename from pkg/provider/loadbalancer/securitygroup/securityrule.go rename to pkg/provider/securitygroup/securityrule.go index 1cd843d44f..481c26f134 100644 --- a/pkg/provider/loadbalancer/securitygroup/securityrule.go +++ b/pkg/provider/securitygroup/securityrule.go @@ -27,8 +27,8 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" + "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) // GenerateAllowSecurityRuleName returns the AllowInbound rule name based on the given rule properties. diff --git a/pkg/provider/loadbalancer/fnutil/map.go b/pkg/util/collectionutil/map.go similarity index 100% rename from pkg/provider/loadbalancer/fnutil/map.go rename to pkg/util/collectionutil/map.go diff --git a/pkg/provider/loadbalancer/fnutil/slice.go b/pkg/util/collectionutil/slice.go similarity index 100% rename from pkg/provider/loadbalancer/fnutil/slice.go rename to pkg/util/collectionutil/slice.go diff --git a/pkg/util/errutils/err.go b/pkg/util/errutils/err.go new file mode 100644 index 0000000000..477b816e40 --- /dev/null +++ b/pkg/util/errutils/err.go @@ -0,0 +1,37 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package errutils + +import ( + "errors" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +func CheckResourceExistsFromAzcoreError(err error) (bool, error) { + if err == nil { + return true, nil + } + var respError *azcore.ResponseError + if errors.As(err, &respError) && respError != nil { + if respError.StatusCode == http.StatusNotFound { + return false, nil + } + } + return false, err +} diff --git a/pkg/provider/loadbalancer/iputil/addr.go b/pkg/util/iputil/addr.go similarity index 100% rename from pkg/provider/loadbalancer/iputil/addr.go rename to pkg/util/iputil/addr.go diff --git a/pkg/provider/loadbalancer/iputil/addr_test.go b/pkg/util/iputil/addr_test.go similarity index 100% rename from pkg/provider/loadbalancer/iputil/addr_test.go rename to pkg/util/iputil/addr_test.go diff --git a/pkg/provider/loadbalancer/iputil/family.go b/pkg/util/iputil/family.go similarity index 94% rename from pkg/provider/loadbalancer/iputil/family.go rename to pkg/util/iputil/family.go index 9ac42b92fd..65a8ce2203 100644 --- a/pkg/provider/loadbalancer/iputil/family.go +++ b/pkg/util/iputil/family.go @@ -19,7 +19,7 @@ package iputil import ( "net/netip" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" + fnutil "sigs.k8s.io/cloud-provider-azure/pkg/util/collectionutil" ) type Family string diff --git a/pkg/provider/loadbalancer/iputil/family_test.go b/pkg/util/iputil/family_test.go similarity index 100% rename from pkg/provider/loadbalancer/iputil/family_test.go rename to pkg/util/iputil/family_test.go diff --git a/pkg/provider/loadbalancer/iputil/prefix.go b/pkg/util/iputil/prefix.go similarity index 100% rename from pkg/provider/loadbalancer/iputil/prefix.go rename to pkg/util/iputil/prefix.go diff --git a/pkg/provider/loadbalancer/iputil/prefix_test.go b/pkg/util/iputil/prefix_test.go similarity index 100% rename from pkg/provider/loadbalancer/iputil/prefix_test.go rename to pkg/util/iputil/prefix_test.go diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree.go b/pkg/util/iputil/prefix_tree.go similarity index 100% rename from pkg/provider/loadbalancer/iputil/prefix_tree.go rename to pkg/util/iputil/prefix_tree.go diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree_test.go b/pkg/util/iputil/prefix_tree_test.go similarity index 100% rename from pkg/provider/loadbalancer/iputil/prefix_tree_test.go rename to pkg/util/iputil/prefix_tree_test.go