From d12fe44e30af8450b6074f7260c43c55f9c30a2b Mon Sep 17 00:00:00 2001 From: Ilias Rinis Date: Wed, 4 Oct 2023 10:54:35 +0200 Subject: [PATCH] endpointaccessible: check if endpoint parameters changed at every sync If there are no changes in the endpoint parameters, skip the check. --- .../endpoint_accessible_controller.go | 65 ++++--- .../endpoint_accessible_controller_test.go | 161 +++++++++++++++++- 2 files changed, 199 insertions(+), 27 deletions(-) diff --git a/pkg/libs/endpointaccessible/endpoint_accessible_controller.go b/pkg/libs/endpointaccessible/endpoint_accessible_controller.go index 60783f5da..0dc08ac9d 100644 --- a/pkg/libs/endpointaccessible/endpoint_accessible_controller.go +++ b/pkg/libs/endpointaccessible/endpoint_accessible_controller.go @@ -3,6 +3,7 @@ package endpointaccessible import ( "context" "crypto/tls" + "crypto/x509" "fmt" "net/http" "strings" @@ -11,6 +12,7 @@ import ( apierrors "k8s.io/apimachinery/pkg/api/errors" utilerrors "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/apimachinery/pkg/util/sets" operatorv1 "github.com/openshift/api/operator/v1" "github.com/openshift/library-go/pkg/controller/factory" @@ -23,6 +25,12 @@ type endpointAccessibleController struct { endpointListFn EndpointListFunc getTLSConfigFn EndpointTLSConfigFunc availableConditionName string + + maxCheckLatency time.Duration + lastCheckTime time.Time + lastEndpoints sets.Set[string] + lastServerName string + lastCA *x509.CertPool } type EndpointListFunc func() ([]string, error) @@ -47,6 +55,8 @@ func NewEndpointAccessibleController( endpointListFn: endpointListFn, getTLSConfigFn: getTLSConfigFn, availableConditionName: name + "EndpointAccessibleControllerAvailable", + maxCheckLatency: resyncInterval - 5*time.Second, + lastEndpoints: sets.New[string](), } return factory.New(). @@ -70,26 +80,41 @@ func humanizeError(err error) error { func (c *endpointAccessibleController) sync(ctx context.Context, syncCtx factory.SyncContext) error { endpoints, err := c.endpointListFn() - if err != nil { - if apierrors.IsNotFound(err) { - _, _, statusErr := v1helpers.UpdateStatus(ctx, c.operatorClient, v1helpers.UpdateConditionFn( - operatorv1.OperatorCondition{ - Type: c.availableConditionName, - Status: operatorv1.ConditionFalse, - Reason: "ResourceNotFound", - Message: err.Error(), - })) - - return statusErr - } + if apierrors.IsNotFound(err) { + _, _, statusErr := v1helpers.UpdateStatus(ctx, c.operatorClient, v1helpers.UpdateConditionFn( + operatorv1.OperatorCondition{ + Type: c.availableConditionName, + Status: operatorv1.ConditionFalse, + Reason: "ResourceNotFound", + Message: err.Error(), + })) + + return statusErr + } else if err != nil { + return err + } + + newEndpoints := sets.New(endpoints...) + endpointsChanged := !c.lastEndpoints.Equal(newEndpoints) + tlsConfig, err := c.getTLSConfigFn() + if err != nil { return err } + tlsChanged := c.lastServerName != tlsConfig.ServerName || !tlsConfig.RootCAs.Equal(c.lastCA) - client, err := c.buildTLSClient() + isPastTimeForCheck := time.Since(c.lastCheckTime) > c.maxCheckLatency + if !endpointsChanged && !tlsChanged && !isPastTimeForCheck { + return nil + } + c.lastCheckTime = time.Now() + c.lastEndpoints = newEndpoints + + client, err := c.buildTLSClient(tlsConfig) if err != nil { return err } + // check all the endpoints in parallel. This matters for pods. errCh := make(chan error, len(endpoints)) wg := sync.WaitGroup{} @@ -155,20 +180,22 @@ func (c *endpointAccessibleController) sync(ctx context.Context, syncCtx factory return utilerrors.NewAggregate(errors) } -func (c *endpointAccessibleController) buildTLSClient() (*http.Client, error) { +func (c *endpointAccessibleController) buildTLSClient(tlsConfig *tls.Config) (*http.Client, error) { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, } - if c.getTLSConfigFn != nil { - tlsConfig, err := c.getTLSConfigFn() - if err != nil { - return nil, err - } + + if tlsConfig != nil { transport.TLSClientConfig = tlsConfig + + // these are the fields that are set by our getTLSConfigFn funcs + c.lastServerName = tlsConfig.ServerName + c.lastCA = tlsConfig.RootCAs } + return &http.Client{ Timeout: 5 * time.Second, Transport: transport, diff --git a/pkg/libs/endpointaccessible/endpoint_accessible_controller_test.go b/pkg/libs/endpointaccessible/endpoint_accessible_controller_test.go index e80041907..58827aef9 100644 --- a/pkg/libs/endpointaccessible/endpoint_accessible_controller_test.go +++ b/pkg/libs/endpointaccessible/endpoint_accessible_controller_test.go @@ -2,59 +2,204 @@ package endpointaccessible import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "testing" + "time" operatorv1 "github.com/openshift/api/operator/v1" "github.com/openshift/library-go/pkg/operator/v1helpers" + "k8s.io/apimachinery/pkg/util/sets" "github.com/openshift/library-go/pkg/controller/factory" "github.com/openshift/library-go/pkg/operator/events" ) func Test_endpointAccessibleController_sync(t *testing.T) { + maxCheckLatency := 55 * time.Second + + systemRootCAs, err := x509.SystemCertPool() + if err != nil { + t.Errorf("unexpected error when getting system cert pool: %v", err) + } + + getTLSConfigFn := func(serverName string, returnErr error) func() (*tls.Config, error) { + return func() (*tls.Config, error) { + return &tls.Config{ + RootCAs: systemRootCAs, + ServerName: serverName, + }, returnErr + } + } + + getTLSConfigFnEmptyRootCAs := func(serverName string, returnErr error) func() (*tls.Config, error) { + return func() (*tls.Config, error) { + return &tls.Config{ + RootCAs: x509.NewCertPool(), + ServerName: serverName, + }, returnErr + } + } + tests := []struct { - name string - endpointListFn EndpointListFunc - wantErr bool + name string + endpointListFn EndpointListFunc + getTLSConfigFn EndpointTLSConfigFunc + lastCheckTime time.Time + lastEndpoints sets.Set[string] + lastServerName string + lastCA *x509.CertPool + wantCheckExecuted bool + wantErr bool }{ { name: "all endpoints working", endpointListFn: func() ([]string, error) { return []string{"https://google.com"}, nil }, + wantCheckExecuted: true, + }, + { + name: "all endpoints working with tls config", + getTLSConfigFn: getTLSConfigFn("google.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + wantCheckExecuted: true, + }, + { + name: "check working when endpoints change", + getTLSConfigFn: getTLSConfigFn("google.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + lastEndpoints: sets.New[string]("https://www.google.com"), + lastCheckTime: time.Now().Add(-1 * time.Second), + lastServerName: "google.com", + lastCA: systemRootCAs, + wantCheckExecuted: true, + }, + { + name: "check working when check is due", + getTLSConfigFn: getTLSConfigFn("google.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + lastEndpoints: sets.New[string]("https://google.com"), + lastCheckTime: time.Now().Add(-2 * maxCheckLatency), + lastServerName: "google.com", + lastCA: systemRootCAs, + wantCheckExecuted: true, + }, + { + name: "check working when tls server name changes", + getTLSConfigFn: getTLSConfigFn("google.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + lastEndpoints: sets.New[string]("https://google.com"), + lastCheckTime: time.Now().Add(-1 * time.Second), + lastServerName: "redhat.com", + lastCA: systemRootCAs, + wantCheckExecuted: true, + }, + { + name: "check working when tls root CAs change", + getTLSConfigFn: getTLSConfigFn("google.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + lastEndpoints: sets.New[string]("https://google.com"), + lastCheckTime: time.Now().Add(-1 * time.Second), + lastServerName: "google.com", + lastCA: x509.NewCertPool(), + wantCheckExecuted: true, + }, + { + name: "check skipped when no changes in parameters and check is not due", + getTLSConfigFn: getTLSConfigFn("google.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + lastEndpoints: sets.New[string]("https://google.com"), + lastCheckTime: time.Now().Add(-1 * time.Second), + lastServerName: "google.com", + lastCA: systemRootCAs, + wantCheckExecuted: false, + wantErr: false, + }, + { + name: "check fails when tls config fails", + getTLSConfigFn: getTLSConfigFn("google.com", fmt.Errorf("tls config error")), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + wantCheckExecuted: false, + wantErr: true, + }, + { + name: "check fails when tls server name invalid", + getTLSConfigFn: getTLSConfigFn("g00gle.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + wantCheckExecuted: true, + wantErr: true, + }, + { + name: "check fails when tls rootCAs invalid", + getTLSConfigFn: getTLSConfigFnEmptyRootCAs("google.com", nil), + endpointListFn: func() ([]string, error) { + return []string{"https://google.com"}, nil + }, + wantCheckExecuted: true, + wantErr: true, }, { name: "endpoints lister error", endpointListFn: func() ([]string, error) { return nil, fmt.Errorf("some error") }, - wantErr: true, + wantCheckExecuted: false, + wantErr: true, }, { name: "non working endpoints", endpointListFn: func() ([]string, error) { return []string{"https://google.com", "https://nonexistenturl.com"}, nil }, - wantErr: true, + wantCheckExecuted: true, + wantErr: true, }, { name: "invalid url", endpointListFn: func() ([]string, error) { return []string{"htt//bad`string"}, nil }, - wantErr: true, + wantCheckExecuted: true, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &endpointAccessibleController{ - operatorClient: v1helpers.NewFakeOperatorClient(&operatorv1.OperatorSpec{}, &operatorv1.OperatorStatus{}, nil), - endpointListFn: tt.endpointListFn, + operatorClient: v1helpers.NewFakeOperatorClient(&operatorv1.OperatorSpec{}, &operatorv1.OperatorStatus{}, nil), + getTLSConfigFn: tt.getTLSConfigFn, + endpointListFn: tt.endpointListFn, + maxCheckLatency: maxCheckLatency, + lastEndpoints: tt.lastEndpoints, + lastCheckTime: tt.lastCheckTime, + lastServerName: tt.lastServerName, + lastCA: tt.lastCA, } + prevLastCheckTime := c.lastCheckTime if err := c.sync(context.Background(), factory.NewSyncContext(tt.name, events.NewInMemoryRecorder(tt.name))); (err != nil) != tt.wantErr { t.Errorf("sync() error = %v, wantErr %v", err, tt.wantErr) } + if tt.wantCheckExecuted != (!prevLastCheckTime.Equal(c.lastCheckTime)) { + t.Errorf("sync() check was executed when it should have been skipped") + } }) } }