Skip to content

Commit

Permalink
synchronize access to generateOIDCTokenFn
Browse files Browse the repository at this point in the history
  • Loading branch information
nklaassen committed Jan 6, 2025
1 parent ab420c1 commit 540c8bd
Showing 1 changed file with 43 additions and 26 deletions.
69 changes: 43 additions & 26 deletions lib/integrations/awsoidc/credprovider/credentialscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,19 @@ type CredentialsCache struct {
roleARN arn.ARN
integration string

// generateOIDCTokenFn can be dynamically set after auth is initialized.
generateOIDCTokenFn GenerateOIDCTokenFn

// initialized communicates (via closing channel) that generateOIDCTokenFn is set.
initialized chan struct{}
// generateOIDCTokenFn can be dynamically set after creating the credential
// cache, this is a workaround for a dependency cycle where audit storage
// depends on the credential cache, the auth server depends on audit
// storage, and the credential cache depends on the auth server for a
// GenerateOIDCTokenFn.
generateOIDCTokenFn GenerateOIDCTokenFn
generateOIDCTokenFnMu sync.Mutex
// gotGenerateOIDCTokenFn communicates (via closing channel) that
// generateOIDCTokenFn is set.
gotGenerateOIDCTokenFn chan struct{}
closeGotGenerateOIDCTokenFn func()
// allowRetrieveBeforeInit allows the Retrieve method to return an error if
// initialized has not been closed yet, instead of waiting for it to be
// [gotGenerateOIDCTokenFn] has not been closed yet, instead of waiting for it to be
// closed.
allowRetrieveBeforeInit bool

Expand Down Expand Up @@ -136,36 +142,47 @@ func NewCredentialsCache(options CredentialsCacheOptions) (*CredentialsCache, er
return nil, trace.Wrap(err, "creating credentials cache")
}

initialized := make(chan struct{})
gotGenerateOIDCTokenFn := make(chan struct{})
closeGotGenerateOIDCTokenFn := sync.OnceFunc(func() { close(gotGenerateOIDCTokenFn) })
if options.GenerateOIDCTokenFn != nil {
close(initialized)
closeGotGenerateOIDCTokenFn()
}

gotFirstCredsOrErr := make(chan struct{})
closeGotFirstCredsOrErr := sync.OnceFunc(func() { close(gotFirstCredsOrErr) })

return &CredentialsCache{
roleARN: options.RoleARN,
integration: options.Integration,
generateOIDCTokenFn: options.GenerateOIDCTokenFn,
initialized: initialized,
allowRetrieveBeforeInit: options.AllowRetrieveBeforeInit,
log: options.Log.With("integration", options.Integration),
gotFirstCredsOrErr: gotFirstCredsOrErr,
closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }),
credsOrErr: credsOrErr{err: errNotReady},
clock: options.Clock,
stsClient: options.STSClient,
roleARN: options.RoleARN,
integration: options.Integration,
generateOIDCTokenFn: options.GenerateOIDCTokenFn,
gotGenerateOIDCTokenFn: gotGenerateOIDCTokenFn,
closeGotGenerateOIDCTokenFn: closeGotGenerateOIDCTokenFn,
allowRetrieveBeforeInit: options.AllowRetrieveBeforeInit,
log: options.Log.With("integration", options.Integration),
gotFirstCredsOrErr: gotFirstCredsOrErr,
closeGotFirstCredsOrErr: closeGotFirstCredsOrErr,
credsOrErr: credsOrErr{err: errNotReady},
clock: options.Clock,
stsClient: options.STSClient,
}, nil
}

// SetGenerateOIDCTokenFn can be used to set a GenerateOIDCTokenFn after
// creating the credential cache, when dependencies require the credential cache
// to be created before a valid GenerateOIDCTokenFn can be created.
//
// This must be called exactly once if and only if no GenerateOIDCTokenFn was
// passed to NewCredentialsCache.
func (cc *CredentialsCache) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) {
cc.generateOIDCTokenFnMu.Lock()
defer cc.generateOIDCTokenFnMu.Unlock()
cc.generateOIDCTokenFn = fn
close(cc.initialized)
close(cc.gotGenerateOIDCTokenFn)
}

// getGenerateOIDCTokenFn must not be called before [cc.gotGenerateOIDCTokenFn]
// has been closed, or it will return nil.
func (cc *CredentialsCache) getGenerateOIDCTokenFn() GenerateOIDCTokenFn {
cc.generateOIDCTokenFnMu.Lock()
defer cc.generateOIDCTokenFnMu.Unlock()
return cc.generateOIDCTokenFn
}

// Retrieve implements [aws.CredentialsProvider] and returns the latest cached
Expand Down Expand Up @@ -193,9 +210,9 @@ func (cc *CredentialsCache) retrieve(ctx context.Context) (aws.Credentials, erro
}

func (cc *CredentialsCache) Run(ctx context.Context) {
// Wait for initialized signal before running loop.
// Wait for a generateOIDCTokenFn before running loop.
select {
case <-cc.initialized:
case <-cc.gotGenerateOIDCTokenFn:
case <-ctx.Done():
cc.log.DebugContext(ctx, "Context canceled before initialized.")
return
Expand Down Expand Up @@ -265,7 +282,7 @@ func (cc *CredentialsCache) refresh(ctx context.Context) (aws.Credentials, error
defer cc.log.InfoContext(ctx, "Exiting AWS credentials refresh")

cc.log.InfoContext(ctx, "Generating Token")
oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration)
oidcToken, err := cc.getGenerateOIDCTokenFn()(ctx, cc.integration)
if err != nil {
cc.log.ErrorContext(ctx, "Token generation failed", errorValue(err))
return aws.Credentials{}, trace.Wrap(err)
Expand Down

0 comments on commit 540c8bd

Please sign in to comment.