diff --git a/lib/cloud/awsconfig/awsconfig.go b/lib/cloud/awsconfig/awsconfig.go
index 92f7e8aa96e86..8be00483f4012 100644
--- a/lib/cloud/awsconfig/awsconfig.go
+++ b/lib/cloud/awsconfig/awsconfig.go
@@ -47,16 +47,23 @@ const (
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)
+// AssumeRoleClientProviderFunc provides an AWS STS assume role API client.
+type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient
+
+// AssumeRole is an AWS role to assume, optionally with an external ID.
+type AssumeRole struct {
+ // RoleARN is the ARN of the role to assume.
+ RoleARN string `json:"role_arn"`
+ // ExternalID is an optional ID to include when assuming the role.
+ ExternalID string `json:"external_id"`
+}
+
// options is a struct of additional options for assuming an AWS role
// when construction an underlying AWS config.
type options struct {
- // baseConfigis a config to use instead of the default config for an
- // AWS region, which is used to enable role chaining.
- baseConfig *aws.Config
- // assumeRoleARN is the AWS IAM Role ARN to assume.
- assumeRoleARN string
- // assumeRoleExternalID is used to assume an external AWS IAM Role.
- assumeRoleExternalID string
+ // assumeRoles are AWS IAM roles that should be assumed one by one in order,
+ // as a chain of assumed roles.
+ assumeRoles []AssumeRole
// credentialsSource describes which source to use to fetch credentials.
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
@@ -67,22 +74,45 @@ type options struct {
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
+ // assumeRoleClientProvider sets the STS assume role client provider func.
+ assumeRoleClientProvider AssumeRoleClientProviderFunc
}
-func (a *options) checkAndSetDefaults() error {
- switch a.credentialsSource {
+func buildOptions(optFns ...OptionsFn) (*options, error) {
+ var opts options
+ for _, optFn := range optFns {
+ optFn(&opts)
+ }
+ if err := opts.checkAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &opts, nil
+}
+
+func (o *options) checkAndSetDefaults() error {
+ switch o.credentialsSource {
case credentialsSourceAmbient:
- if a.integration != "" {
+ if o.integration != "" {
return trace.BadParameter("integration and ambient credentials cannot be used at the same time")
}
case credentialsSourceIntegration:
- if a.integration == "" {
+ if o.integration == "" {
return trace.BadParameter("missing integration name")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}
+ if len(o.assumeRoles) > 2 {
+ return trace.BadParameter("role chain contains more than 2 roles")
+ }
+ if o.assumeRoleClientProvider == nil {
+ o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
+ return sts.NewFromConfig(cfg, func(o *sts.Options) {
+ o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
+ })
+ }
+ }
return nil
}
@@ -93,8 +123,14 @@ type OptionsFn func(*options)
// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) OptionsFn {
return func(options *options) {
- options.assumeRoleARN = roleARN
- options.assumeRoleExternalID = externalID
+ if roleARN == "" {
+ // ignore empty role ARN for caller convenience.
+ return
+ }
+ options.assumeRoles = append(options.assumeRoles, AssumeRole{
+ RoleARN: roleARN,
+ ExternalID: externalID,
+ })
}
}
@@ -146,96 +182,98 @@ func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) O
}
}
+// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to
+// assume roles.
+func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn {
+ return func(options *options) {
+ options.assumeRoleClientProvider = fn
+ }
+}
+
// GetConfig returns an AWS config for the specified region, optionally
// assuming AWS IAM Roles.
-func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) {
- var options options
- for _, opt := range opts {
- opt(&options)
- }
- if options.baseConfig == nil {
- cfg, err := getConfigForRegion(ctx, region, options)
- if err != nil {
- return aws.Config{}, trace.Wrap(err)
- }
- options.baseConfig = &cfg
+func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) {
+ opts, err := buildOptions(optFns...)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
}
- if options.assumeRoleARN == "" {
- return *options.baseConfig, nil
+
+ cfg, err := getBaseConfig(ctx, region, opts)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
}
- return getConfigForRole(ctx, region, options)
+ return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider)
}
-// ambientConfigProvider loads a new config using the environment variables.
-func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) {
- opts := buildConfigOptions(region, cred, options)
- cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
+// loadDefaultConfig loads a new config.
+func loadDefaultConfig(ctx context.Context, region string, cred aws.CredentialsProvider, opts *options) (aws.Config, error) {
+ configOpts := buildConfigOptions(region, cred, opts)
+ cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
return cfg, trace.Wrap(err)
}
-func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error {
- opts := []func(*config.LoadOptions) error{
+func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *options) []func(*config.LoadOptions) error {
+ configOpts := []func(*config.LoadOptions) error{
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
}
if modules.GetModules().IsBoringBinary() {
- opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
+ configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
}
- if options.customRetryer != nil {
- opts = append(opts, config.WithRetryer(options.customRetryer))
+ if opts.customRetryer != nil {
+ configOpts = append(configOpts, config.WithRetryer(opts.customRetryer))
}
- if options.maxRetries != nil {
- opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries))
+ if opts.maxRetries != nil {
+ configOpts = append(configOpts, config.WithRetryMaxAttempts(*opts.maxRetries))
}
- return opts
+ return configOpts
}
-// getConfigForRegion returns AWS config for the specified region.
-func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) {
- if err := options.checkAndSetDefaults(); err != nil {
- return aws.Config{}, trace.Wrap(err)
- }
-
+// getBaseConfig returns an AWS config without assuming any roles.
+func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
var cred aws.CredentialsProvider
- if options.credentialsSource == credentialsSourceIntegration {
- if options.integrationCredentialsProvider == nil {
+ if opts.credentialsSource == credentialsSourceIntegration {
+ if opts.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
}
- slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
+ slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration)
var err error
- cred, err = options.integrationCredentialsProvider(ctx, region, options.integration)
+ cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
- slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
+ slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region)
}
- cfg, err := ambientConfigProvider(region, cred, options)
+ cfg, err := loadDefaultConfig(ctx, region, cred, opts)
return cfg, trace.Wrap(err)
}
-// getConfigForRole returns an AWS config for the specified region and role.
-func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) {
- if err := options.checkAndSetDefaults(); err != nil {
- return aws.Config{}, trace.Wrap(err)
+func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) {
+ for _, r := range roles {
+ cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
}
-
- stsClient := sts.NewFromConfig(*options.baseConfig, func(o *sts.Options) {
- o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
- })
- cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
- if options.assumeRoleExternalID != "" {
- aro.ExternalID = aws.String(options.assumeRoleExternalID)
+ if len(roles) > 0 {
+ // no point caching every assumed role in the chain, we can just cache
+ // the last one.
+ cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, awsCredentialsCacheOptions)
+ if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
+ return aws.Config{}, trace.Wrap(err)
}
- })
- if _, err := cred.Retrieve(ctx); err != nil {
- return aws.Config{}, trace.Wrap(err)
}
+ return cfg, nil
+}
- opts := buildConfigOptions(region, cred, options)
- cfg, err := config.LoadDefaultConfig(ctx, opts...)
- return cfg, trace.Wrap(err)
+func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient, role AssumeRole) aws.CredentialsProvider {
+ slog.DebugContext(ctx, "Initializing AWS session for assumed role",
+ "assumed_role", role.RoleARN,
+ )
+ return stscreds.NewAssumeRoleProvider(clt, role.RoleARN, func(aro *stscreds.AssumeRoleOptions) {
+ if role.ExternalID != "" {
+ aro.ExternalID = aws.String(role.ExternalID)
+ }
+ })
}
diff --git a/lib/cloud/awsconfig/awsconfig_test.go b/lib/cloud/awsconfig/awsconfig_test.go
index 5c0ab10ed6abb..3cb2c4eda3123 100644
--- a/lib/cloud/awsconfig/awsconfig_test.go
+++ b/lib/cloud/awsconfig/awsconfig_test.go
@@ -18,9 +18,15 @@ package awsconfig
import (
"context"
+ "fmt"
+ "strings"
"testing"
+ "time"
"github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
+ ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)
@@ -29,18 +35,60 @@ type mockCredentialProvider struct {
cred aws.Credentials
}
-func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
+func (m *mockCredentialProvider) Retrieve(_ context.Context) (aws.Credentials, error) {
return m.cred, nil
}
+type mockAssumeRoleAPIClient struct{}
+
+func (m *mockAssumeRoleAPIClient) AssumeRole(_ context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) {
+ fakeKeyID := fmt.Sprintf("role: %s, externalID: %s", aws.ToString(params.RoleArn), aws.ToString(params.ExternalId))
+ return &sts.AssumeRoleOutput{
+ AssumedRoleUser: &ststypes.AssumedRoleUser{
+ Arn: params.RoleArn,
+ AssumedRoleId: aws.String("role-id"),
+ },
+ Credentials: &ststypes.Credentials{
+ AccessKeyId: aws.String(fakeKeyID),
+ Expiration: aws.Time(time.Time{}),
+ SecretAccessKey: aws.String("fake-secret-access-key"),
+ SessionToken: aws.String("fake-session-token"),
+ },
+ }, nil
+}
+
func TestGetConfigIntegration(t *testing.T) {
t.Parallel()
+
+ cache, err := NewCache()
+ require.NoError(t, err)
+ tests := []struct {
+ desc string
+ Provider
+ }{
+ {
+ desc: "uncached",
+ Provider: ProviderFunc(GetConfig),
+ },
+ {
+ desc: "cached",
+ Provider: cache,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ testGetConfigIntegration(t, test.Provider)
+ })
+ }
+}
+
+func testGetConfigIntegration(t *testing.T, provider Provider) {
dummyIntegration := "integration-test"
dummyRegion := "test-region-123"
t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) {
ctx := context.Background()
- _, err := GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
+ _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err)
require.ErrorContains(t, err, "missing aws integration credential provider")
})
@@ -48,7 +96,7 @@ func TestGetConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) {
ctx := context.Background()
- cfg, err := GetConfig(ctx, dummyRegion,
+ cfg, err := provider.GetConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(dummyIntegration),
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
if region == dummyRegion && integration == dummyIntegration {
@@ -66,10 +114,68 @@ func TestGetConfigIntegration(t *testing.T) {
require.Equal(t, "foo-bar", creds.SessionToken)
})
+ t.Run("with an integration credential provider assuming a role, must return assumed role credentials", func(t *testing.T) {
+ ctx := context.Background()
+
+ cfg, err := provider.GetConfig(ctx, dummyRegion,
+ WithCredentialsMaybeIntegration(dummyIntegration),
+ WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
+ if region == dummyRegion && integration == dummyIntegration {
+ return &mockCredentialProvider{
+ cred: aws.Credentials{
+ SessionToken: "foo-bar",
+ },
+ }, nil
+ }
+ return nil, trace.NotFound("no creds in region %q with integration %q", region, integration)
+ }),
+ WithAssumeRole("roleA", "abc123"),
+ WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
+ creds, err := cfg.Credentials.Retrieve(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, "foo-bar", creds.SessionToken)
+ return &mockAssumeRoleAPIClient{}
+ }),
+ )
+ require.NoError(t, err)
+ creds, err := cfg.Credentials.Retrieve(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "role: roleA, externalID: abc123", creds.AccessKeyID)
+ require.Equal(t, "fake-session-token", creds.SessionToken)
+ })
+
+ t.Run("with an integration credential provider assuming a role, must limit role chain length", func(t *testing.T) {
+ ctx := context.Background()
+ _, err := provider.GetConfig(ctx, dummyRegion,
+ WithCredentialsMaybeIntegration(dummyIntegration),
+ WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
+ if region == dummyRegion && integration == dummyIntegration {
+ return &mockCredentialProvider{
+ cred: aws.Credentials{
+ SessionToken: "foo-bar",
+ },
+ }, nil
+ }
+ return nil, trace.NotFound("no creds in region %q with integration %q", region, integration)
+ }),
+ WithAssumeRole("roleA", "abc123"),
+ WithAssumeRole("roleB", "abc123"),
+ WithAssumeRole("roleC", "abc123"),
+ WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
+ creds, err := cfg.Credentials.Retrieve(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, "foo-bar", creds.SessionToken)
+ return &mockAssumeRoleAPIClient{}
+ }),
+ )
+ require.Error(t, err)
+ require.ErrorContains(t, err, "role chain contains more than 2 roles")
+ })
+
t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) {
ctx := context.Background()
- _, err := GetConfig(ctx, dummyRegion,
+ _, err := provider.GetConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(""),
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
@@ -81,7 +187,7 @@ func TestGetConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) {
ctx := context.Background()
- _, err := GetConfig(ctx, dummyRegion,
+ _, err := provider.GetConfig(ctx, dummyRegion,
WithAmbientCredentials(),
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
@@ -93,7 +199,7 @@ func TestGetConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, but no credential source", func(t *testing.T) {
ctx := context.Background()
- _, err := GetConfig(ctx, dummyRegion,
+ _, err := provider.GetConfig(ctx, dummyRegion,
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
@@ -102,3 +208,16 @@ func TestGetConfigIntegration(t *testing.T) {
require.ErrorContains(t, err, "missing credentials source")
})
}
+
+func TestNewCacheKey(t *testing.T) {
+ roleChain := []AssumeRole{
+ {RoleARN: "roleA"},
+ {RoleARN: "roleB", ExternalID: "abc123"},
+ }
+ got, err := newCacheKey("integration-name", roleChain...)
+ require.NoError(t, err)
+ want := strings.TrimSpace(`
+{"integration":"integration-name","role_chain":[{"role_arn":"roleA","external_id":""},{"role_arn":"roleB","external_id":"abc123"}]}
+`)
+ require.Equal(t, want, got)
+}
diff --git a/lib/cloud/awsconfig/cache.go b/lib/cloud/awsconfig/cache.go
new file mode 100644
index 0000000000000..15c98cecb854d
--- /dev/null
+++ b/lib/cloud/awsconfig/cache.go
@@ -0,0 +1,149 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package awsconfig
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+func awsCredentialsCacheOptions(opts *aws.CredentialsCacheOptions) {
+ // expire early to avoid expiration race.
+ opts.ExpiryWindow = 2 * time.Minute
+}
+
+// Cache is an AWS config [Provider] that caches credentials by integration and
+// role.
+type Cache struct {
+ awsConfigCache *utils.FnCache
+}
+
+var _ Provider = (*Cache)(nil)
+
+// NewCache returns a new [Cache].
+func NewCache() (*Cache, error) {
+ c, err := utils.NewFnCache(utils.FnCacheConfig{
+ TTL: 15 * time.Minute,
+ ReloadOnErr: true,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &Cache{
+ awsConfigCache: c,
+ }, nil
+}
+
+// GetConfig returns an [aws.Config] for the given region and options.
+func (c *Cache) GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) {
+ opts, err := buildOptions(optFns...)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+
+ cfg, err := c.getBaseConfig(ctx, region, opts)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+ cfg, err = c.getConfigForRoleChain(ctx, cfg, opts)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+ return cfg, nil
+}
+
+func (c *Cache) getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
+ // The AWS SDK combines config loading with default credential chain
+ // loading.
+ // We cache the entire config by integration name, which is empty for
+ // non-integration config, but only use credentials from it on cache hit.
+ cacheKey, err := newCacheKey(opts.integration)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+ var reloaded bool
+ cfg, err := utils.FnCacheGet(ctx, c.awsConfigCache, cacheKey,
+ func(ctx context.Context) (aws.Config, error) {
+ reloaded = true
+ cfg, err := getBaseConfig(ctx, region, opts)
+ return cfg, trace.Wrap(err)
+ })
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+
+ if reloaded {
+ // If the cache reload func was called, then the config we got back has
+ // already applied our options so we can return the config itself.
+ return cfg, nil
+ }
+
+ // On cache hit we just take the credentials from the cached config.
+ // Then, we apply those credentials while loading config with current
+ // options.
+ cfg, err = loadDefaultConfig(ctx, region, cfg.Credentials, opts)
+ return cfg, trace.Wrap(err)
+}
+
+func (c *Cache) getConfigForRoleChain(ctx context.Context, cfg aws.Config, opts *options) (aws.Config, error) {
+ for i, r := range opts.assumeRoles {
+ // cache credentials by integration and assumed-role chain.
+ cacheKey, err := newCacheKey(opts.integration, opts.assumeRoles[:i+1]...)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+ credProvider, err := utils.FnCacheGet(ctx, c.awsConfigCache, cacheKey,
+ func(ctx context.Context) (aws.CredentialsProvider, error) {
+ clt := opts.assumeRoleClientProvider(cfg)
+ credProvider := getAssumeRoleProvider(ctx, clt, r)
+ cc := aws.NewCredentialsCache(credProvider,
+ awsCredentialsCacheOptions,
+ )
+ if _, err := cc.Retrieve(ctx); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return cc, nil
+ })
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+ cfg.Credentials = credProvider
+ }
+ return cfg, nil
+}
+
+// newCacheKey returns a cache key for AWS credentials.
+// The cache key can be used to get role credentials without calling AWS STS.
+// Therefore, we marshal the key as JSON to be sure the input cannot be
+// manipulated to retrieve other credentials.
+func newCacheKey(integrationName string, roleChain ...AssumeRole) (string, error) {
+ type configCacheKey struct {
+ Integration string `json:"integration"`
+ RoleChain []AssumeRole `json:"role_chain"`
+ }
+ out, err := json.Marshal(configCacheKey{
+ Integration: integrationName,
+ RoleChain: roleChain,
+ })
+ return string(out), trace.Wrap(err)
+}
diff --git a/lib/cloud/awsconfig/provider.go b/lib/cloud/awsconfig/provider.go
new file mode 100644
index 0000000000000..cff06964ce785
--- /dev/null
+++ b/lib/cloud/awsconfig/provider.go
@@ -0,0 +1,37 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package awsconfig
+
+import (
+ "context"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+)
+
+// Provider provides an [aws.Config].
+type Provider interface {
+ // GetConfig returns an [aws.Config] for the given region and options.
+ GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error)
+}
+
+// ProviderFunc is a [Provider] adapter for functions.
+type ProviderFunc func(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error)
+
+// GetConfig returns an [aws.Config] for the given region and options.
+func (fn ProviderFunc) GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) {
+ return fn(ctx, region, optFns...)
+}