diff --git a/pkg/curatedpackages/packagecontrollerclient.go b/pkg/curatedpackages/packagecontrollerclient.go index 8369aadfa3df2..f21818fb0c343 100644 --- a/pkg/curatedpackages/packagecontrollerclient.go +++ b/pkg/curatedpackages/packagecontrollerclient.go @@ -40,8 +40,6 @@ const ( type PackageControllerClientOpt func(client *PackageControllerClient) -type registryAccessTester func(ctx context.Context, accessKey, secret, registry, region string) error - type PackageControllerClient struct { kubeConfig string chart *releasev1.Image @@ -76,7 +74,7 @@ type PackageControllerClient struct { mu sync.Mutex // registryAccessTester test if the aws credential has access to registry - registryAccessTester registryAccessTester + registryAccessTester RegistryAccessTester } // ClientBuilder returns a k8s client for the specified cluster. @@ -112,7 +110,7 @@ func NewPackageControllerClientFullLifecycle(logger logr.Logger, chartManager Ch skipWaitForPackageBundle: true, eksaRegion: eksaDefaultRegion, clientBuilder: clientBuilder, - registryAccessTester: TestRegistryAccess, + registryAccessTester: &DefaultRegistryAccessTester{}, } } @@ -171,7 +169,7 @@ func NewPackageControllerClient(chartManager ChartManager, kubectl KubectlRunner kubectl: kubectl, registryMirror: registryMirror, eksaRegion: eksaDefaultRegion, - registryAccessTester: TestRegistryAccess, + registryAccessTester: &DefaultRegistryAccessTester{}, } for _, o := range options { @@ -269,7 +267,7 @@ func (pc *PackageControllerClient) GetCuratedPackagesRegistries(ctx context.Cont } regionalRegistry := GetRegionalRegistry(defaultRegistry, pc.eksaRegion) - if err := pc.registryAccessTester(ctx, pc.eksaAccessKeyID, pc.eksaSecretAccessKey, regionalRegistry, pc.eksaRegion); err == nil { + if err := pc.registryAccessTester.Test(ctx, pc.eksaAccessKeyID, pc.eksaSecretAccessKey, pc.eksaRegion, pc.eksaAwsConfig, regionalRegistry); err == nil { // use regional registry when the above credential is good logger.V(6).Info("Using regional registry") defaultRegistry = regionalRegistry @@ -619,7 +617,7 @@ func WithClusterSpec(clusterSpec *cluster.Spec) func(client *PackageControllerCl } // WithRegistryAccessTester sets the registryTester. -func WithRegistryAccessTester(registryTester registryAccessTester) func(client *PackageControllerClient) { +func WithRegistryAccessTester(registryTester RegistryAccessTester) func(client *PackageControllerClient) { return func(config *PackageControllerClient) { config.registryAccessTester = registryTester } diff --git a/pkg/curatedpackages/packagecontrollerclient_test.go b/pkg/curatedpackages/packagecontrollerclient_test.go index 3f8e04e31d821..70e0527deb73a 100644 --- a/pkg/curatedpackages/packagecontrollerclient_test.go +++ b/pkg/curatedpackages/packagecontrollerclient_test.go @@ -1178,6 +1178,12 @@ func TestEnableFullLifecyclePath(t *testing.T) { } } +type stubRegistryAccessTester struct{} + +func (s *stubRegistryAccessTester) Test(ctx context.Context, accessKey, secret, registry, region, awsConfig string) error { + return nil +} + func TestGetCuratedPackagesRegistries(s *testing.T) { s.Run("substitutes a region if set", func(t *testing.T) { ctrl := gomock.NewController(t) @@ -1248,9 +1254,7 @@ func TestGetCuratedPackagesRegistries(s *testing.T) { cm, k, clusterName, kubeConfig, chart, nil, curatedpackages.WithManagementClusterName(clusterName), curatedpackages.WithValuesFileWriter(writer), - curatedpackages.WithRegistryAccessTester(func(ctx context.Context, accessKey, secret, registry, region string) error { - return nil - }), + curatedpackages.WithRegistryAccessTester(&stubRegistryAccessTester{}), ) expected := "TODO.dkr.ecr.us-west-2.amazonaws.com" diff --git a/pkg/curatedpackages/regional_registry.go b/pkg/curatedpackages/regional_registry.go index 569ee105da7e6..64fe6e4ec9866 100644 --- a/pkg/curatedpackages/regional_registry.go +++ b/pkg/curatedpackages/regional_registry.go @@ -5,8 +5,10 @@ import ( "fmt" "io" "net/http" + "os" "strings" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/ecr" @@ -22,28 +24,33 @@ var prodRegionalECRMap = map[string]string{ "us-east-2": "TODO.dkr.ecr.us-east-2.amazonaws.com", } -// TestRegistryAccess test if the packageControllerClient has valid credential to access registry. -func TestRegistryAccess(ctx context.Context, accessKey, secret, registry, region string) error { - cfg, err := config.LoadDefaultConfig(ctx, - config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secret, "")), - config.WithRegion(region), - ) - if err != nil { - return err - } +// RegistryAccessTester test if AWS credentials has valid permission to access an ECR registry. +type RegistryAccessTester interface { + Test(ctx context.Context, accessKey, secret, region, awsConfig, registry string) error +} - ecrClient := ecr.NewFromConfig(cfg) - out, err := ecrClient.GetAuthorizationToken(context.Background(), &ecr.GetAuthorizationTokenInput{}) +// DefaultRegistryAccessTester the default implementation of RegistryAccessTester. +type DefaultRegistryAccessTester struct{} + +// Test if the AWS static credential or sharedConfig has valid permission to access an ECR registry. +func (r *DefaultRegistryAccessTester) Test(ctx context.Context, accessKey, secret, region, awsConfig, registry string) (err error) { + authTokenProvider := &DefaultRegistryAuthTokenProvider{} + + var authToken string + if len(awsConfig) > 0 { + authToken, err = authTokenProvider.GetTokenByAWSConfig(ctx, awsConfig) + } else { + authToken, err = authTokenProvider.GetTokenByAWSKeySecret(ctx, accessKey, secret, region) + } if err != nil { return err } - authToken := out.AuthorizationData[0].AuthorizationToken - return TestRegistryWithAuthToken(*authToken, registry, http.DefaultClient.Do) + return TestRegistryWithAuthToken(authToken, registry, http.DefaultClient.Do) } // TestRegistryWithAuthToken test if the registry can be acccessed with auth token. -func TestRegistryWithAuthToken(authToken, registry string, getResponse func(req *http.Request) (*http.Response, error)) error { +func TestRegistryWithAuthToken(authToken, registry string, do Do) error { manifestPath := "/v2/eks-anywhere-packages/manifests/latest" req, err := http.NewRequest("GET", "https://"+registry+manifestPath, nil) @@ -52,7 +59,7 @@ func TestRegistryWithAuthToken(authToken, registry string, getResponse func(req } req.Header.Add("Authorization", "Basic "+authToken) - resp2, err := getResponse(req) + resp2, err := do(req) if err != nil { return err } @@ -76,3 +83,80 @@ func GetRegionalRegistry(defaultRegistry, region string) string { } return prodRegionalECRMap[region] } + +// RegistryAuthTokenProvider provides auth token for registry access. +type RegistryAuthTokenProvider interface { + GetTokenByAWSConfig(ctx context.Context, awsConfig string) (string, error) + GetTokenByAWSKeySecret(ctx context.Context, key, secret, region string) (string, error) +} + +// DefaultRegistryAuthTokenProvider provides auth token for AWS ECR registry access. +type DefaultRegistryAuthTokenProvider struct{} + +// GetTokenByAWSConfig get auth token by AWS config. +func (d *DefaultRegistryAuthTokenProvider) GetTokenByAWSConfig(ctx context.Context, awsConfig string) (string, error) { + file, err := os.CreateTemp("", "eksa-temp-aws-config-*") + if err != nil { + return "", err + } + if _, err := file.Write([]byte(awsConfig)); err != nil { + return "", err + } + defer os.Remove(file.Name()) + if err != nil { + return "", err + } + + cfg, err := config.LoadDefaultConfig(ctx, + config.WithSharedConfigFiles([]string{file.Name()}), + ) + if err != nil { + return "", err + } + + return getAuthorizationToken(cfg) +} + +// GetTokenByAWSKeySecret get auth token by AWS key and secret. +func (d *DefaultRegistryAuthTokenProvider) GetTokenByAWSKeySecret(ctx context.Context, key, secret, region string) (string, error) { + cfg, err := config.LoadDefaultConfig(ctx, + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(key, secret, "")), + config.WithRegion(region), + ) + if err != nil { + return "", err + } + + return getAuthorizationToken(cfg) +} + +func getAuthorizationToken(cfg aws.Config) (string, error) { + ecrClient := ecr.NewFromConfig(cfg) + out, err := ecrClient.GetAuthorizationToken(context.Background(), &ecr.GetAuthorizationTokenInput{}) + if err != nil { + return "", fmt.Errorf("ecrClient cannot get authorization token: %w", err) + } + authToken := out.AuthorizationData[0].AuthorizationToken + return *authToken, nil +} + +// Do is a function type that takes a http request and returns a http response. +type Do func(req *http.Request) (*http.Response, error) + +// TestRegistryAccessWithAWSConfig test if the AWS config has valid permission to access container registry. +func TestRegistryAccessWithAWSConfig(ctx context.Context, awsConfig, registry string, tokenProvider RegistryAuthTokenProvider, do Do) error { + token, err := tokenProvider.GetTokenByAWSConfig(ctx, awsConfig) + if err != nil { + return err + } + return TestRegistryWithAuthToken(token, registry, do) +} + +// TestRegistryAccessWithAWSKeySecret test if the AWS key and secret has valid permission to access container registry. +func TestRegistryAccessWithAWSKeySecret(ctx context.Context, key, secret, region, registry string, tokenProvider RegistryAuthTokenProvider, do Do) error { + token, err := tokenProvider.GetTokenByAWSKeySecret(ctx, key, secret, region) + if err != nil { + return err + } + return TestRegistryWithAuthToken(token, registry, do) +} diff --git a/pkg/curatedpackages/regional_registry_test.go b/pkg/curatedpackages/regional_registry_test.go index d573b6486e1fa..c051405de7b6e 100644 --- a/pkg/curatedpackages/regional_registry_test.go +++ b/pkg/curatedpackages/regional_registry_test.go @@ -2,41 +2,126 @@ package curatedpackages_test import ( "bytes" + "context" "io" "net/http" + "net/http/httptest" + "strings" "testing" "github.com/aws/eks-anywhere/pkg/curatedpackages" ) -func TestTestRegistry(t *testing.T) { - err := curatedpackages.TestRegistryWithAuthToken("authToken", "registry_url", func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewReader(nil)), - }, nil - }) - if err != nil { - t.Errorf("Registry is good, but error has been returned %v\n", err) +func TestTestRegistryWithAuthToken(t *testing.T) { + cases := []struct { + description string + statusCode int + hasError bool + }{ + {"200 status code does not cause error", 200, false}, + {"404 status code does not cause error", 404, false}, + {"400 status code causes error", 400, true}, } - err = curatedpackages.TestRegistryWithAuthToken("authToken", "registry_url", func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 404, - Body: io.NopCloser(bytes.NewReader(nil)), - }, nil - }) - if err != nil { - t.Errorf("Registry is good, but error has been returned %v\n", err) + for _, test := range cases { + t.Run(test.description, func(t *testing.T) { + err := curatedpackages.TestRegistryWithAuthToken("authToken", "registry_url", func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: test.statusCode, + Body: io.NopCloser(bytes.NewReader(nil)), + }, nil + }) + if test.hasError && err == nil { + t.Errorf("Error should have been returned") + } + if !test.hasError && err != nil { + t.Errorf("Registry is good, but error has been returned %v\n", err) + } + }) } +} - err = curatedpackages.TestRegistryWithAuthToken("authToken", "registry_url", func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 400, - Body: io.NopCloser(bytes.NewReader(nil)), - }, nil - }) - if err == nil { - t.Errorf("Error should have been returned") +var ( + goodAwsConfig = "good-aws-config" + badAwsConfig = "bad-aws-config" + + goodAwsKey = "good-aws-key" + goodAwsSecret = "good-aws-secret" + badAwsKey = "bad-aws-key" + badAwsSecret = "bad-aws-secret" + + goodAuthToken = "good-auth-token" + badAuthToken = "bad-auth-token" +) + +type mockRegistryAuthTokenProvider struct{} + +func (m *mockRegistryAuthTokenProvider) GetTokenByAWSConfig(ctx context.Context, awsConfig string) (string, error) { + if awsConfig == goodAwsConfig { + return goodAuthToken, nil + } + return badAuthToken, nil +} + +func (m *mockRegistryAuthTokenProvider) GetTokenByAWSKeySecret(ctx context.Context, key, secret, region string) (string, error) { + if key == goodAwsKey && secret == goodAwsSecret { + return goodAuthToken, nil } + return badAuthToken, nil +} + +func TestTestAWSConfigRegistryAccessWithAWSConfig(t *testing.T) { + registryServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if strings.Contains(auth, goodAuthToken) { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + } + })) + defer registryServer.Close() + + registry := strings.TrimPrefix(registryServer.URL, "https://") + mockProvider := &mockRegistryAuthTokenProvider{} + + t.Run("return error for bad auth token", func(t *testing.T) { + err := curatedpackages.TestRegistryAccessWithAWSConfig(context.Background(), badAwsConfig, registry, mockProvider, registryServer.Client().Do) + if err == nil { + t.Errorf("Error should have been returned") + } + }) + t.Run("return no error for good auth token", func(t *testing.T) { + err := curatedpackages.TestRegistryAccessWithAWSConfig(context.Background(), goodAwsConfig, registry, mockProvider, registryServer.Client().Do) + if err != nil { + t.Errorf("Error should not have been returned for good AWS Config: %s", err) + } + }) +} + +func TestTestAWSConfigRegistryAccessWithAWSKeySecret(t *testing.T) { + registryServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if strings.Contains(auth, goodAuthToken) { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + } + })) + defer registryServer.Close() + + registry := strings.TrimPrefix(registryServer.URL, "https://") + mockProvider := &mockRegistryAuthTokenProvider{} + + t.Run("return error for bad auth token", func(t *testing.T) { + err := curatedpackages.TestRegistryAccessWithAWSKeySecret(context.Background(), badAwsKey, badAwsSecret, "us-west-2", registry, mockProvider, registryServer.Client().Do) + if err == nil { + t.Errorf("Error should have been returned") + } + }) + t.Run("return no error for good auth token", func(t *testing.T) { + err := curatedpackages.TestRegistryAccessWithAWSKeySecret(context.Background(), goodAwsKey, goodAwsSecret, "us-west-2", registry, mockProvider, registryServer.Client().Do) + if err != nil { + t.Errorf("Error should not have been returned for good AWS Config: %s", err) + } + }) }