From 540c8bd196f256004030b214b5a14bd7c2912bff Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Mon, 6 Jan 2025 09:43:31 -0800 Subject: [PATCH] synchronize access to generateOIDCTokenFn --- .../awsoidc/credprovider/credentialscache.go | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/lib/integrations/awsoidc/credprovider/credentialscache.go b/lib/integrations/awsoidc/credprovider/credentialscache.go index 5eceb6aa3e7b7..ad1db35e94bea 100644 --- a/lib/integrations/awsoidc/credprovider/credentialscache.go +++ b/lib/integrations/awsoidc/credprovider/credentialscache.go @@ -62,13 +62,19 @@ type CredentialsCache struct { roleARN arn.ARN integration string - // generateOIDCTokenFn can be dynamically set after auth is initialized. - generateOIDCTokenFn GenerateOIDCTokenFn - - // initialized communicates (via closing channel) that generateOIDCTokenFn is set. - initialized chan struct{} + // generateOIDCTokenFn can be dynamically set after creating the credential + // cache, this is a workaround for a dependency cycle where audit storage + // depends on the credential cache, the auth server depends on audit + // storage, and the credential cache depends on the auth server for a + // GenerateOIDCTokenFn. + generateOIDCTokenFn GenerateOIDCTokenFn + generateOIDCTokenFnMu sync.Mutex + // gotGenerateOIDCTokenFn communicates (via closing channel) that + // generateOIDCTokenFn is set. + gotGenerateOIDCTokenFn chan struct{} + closeGotGenerateOIDCTokenFn func() // allowRetrieveBeforeInit allows the Retrieve method to return an error if - // initialized has not been closed yet, instead of waiting for it to be + // [gotGenerateOIDCTokenFn] has not been closed yet, instead of waiting for it to be // closed. allowRetrieveBeforeInit bool @@ -136,36 +142,47 @@ func NewCredentialsCache(options CredentialsCacheOptions) (*CredentialsCache, er return nil, trace.Wrap(err, "creating credentials cache") } - initialized := make(chan struct{}) + gotGenerateOIDCTokenFn := make(chan struct{}) + closeGotGenerateOIDCTokenFn := sync.OnceFunc(func() { close(gotGenerateOIDCTokenFn) }) if options.GenerateOIDCTokenFn != nil { - close(initialized) + closeGotGenerateOIDCTokenFn() } + gotFirstCredsOrErr := make(chan struct{}) + closeGotFirstCredsOrErr := sync.OnceFunc(func() { close(gotFirstCredsOrErr) }) return &CredentialsCache{ - roleARN: options.RoleARN, - integration: options.Integration, - generateOIDCTokenFn: options.GenerateOIDCTokenFn, - initialized: initialized, - allowRetrieveBeforeInit: options.AllowRetrieveBeforeInit, - log: options.Log.With("integration", options.Integration), - gotFirstCredsOrErr: gotFirstCredsOrErr, - closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }), - credsOrErr: credsOrErr{err: errNotReady}, - clock: options.Clock, - stsClient: options.STSClient, + roleARN: options.RoleARN, + integration: options.Integration, + generateOIDCTokenFn: options.GenerateOIDCTokenFn, + gotGenerateOIDCTokenFn: gotGenerateOIDCTokenFn, + closeGotGenerateOIDCTokenFn: closeGotGenerateOIDCTokenFn, + allowRetrieveBeforeInit: options.AllowRetrieveBeforeInit, + log: options.Log.With("integration", options.Integration), + gotFirstCredsOrErr: gotFirstCredsOrErr, + closeGotFirstCredsOrErr: closeGotFirstCredsOrErr, + credsOrErr: credsOrErr{err: errNotReady}, + clock: options.Clock, + stsClient: options.STSClient, }, nil } // SetGenerateOIDCTokenFn can be used to set a GenerateOIDCTokenFn after // creating the credential cache, when dependencies require the credential cache // to be created before a valid GenerateOIDCTokenFn can be created. -// -// This must be called exactly once if and only if no GenerateOIDCTokenFn was -// passed to NewCredentialsCache. func (cc *CredentialsCache) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { + cc.generateOIDCTokenFnMu.Lock() + defer cc.generateOIDCTokenFnMu.Unlock() cc.generateOIDCTokenFn = fn - close(cc.initialized) + close(cc.gotGenerateOIDCTokenFn) +} + +// getGenerateOIDCTokenFn must not be called before [cc.gotGenerateOIDCTokenFn] +// has been closed, or it will return nil. +func (cc *CredentialsCache) getGenerateOIDCTokenFn() GenerateOIDCTokenFn { + cc.generateOIDCTokenFnMu.Lock() + defer cc.generateOIDCTokenFnMu.Unlock() + return cc.generateOIDCTokenFn } // Retrieve implements [aws.CredentialsProvider] and returns the latest cached @@ -193,9 +210,9 @@ func (cc *CredentialsCache) retrieve(ctx context.Context) (aws.Credentials, erro } func (cc *CredentialsCache) Run(ctx context.Context) { - // Wait for initialized signal before running loop. + // Wait for a generateOIDCTokenFn before running loop. select { - case <-cc.initialized: + case <-cc.gotGenerateOIDCTokenFn: case <-ctx.Done(): cc.log.DebugContext(ctx, "Context canceled before initialized.") return @@ -265,7 +282,7 @@ func (cc *CredentialsCache) refresh(ctx context.Context) (aws.Credentials, error defer cc.log.InfoContext(ctx, "Exiting AWS credentials refresh") cc.log.InfoContext(ctx, "Generating Token") - oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration) + oidcToken, err := cc.getGenerateOIDCTokenFn()(ctx, cc.integration) if err != nil { cc.log.ErrorContext(ctx, "Token generation failed", errorValue(err)) return aws.Credentials{}, trace.Wrap(err)