diff --git a/store/secretsmanagerstore.go b/store/secretsmanagerstore.go index d9abc41..1740018 100644 --- a/store/secretsmanagerstore.go +++ b/store/secretsmanagerstore.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "os" "reflect" "sort" "strconv" @@ -17,6 +18,12 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" ) +const ( + // CustomSecretsManagerEndpointEnvVar is the name of the environment variable specifying a custom + // base Secrets Manager endpoint. + CustomSecretsManagerEndpointEnvVar = "CHAMBER_AWS_SECRETS_MANAGER_ENDPOINT" +) + // We store all Chamber metadata in a stringified JSON format, // in a field named "_chamber_metadata" const metadataKey = "_chamber_metadata" @@ -85,6 +92,16 @@ func NewSecretsManagerStore(ctx context.Context, numRetries int) (*SecretsManage if err != nil { return nil, err } + customSecretsManagerEndpoint, ok := os.LookupEnv(CustomSecretsManagerEndpointEnvVar) + if ok { + cfg.BaseEndpoint = aws.String(customSecretsManagerEndpoint) + } else { + // Preserving incorrect and deprecated use of the SSM environment variable from v2 + customSecretsManagerEndpoint, ok = os.LookupEnv(CustomSSMEndpointEnvVar) + if ok { + cfg.BaseEndpoint = aws.String(customSecretsManagerEndpoint) + } + } svc := secretsmanager.NewFromConfig(cfg) diff --git a/store/secretsmanagerstore_test.go b/store/secretsmanagerstore_test.go index ed14e95..7406c9d 100644 --- a/store/secretsmanagerstore_test.go +++ b/store/secretsmanagerstore_test.go @@ -194,23 +194,33 @@ func TestNewSecretsManagerStore(t *testing.T) { assert.Equal(t, "us-west-1", s.config.Region) }) - t.Run("Should use CHAMBER_AWS_SSM_ENDPOINT if set", func(t *testing.T) { + t.Run("Should use CHAMBER_AWS_SECRETS_MANAGER_ENDPOINT if set", func(t *testing.T) { + os.Setenv("CHAMBER_AWS_SECRETS_MANAGER_ENDPOINT", "mycustomendpoint") + defer os.Unsetenv("CHAMBER_AWS_SECRETS_MANAGER_ENDPOINT") + + s, err := NewSecretsManagerStore(context.Background(), 1) + assert.Nil(t, err) + secretsmanagerClient := s.svc.(*secretsmanager.Client) + assert.Equal(t, "mycustomendpoint", *secretsmanagerClient.Options().BaseEndpoint) + // default endpoint resolution (v2) uses the client's BaseEndpoint + }) + + t.Run("Should use CHAMBER_AWS_SSM_ENDPOINT if set (deprecated)", func(t *testing.T) { os.Setenv("CHAMBER_AWS_SSM_ENDPOINT", "mycustomendpoint") defer os.Unsetenv("CHAMBER_AWS_SSM_ENDPOINT") s, err := NewSecretsManagerStore(context.Background(), 1) assert.Nil(t, err) - endpoint, err := s.config.EndpointResolverWithOptions.ResolveEndpoint(secretsmanager.ServiceID, "us-west-2") - assert.Nil(t, err) - assert.Equal(t, "mycustomendpoint", endpoint.URL) + secretsmanagerClient := s.svc.(*secretsmanager.Client) + assert.Equal(t, "mycustomendpoint", *secretsmanagerClient.Options().BaseEndpoint) + // default endpoint resolution (v2) uses the client's BaseEndpoint }) - t.Run("Should use default AWS SSM endpoint if CHAMBER_AWS_SSM_ENDPOINT not set", func(t *testing.T) { + t.Run("Should use default AWS secrets manager endpoint if CHAMBER_AWS_SECRETS_MANAGER_ENDPOINT not set", func(t *testing.T) { s, err := NewSecretsManagerStore(context.Background(), 1) assert.Nil(t, err) - _, err = s.config.EndpointResolverWithOptions.ResolveEndpoint(secretsmanager.ServiceID, "us-west-2") - var notFoundError *aws.EndpointNotFoundError - assert.ErrorAs(t, err, ¬FoundError) + secretsmanagerClient := s.svc.(*secretsmanager.Client) + assert.Nil(t, secretsmanagerClient.Options().BaseEndpoint) }) } diff --git a/store/shared.go b/store/shared.go index d40a6ba..937c301 100644 --- a/store/shared.go +++ b/store/shared.go @@ -10,23 +10,10 @@ import ( ) const ( - RegionEnvVar = "CHAMBER_AWS_REGION" - CustomSSMEndpointEnvVar = "CHAMBER_AWS_SSM_ENDPOINT" + RegionEnvVar = "CHAMBER_AWS_REGION" ) func getConfig(ctx context.Context, numRetries int, retryMode aws.RetryMode) (aws.Config, string, error) { - endpointResolver := func(service, region string, options ...interface{}) (aws.Endpoint, error) { - customSsmEndpoint, ok := os.LookupEnv(CustomSSMEndpointEnvVar) - if ok { - return aws.Endpoint{ - URL: customSsmEndpoint, - Source: aws.EndpointSourceCustom, - }, nil - } - - return aws.Endpoint{}, &aws.EndpointNotFoundError{} - } - var region string if regionOverride, ok := os.LookupEnv(RegionEnvVar); ok { region = regionOverride @@ -36,7 +23,6 @@ func getConfig(ctx context.Context, numRetries int, retryMode aws.RetryMode) (aw config.WithRegion(region), config.WithRetryMaxAttempts(numRetries), config.WithRetryMode(retryMode), - config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(endpointResolver)), ) if err != nil { return aws.Config{}, "", err diff --git a/store/shared_test.go b/store/shared_test.go index 148fac7..f8e664c 100644 --- a/store/shared_test.go +++ b/store/shared_test.go @@ -10,14 +10,6 @@ import ( ) func TestGetConfig(t *testing.T) { - originalEndpoint := os.Getenv(CustomSSMEndpointEnvVar) - os.Setenv(CustomSSMEndpointEnvVar, "https://example.com/custom-endpoint") - if originalEndpoint != "" { - defer os.Setenv(CustomSSMEndpointEnvVar, originalEndpoint) - } else { - defer os.Unsetenv(CustomSSMEndpointEnvVar) - } - originalRegion := os.Getenv(RegionEnvVar) os.Setenv(RegionEnvVar, "us-west-2") if originalRegion != "" { @@ -31,11 +23,6 @@ func TestGetConfig(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "us-west-2", region) - endpoint, err := config.EndpointResolverWithOptions.ResolveEndpoint("ssm", "us-west-2") - assert.NoError(t, err) - assert.Equal(t, "https://example.com/custom-endpoint", endpoint.URL) - assert.Equal(t, aws.EndpointSourceCustom, endpoint.Source) - assert.Equal(t, 3, config.RetryMaxAttempts) assert.Equal(t, aws.RetryModeStandard, config.RetryMode) } diff --git a/store/ssmstore.go b/store/ssmstore.go index 04b7917..3a98506 100644 --- a/store/ssmstore.go +++ b/store/ssmstore.go @@ -16,6 +16,10 @@ import ( ) const ( + // CustomSSMEndpointEnvVar is the name of the environment variable specifying a custom base SSM + // endpoint. + CustomSSMEndpointEnvVar = "CHAMBER_AWS_SSM_ENDPOINT" + // DefaultKeyID is the default alias for the KMS key used to encrypt/decrypt secrets DefaultKeyID = "alias/parameter_store_key" @@ -61,10 +65,13 @@ func NewSSMStoreWithRetryMode(ctx context.Context, numRetries int, retryMode aws func ssmStoreUsingRetryer(ctx context.Context, numRetries int, retryMode aws.RetryMode) (*SSMStore, error) { cfg, _, err := getConfig(ctx, numRetries, retryMode) - if err != nil { return nil, err } + customSsmEndpoint, ok := os.LookupEnv(CustomSSMEndpointEnvVar) + if ok { + cfg.BaseEndpoint = aws.String(customSsmEndpoint) + } svc := ssm.NewFromConfig(cfg) diff --git a/store/ssmstore_test.go b/store/ssmstore_test.go index 2adaa9d..83c000d 100644 --- a/store/ssmstore_test.go +++ b/store/ssmstore_test.go @@ -377,17 +377,16 @@ func TestNewSSMStore(t *testing.T) { s, err := NewSSMStore(context.Background(), 1) assert.Nil(t, err) - endpoint, err := s.config.EndpointResolverWithOptions.ResolveEndpoint(ssm.ServiceID, "us-west-2") - assert.Nil(t, err) - assert.Equal(t, "mycustomendpoint", endpoint.URL) + ssmClient := s.svc.(*ssm.Client) + assert.Equal(t, "mycustomendpoint", *ssmClient.Options().BaseEndpoint) + // default endpoint resolution (v2) uses the client's BaseEndpoint }) t.Run("Should use default AWS SSM endpoint if CHAMBER_AWS_SSM_ENDPOINT not set", func(t *testing.T) { s, err := NewSSMStore(context.Background(), 1) assert.Nil(t, err) - _, err = s.config.EndpointResolverWithOptions.ResolveEndpoint(ssm.ServiceID, "us-west-2") - var notFoundError *aws.EndpointNotFoundError - assert.ErrorAs(t, err, ¬FoundError) + ssmClient := s.svc.(*ssm.Client) + assert.Nil(t, ssmClient.Options().BaseEndpoint) }) t.Run("Should set AWS SDK retry mode to default", func(t *testing.T) {