From 2bd71db3adedbafa3eded7ab5539d1bb548772fc Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Mon, 23 Dec 2024 12:42:34 -0800 Subject: [PATCH] Update awsconfig * Add a Cache for caching credentials, similar to SDK v1 session cache. * Add a Provider interface that provides aws.Config * Simplified role chaining options Unlike our SDK v1 session cache, the SDK v2 implementation in this PR does not include region as a cache key. There are regional AWS STS endpoints for lower latency calls, but the lowest latency path is to just grab credentials from the cache if we already have them - the region they were originally taken from doesn't matter. --- lib/cloud/awsconfig/awsconfig.go | 174 ++++++++++++++++---------- lib/cloud/awsconfig/awsconfig_test.go | 131 ++++++++++++++++++- lib/cloud/awsconfig/cache.go | 149 ++++++++++++++++++++++ lib/cloud/awsconfig/provider.go | 37 ++++++ 4 files changed, 417 insertions(+), 74 deletions(-) create mode 100644 lib/cloud/awsconfig/cache.go create mode 100644 lib/cloud/awsconfig/provider.go diff --git a/lib/cloud/awsconfig/awsconfig.go b/lib/cloud/awsconfig/awsconfig.go index 92f7e8aa96e86..8be00483f4012 100644 --- a/lib/cloud/awsconfig/awsconfig.go +++ b/lib/cloud/awsconfig/awsconfig.go @@ -47,16 +47,23 @@ const ( // This is used to generate aws configs for clients that must use an integration instead of ambient credentials. type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) +// AssumeRoleClientProviderFunc provides an AWS STS assume role API client. +type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient + +// AssumeRole is an AWS role to assume, optionally with an external ID. +type AssumeRole struct { + // RoleARN is the ARN of the role to assume. + RoleARN string `json:"role_arn"` + // ExternalID is an optional ID to include when assuming the role. + ExternalID string `json:"external_id"` +} + // options is a struct of additional options for assuming an AWS role // when construction an underlying AWS config. type options struct { - // baseConfigis a config to use instead of the default config for an - // AWS region, which is used to enable role chaining. - baseConfig *aws.Config - // assumeRoleARN is the AWS IAM Role ARN to assume. - assumeRoleARN string - // assumeRoleExternalID is used to assume an external AWS IAM Role. - assumeRoleExternalID string + // assumeRoles are AWS IAM roles that should be assumed one by one in order, + // as a chain of assumed roles. + assumeRoles []AssumeRole // credentialsSource describes which source to use to fetch credentials. credentialsSource credentialsSource // integration is the name of the integration to be used to fetch the credentials. @@ -67,22 +74,45 @@ type options struct { customRetryer func() aws.Retryer // maxRetries is the maximum number of retries to use for the config. maxRetries *int + // assumeRoleClientProvider sets the STS assume role client provider func. + assumeRoleClientProvider AssumeRoleClientProviderFunc } -func (a *options) checkAndSetDefaults() error { - switch a.credentialsSource { +func buildOptions(optFns ...OptionsFn) (*options, error) { + var opts options + for _, optFn := range optFns { + optFn(&opts) + } + if err := opts.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &opts, nil +} + +func (o *options) checkAndSetDefaults() error { + switch o.credentialsSource { case credentialsSourceAmbient: - if a.integration != "" { + if o.integration != "" { return trace.BadParameter("integration and ambient credentials cannot be used at the same time") } case credentialsSourceIntegration: - if a.integration == "" { + if o.integration == "" { return trace.BadParameter("missing integration name") } default: return trace.BadParameter("missing credentials source (ambient or integration)") } + if len(o.assumeRoles) > 2 { + return trace.BadParameter("role chain contains more than 2 roles") + } + if o.assumeRoleClientProvider == nil { + o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient { + return sts.NewFromConfig(cfg, func(o *sts.Options) { + o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider()) + }) + } + } return nil } @@ -93,8 +123,14 @@ type OptionsFn func(*options) // WithAssumeRole configures options needed for assuming an AWS role. func WithAssumeRole(roleARN, externalID string) OptionsFn { return func(options *options) { - options.assumeRoleARN = roleARN - options.assumeRoleExternalID = externalID + if roleARN == "" { + // ignore empty role ARN for caller convenience. + return + } + options.assumeRoles = append(options.assumeRoles, AssumeRole{ + RoleARN: roleARN, + ExternalID: externalID, + }) } } @@ -146,96 +182,98 @@ func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) O } } +// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to +// assume roles. +func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn { + return func(options *options) { + options.assumeRoleClientProvider = fn + } +} + // GetConfig returns an AWS config for the specified region, optionally // assuming AWS IAM Roles. -func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) { - var options options - for _, opt := range opts { - opt(&options) - } - if options.baseConfig == nil { - cfg, err := getConfigForRegion(ctx, region, options) - if err != nil { - return aws.Config{}, trace.Wrap(err) - } - options.baseConfig = &cfg +func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) { + opts, err := buildOptions(optFns...) + if err != nil { + return aws.Config{}, trace.Wrap(err) } - if options.assumeRoleARN == "" { - return *options.baseConfig, nil + + cfg, err := getBaseConfig(ctx, region, opts) + if err != nil { + return aws.Config{}, trace.Wrap(err) } - return getConfigForRole(ctx, region, options) + return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider) } -// ambientConfigProvider loads a new config using the environment variables. -func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) { - opts := buildConfigOptions(region, cred, options) - cfg, err := config.LoadDefaultConfig(context.Background(), opts...) +// loadDefaultConfig loads a new config. +func loadDefaultConfig(ctx context.Context, region string, cred aws.CredentialsProvider, opts *options) (aws.Config, error) { + configOpts := buildConfigOptions(region, cred, opts) + cfg, err := config.LoadDefaultConfig(ctx, configOpts...) return cfg, trace.Wrap(err) } -func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error { - opts := []func(*config.LoadOptions) error{ +func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *options) []func(*config.LoadOptions) error { + configOpts := []func(*config.LoadOptions) error{ config.WithDefaultRegion(defaultRegion), config.WithRegion(region), config.WithCredentialsProvider(cred), } if modules.GetModules().IsBoringBinary() { - opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled)) + configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled)) } - if options.customRetryer != nil { - opts = append(opts, config.WithRetryer(options.customRetryer)) + if opts.customRetryer != nil { + configOpts = append(configOpts, config.WithRetryer(opts.customRetryer)) } - if options.maxRetries != nil { - opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries)) + if opts.maxRetries != nil { + configOpts = append(configOpts, config.WithRetryMaxAttempts(*opts.maxRetries)) } - return opts + return configOpts } -// getConfigForRegion returns AWS config for the specified region. -func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) { - if err := options.checkAndSetDefaults(); err != nil { - return aws.Config{}, trace.Wrap(err) - } - +// getBaseConfig returns an AWS config without assuming any roles. +func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) { var cred aws.CredentialsProvider - if options.credentialsSource == credentialsSourceIntegration { - if options.integrationCredentialsProvider == nil { + if opts.credentialsSource == credentialsSourceIntegration { + if opts.integrationCredentialsProvider == nil { return aws.Config{}, trace.BadParameter("missing aws integration credential provider") } - slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration) + slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration) var err error - cred, err = options.integrationCredentialsProvider(ctx, region, options.integration) + cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration) if err != nil { return aws.Config{}, trace.Wrap(err) } } else { - slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region) + slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region) } - cfg, err := ambientConfigProvider(region, cred, options) + cfg, err := loadDefaultConfig(ctx, region, cred, opts) return cfg, trace.Wrap(err) } -// getConfigForRole returns an AWS config for the specified region and role. -func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) { - if err := options.checkAndSetDefaults(); err != nil { - return aws.Config{}, trace.Wrap(err) +func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) { + for _, r := range roles { + cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r) } - - stsClient := sts.NewFromConfig(*options.baseConfig, func(o *sts.Options) { - o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider()) - }) - cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) { - if options.assumeRoleExternalID != "" { - aro.ExternalID = aws.String(options.assumeRoleExternalID) + if len(roles) > 0 { + // no point caching every assumed role in the chain, we can just cache + // the last one. + cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, awsCredentialsCacheOptions) + if _, err := cfg.Credentials.Retrieve(ctx); err != nil { + return aws.Config{}, trace.Wrap(err) } - }) - if _, err := cred.Retrieve(ctx); err != nil { - return aws.Config{}, trace.Wrap(err) } + return cfg, nil +} - opts := buildConfigOptions(region, cred, options) - cfg, err := config.LoadDefaultConfig(ctx, opts...) - return cfg, trace.Wrap(err) +func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient, role AssumeRole) aws.CredentialsProvider { + slog.DebugContext(ctx, "Initializing AWS session for assumed role", + "assumed_role", role.RoleARN, + ) + return stscreds.NewAssumeRoleProvider(clt, role.RoleARN, func(aro *stscreds.AssumeRoleOptions) { + if role.ExternalID != "" { + aro.ExternalID = aws.String(role.ExternalID) + } + }) } diff --git a/lib/cloud/awsconfig/awsconfig_test.go b/lib/cloud/awsconfig/awsconfig_test.go index 5c0ab10ed6abb..3cb2c4eda3123 100644 --- a/lib/cloud/awsconfig/awsconfig_test.go +++ b/lib/cloud/awsconfig/awsconfig_test.go @@ -18,9 +18,15 @@ package awsconfig import ( "context" + "fmt" + "strings" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/gravitational/trace" "github.com/stretchr/testify/require" ) @@ -29,18 +35,60 @@ type mockCredentialProvider struct { cred aws.Credentials } -func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { +func (m *mockCredentialProvider) Retrieve(_ context.Context) (aws.Credentials, error) { return m.cred, nil } +type mockAssumeRoleAPIClient struct{} + +func (m *mockAssumeRoleAPIClient) AssumeRole(_ context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + fakeKeyID := fmt.Sprintf("role: %s, externalID: %s", aws.ToString(params.RoleArn), aws.ToString(params.ExternalId)) + return &sts.AssumeRoleOutput{ + AssumedRoleUser: &ststypes.AssumedRoleUser{ + Arn: params.RoleArn, + AssumedRoleId: aws.String("role-id"), + }, + Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String(fakeKeyID), + Expiration: aws.Time(time.Time{}), + SecretAccessKey: aws.String("fake-secret-access-key"), + SessionToken: aws.String("fake-session-token"), + }, + }, nil +} + func TestGetConfigIntegration(t *testing.T) { t.Parallel() + + cache, err := NewCache() + require.NoError(t, err) + tests := []struct { + desc string + Provider + }{ + { + desc: "uncached", + Provider: ProviderFunc(GetConfig), + }, + { + desc: "cached", + Provider: cache, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + testGetConfigIntegration(t, test.Provider) + }) + } +} + +func testGetConfigIntegration(t *testing.T, provider Provider) { dummyIntegration := "integration-test" dummyRegion := "test-region-123" t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) { ctx := context.Background() - _, err := GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration)) + _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration)) require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err) require.ErrorContains(t, err, "missing aws integration credential provider") }) @@ -48,7 +96,7 @@ func TestGetConfigIntegration(t *testing.T) { t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) { ctx := context.Background() - cfg, err := GetConfig(ctx, dummyRegion, + cfg, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { if region == dummyRegion && integration == dummyIntegration { @@ -66,10 +114,68 @@ func TestGetConfigIntegration(t *testing.T) { require.Equal(t, "foo-bar", creds.SessionToken) }) + t.Run("with an integration credential provider assuming a role, must return assumed role credentials", func(t *testing.T) { + ctx := context.Background() + + cfg, err := provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + if region == dummyRegion && integration == dummyIntegration { + return &mockCredentialProvider{ + cred: aws.Credentials{ + SessionToken: "foo-bar", + }, + }, nil + } + return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) + }), + WithAssumeRole("roleA", "abc123"), + WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient { + creds, err := cfg.Credentials.Retrieve(context.Background()) + require.NoError(t, err) + require.Equal(t, "foo-bar", creds.SessionToken) + return &mockAssumeRoleAPIClient{} + }), + ) + require.NoError(t, err) + creds, err := cfg.Credentials.Retrieve(ctx) + require.NoError(t, err) + require.Equal(t, "role: roleA, externalID: abc123", creds.AccessKeyID) + require.Equal(t, "fake-session-token", creds.SessionToken) + }) + + t.Run("with an integration credential provider assuming a role, must limit role chain length", func(t *testing.T) { + ctx := context.Background() + _, err := provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + if region == dummyRegion && integration == dummyIntegration { + return &mockCredentialProvider{ + cred: aws.Credentials{ + SessionToken: "foo-bar", + }, + }, nil + } + return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) + }), + WithAssumeRole("roleA", "abc123"), + WithAssumeRole("roleB", "abc123"), + WithAssumeRole("roleC", "abc123"), + WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient { + creds, err := cfg.Credentials.Retrieve(context.Background()) + require.NoError(t, err) + require.Equal(t, "foo-bar", creds.SessionToken) + return &mockAssumeRoleAPIClient{} + }), + ) + require.Error(t, err) + require.ErrorContains(t, err, "role chain contains more than 2 roles") + }) + t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) { ctx := context.Background() - _, err := GetConfig(ctx, dummyRegion, + _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(""), WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { require.Fail(t, "this function should not be called") @@ -81,7 +187,7 @@ func TestGetConfigIntegration(t *testing.T) { t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) { ctx := context.Background() - _, err := GetConfig(ctx, dummyRegion, + _, err := provider.GetConfig(ctx, dummyRegion, WithAmbientCredentials(), WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { require.Fail(t, "this function should not be called") @@ -93,7 +199,7 @@ func TestGetConfigIntegration(t *testing.T) { t.Run("with an integration credential provider, but no credential source", func(t *testing.T) { ctx := context.Background() - _, err := GetConfig(ctx, dummyRegion, + _, err := provider.GetConfig(ctx, dummyRegion, WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { require.Fail(t, "this function should not be called") return nil, nil @@ -102,3 +208,16 @@ func TestGetConfigIntegration(t *testing.T) { require.ErrorContains(t, err, "missing credentials source") }) } + +func TestNewCacheKey(t *testing.T) { + roleChain := []AssumeRole{ + {RoleARN: "roleA"}, + {RoleARN: "roleB", ExternalID: "abc123"}, + } + got, err := newCacheKey("integration-name", roleChain...) + require.NoError(t, err) + want := strings.TrimSpace(` +{"integration":"integration-name","role_chain":[{"role_arn":"roleA","external_id":""},{"role_arn":"roleB","external_id":"abc123"}]} +`) + require.Equal(t, want, got) +} diff --git a/lib/cloud/awsconfig/cache.go b/lib/cloud/awsconfig/cache.go new file mode 100644 index 0000000000000..15c98cecb854d --- /dev/null +++ b/lib/cloud/awsconfig/cache.go @@ -0,0 +1,149 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package awsconfig + +import ( + "context" + "encoding/json" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/utils" +) + +func awsCredentialsCacheOptions(opts *aws.CredentialsCacheOptions) { + // expire early to avoid expiration race. + opts.ExpiryWindow = 2 * time.Minute +} + +// Cache is an AWS config [Provider] that caches credentials by integration and +// role. +type Cache struct { + awsConfigCache *utils.FnCache +} + +var _ Provider = (*Cache)(nil) + +// NewCache returns a new [Cache]. +func NewCache() (*Cache, error) { + c, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: 15 * time.Minute, + ReloadOnErr: true, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return &Cache{ + awsConfigCache: c, + }, nil +} + +// GetConfig returns an [aws.Config] for the given region and options. +func (c *Cache) GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) { + opts, err := buildOptions(optFns...) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + + cfg, err := c.getBaseConfig(ctx, region, opts) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + cfg, err = c.getConfigForRoleChain(ctx, cfg, opts) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + return cfg, nil +} + +func (c *Cache) getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) { + // The AWS SDK combines config loading with default credential chain + // loading. + // We cache the entire config by integration name, which is empty for + // non-integration config, but only use credentials from it on cache hit. + cacheKey, err := newCacheKey(opts.integration) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + var reloaded bool + cfg, err := utils.FnCacheGet(ctx, c.awsConfigCache, cacheKey, + func(ctx context.Context) (aws.Config, error) { + reloaded = true + cfg, err := getBaseConfig(ctx, region, opts) + return cfg, trace.Wrap(err) + }) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + + if reloaded { + // If the cache reload func was called, then the config we got back has + // already applied our options so we can return the config itself. + return cfg, nil + } + + // On cache hit we just take the credentials from the cached config. + // Then, we apply those credentials while loading config with current + // options. + cfg, err = loadDefaultConfig(ctx, region, cfg.Credentials, opts) + return cfg, trace.Wrap(err) +} + +func (c *Cache) getConfigForRoleChain(ctx context.Context, cfg aws.Config, opts *options) (aws.Config, error) { + for i, r := range opts.assumeRoles { + // cache credentials by integration and assumed-role chain. + cacheKey, err := newCacheKey(opts.integration, opts.assumeRoles[:i+1]...) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + credProvider, err := utils.FnCacheGet(ctx, c.awsConfigCache, cacheKey, + func(ctx context.Context) (aws.CredentialsProvider, error) { + clt := opts.assumeRoleClientProvider(cfg) + credProvider := getAssumeRoleProvider(ctx, clt, r) + cc := aws.NewCredentialsCache(credProvider, + awsCredentialsCacheOptions, + ) + if _, err := cc.Retrieve(ctx); err != nil { + return nil, trace.Wrap(err) + } + return cc, nil + }) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + cfg.Credentials = credProvider + } + return cfg, nil +} + +// newCacheKey returns a cache key for AWS credentials. +// The cache key can be used to get role credentials without calling AWS STS. +// Therefore, we marshal the key as JSON to be sure the input cannot be +// manipulated to retrieve other credentials. +func newCacheKey(integrationName string, roleChain ...AssumeRole) (string, error) { + type configCacheKey struct { + Integration string `json:"integration"` + RoleChain []AssumeRole `json:"role_chain"` + } + out, err := json.Marshal(configCacheKey{ + Integration: integrationName, + RoleChain: roleChain, + }) + return string(out), trace.Wrap(err) +} diff --git a/lib/cloud/awsconfig/provider.go b/lib/cloud/awsconfig/provider.go new file mode 100644 index 0000000000000..cff06964ce785 --- /dev/null +++ b/lib/cloud/awsconfig/provider.go @@ -0,0 +1,37 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package awsconfig + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +// Provider provides an [aws.Config]. +type Provider interface { + // GetConfig returns an [aws.Config] for the given region and options. + GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) +} + +// ProviderFunc is a [Provider] adapter for functions. +type ProviderFunc func(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) + +// GetConfig returns an [aws.Config] for the given region and options. +func (fn ProviderFunc) GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) { + return fn(ctx, region, optFns...) +}