diff --git a/go.mod b/go.mod index 3d866c974..58736482f 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.5 require ( github.com/aws/aws-sdk-go v1.54.6 + github.com/aws/aws-sdk-go-v2 v1.30.4 github.com/fsnotify/fsnotify v1.7.0 github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.6.0 @@ -25,6 +26,7 @@ require ( ) require ( + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 2a3bcb3a0..9d1fc5538 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= diff --git a/pkg/filecache/converter.go b/pkg/filecache/converter.go new file mode 100644 index 000000000..ec2f16bde --- /dev/null +++ b/pkg/filecache/converter.go @@ -0,0 +1,55 @@ +package filecache + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +type v2 struct { + creds *credentials.Credentials +} + +var _ aws.CredentialsProvider = &v2{} + +func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) { + val, err := p.creds.GetWithContext(ctx) + if err != nil { + return aws.Credentials{}, err + } + resp := aws.Credentials{ + AccessKeyID: val.AccessKeyID, + SecretAccessKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + Source: val.ProviderName, + CanExpire: false, + // Don't have account ID + } + + if expiration, err := p.creds.ExpiresAt(); err != nil { + resp.CanExpire = true + resp.Expires = expiration + } + return resp, nil +} + +// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider +func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider { + return V1CredentialToV2Provider(credentials.NewCredentials(p)) +} + +// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider +func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider { + return &v2{creds: c} +} + +// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value +func V2CredentialToV1Value(cred aws.Credentials) credentials.Value { + return credentials.Value{ + AccessKeyID: cred.AccessKeyID, + SecretAccessKey: cred.SecretAccessKey, + SessionToken: cred.SessionToken, + ProviderName: cred.Source, + } +} diff --git a/pkg/filecache/filecache.go b/pkg/filecache/filecache.go index 41597edaa..64092b9f4 100644 --- a/pkg/filecache/filecache.go +++ b/pkg/filecache/filecache.go @@ -10,6 +10,7 @@ import ( "runtime" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" "github.com/spf13/afero" @@ -34,7 +35,7 @@ func NewFileLocker(filename string) FileLocker { // cacheFile is a map of clusterID/roleARNs to cached credentials type cacheFile struct { // a map of clusterIDs/profiles/roleARNs to cachedCredentials - ClusterMap map[string]map[string]map[string]cachedCredential `yaml:"clusters"` + ClusterMap map[string]map[string]map[string]aws.Credentials `yaml:"clusters"` } // a utility type for dealing with compound cache keys @@ -44,19 +45,19 @@ type cacheKey struct { roleARN string } -func (c *cacheFile) Put(key cacheKey, credential cachedCredential) { +func (c *cacheFile) Put(key cacheKey, credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; !ok { // first use of this cluster id - c.ClusterMap[key.clusterID] = map[string]map[string]cachedCredential{} + c.ClusterMap[key.clusterID] = map[string]map[string]aws.Credentials{} } if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok { // first use of this profile - c.ClusterMap[key.clusterID][key.profile] = map[string]cachedCredential{} + c.ClusterMap[key.clusterID][key.profile] = map[string]aws.Credentials{} } c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential } -func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { +func (c *cacheFile) Get(key cacheKey) (credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; ok { if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok { // we at least have this cluster and profile combo in the map, if no matching roleARN, map will @@ -67,31 +68,12 @@ func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { return } -// cachedCredential is a single cached credential entry, along with expiration time -type cachedCredential struct { - Credential credentials.Value - Expiration time.Time - // If set will be used by IsExpired to determine the current time. - // Defaults to time.Now if CurrentTime is not set. Available for testing - // to be able to mock out the current time. - currentTime func() time.Time -} - -// IsExpired determines if the cached credential has expired -func (c *cachedCredential) IsExpired() bool { - curTime := c.currentTime - if curTime == nil { - curTime = time.Now - } - return c.Expiration.Before(curTime()) -} - // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ - map[string]map[string]map[string]cachedCredential{}, + map[string]map[string]map[string]aws.Credentials{}, } data, err := afero.ReadFile(fs, filename) if err != nil { @@ -149,9 +131,9 @@ type FileCacheProvider struct { fs afero.Fs filelockCreator func(string) FileLocker filename string - credentials *credentials.Credentials // the underlying implementation that has the *real* Provider - cacheKey cacheKey // cache key parameters used to create Provider - cachedCredential cachedCredential // the cached credential, if it exists + provider aws.CredentialsProvider // the underlying implementation that has the *real* Provider + cacheKey cacheKey // cache key parameters used to create Provider + cachedCredential aws.Credentials // the cached credential, if it exists } var _ credentials.Provider = &FileCacheProvider{} @@ -160,8 +142,8 @@ var _ credentials.Provider = &FileCacheProvider{} // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { - if creds == nil { +func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.CredentialsProvider, opts ...FileCacheOpt) (*FileCacheProvider, error) { + if provider == nil { return nil, errors.New("no underlying Credentials object provided") } @@ -169,9 +151,9 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials fs: afero.NewOsFs(), filelockCreator: NewFileLocker, filename: defaultCacheFilename(), - credentials: creds, + provider: provider, cacheKey: cacheKey{clusterID, profile, roleARN}, - cachedCredential: cachedCredential{}, + cachedCredential: aws.Credentials{}, } // override defaults @@ -222,36 +204,40 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials // otherwise fetching the credential from the underlying Provider and caching the results on disk // with an expiration time. func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { - if !f.cachedCredential.IsExpired() { + return f.RetrieveWithContext(context.Background()) +} + +// Retrieve() implements the Provider interface, returning the cached credential if is not expired, +// otherwise fetching the credential from the underlying Provider and caching the results on disk +// with an expiration time. +func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() { // use the cached credential - return f.cachedCredential.Credential, nil + return V2CredentialToV1Value(f.cachedCredential), nil } else { _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") // fetch the credentials from the underlying Provider - credential, err := f.credentials.Get() + credential, err := f.provider.Retrieve(ctx) if err != nil { - return credential, err + return V2CredentialToV1Value(credential), err } - if expiration, err := f.credentials.ExpiresAt(); err == nil { - // underlying provider supports Expirer interface, so we can cache + + if credential.CanExpire { + // Credential supports expiration, so we can cache // do file locking on cache to prevent inconsistent writes lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) - return credential, nil - } - f.cachedCredential = cachedCredential{ - credential, - expiration, - nil, + return V2CredentialToV1Value(credential), nil } + f.cachedCredential = credential // don't really care about read error. Either read the cache, or we create a new cache. cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) @@ -268,19 +254,19 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) err = nil } - return credential, err + return V2CredentialToV1Value(credential), err } } // IsExpired() implements the Provider interface, deferring to the cached credential first, // but fall back to the underlying Provider if it is expired. func (f *FileCacheProvider) IsExpired() bool { - return f.cachedCredential.IsExpired() && f.credentials.IsExpired() + return f.cachedCredential.CanExpire && f.cachedCredential.Expired() } // ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential func (f *FileCacheProvider) ExpiresAt() time.Time { - return f.cachedCredential.Expiration + return f.cachedCredential.Expires } // defaultCacheFilename returns the name of the credential cache file, which can either be diff --git a/pkg/filecache/filecache_test.go b/pkg/filecache/filecache_test.go index 60b4a8771..f2db98556 100644 --- a/pkg/filecache/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,7 +1,6 @@ package filecache import ( - "bytes" "context" "errors" "fmt" @@ -10,7 +9,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/google/go-cmp/cmp" "github.com/spf13/afero" ) @@ -20,35 +20,17 @@ const ( // stubProvider implements credentials.Provider with configurable response values type stubProvider struct { - creds credentials.Value - expired bool - err error + creds aws.Credentials + err error } -var _ credentials.Provider = &stubProvider{} +var _ aws.CredentialsProvider = &stubProvider{} -func (s *stubProvider) Retrieve() (credentials.Value, error) { - s.expired = false - s.creds.ProviderName = "stubProvider" +func (s *stubProvider) Retrieve(_ context.Context) (aws.Credentials, error) { + s.creds.Source = "stubProvider" return s.creds, s.err } -func (s *stubProvider) IsExpired() bool { - return s.expired -} - -// stubProviderExpirer implements credentials.Expirer with configurable expiration -type stubProviderExpirer struct { - stubProvider - expiration time.Time -} - -var _ credentials.Expirer = &stubProviderExpirer{} - -func (s *stubProviderExpirer) ExpiresAt() time.Time { - return s.expiration -} - // testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string @@ -116,22 +98,34 @@ func getMocks() (*testFS, *testFilelock) { } // makeCredential returns a dummy AWS crdential -func makeCredential() credentials.Value { - return credentials.Value{ +func makeCredential() aws.Credentials { + return aws.Credentials{ AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "TOKEN", - ProviderName: "stubProvider", + Source: "stubProvider", + CanExpire: false, + } +} + +func makeExpiringCredential(e time.Time) aws.Credentials { + return aws.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "TOKEN", + Source: "stubProvider", + CanExpire: true, + Expires: e, } } // validateFileCacheProvider ensures that the cache provider is properly initialized -func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c aws.CredentialsProvider) { t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } - if p.credentials != c { + if p.provider != c { t.Errorf("Credentials not copied") } if p.cacheKey.clusterID != "CLUSTER" { @@ -184,24 +178,24 @@ func TestCacheFilename(t *testing.T) { } func TestNewFileCacheProvider_Missing(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { return tfl })) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing cache file should result in empty cached credential") } } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, _ := getMocks() // afero.MemMapFs always returns tempfile FileInfo, @@ -209,7 +203,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), ) @@ -223,7 +217,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { } func TestNewFileCacheProvider_Unlockable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) @@ -232,7 +226,7 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { tfl.success = false tfl.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -245,14 +239,14 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { } func TestNewFileCacheProvider_Unreadable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) tfl.err = fmt.Errorf("open %s: permission denied", testFilename) tfl.success = false - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -270,12 +264,12 @@ func TestNewFileCacheProvider_Unreadable(t *testing.T) { } func TestNewFileCacheProvider_Unparseable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -297,12 +291,12 @@ func TestNewFileCacheProvider_Unparseable(t *testing.T) { } func TestNewFileCacheProvider_Empty(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -313,58 +307,60 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("empty cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("empty cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - afero.WriteFile( - tfs, - testFilename, - []byte(`clusters: - CLUSTER: - ARN2: {} -`), - 0700) + tfs.Create(testFilename) + // successfully parse existing cluster without matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { - tfs.Create(testFilename) + + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: + CLUSTER: + PROFILE2: {} +`), + 0700) return tfl }), ) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing arn in cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing profile in cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingARN(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} + expiry := time.Now().Add(time.Hour * 6) content := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: 2018-01-02T03:04:56.789Z + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + expires: ` + expiry.Format(time.RFC3339Nano) + ` `) tfs, tfl := getMocks() tfs.Create(testFilename) // successfully parse cluster with matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -377,38 +373,31 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || - p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.AccessKeyID != "ABC" || p.cachedCredential.SecretAccessKey != "DEF" || + p.cachedCredential.SessionToken != "GHI" || p.cachedCredential.Source != "JKL" { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } - if p.cachedCredential.IsExpired() { + + if p.cachedCredential.Expired() { t.Errorf("Cached credential should not be expired") } - if p.IsExpired() { - t.Errorf("Cache credential should not be expired") - } - expectedExpiration := time.Date(2018, 01, 02, 03, 04, 56, 789000000, time.UTC) - if p.ExpiresAt() != expectedExpiration { + + if p.ExpiresAt() != p.cachedCredential.Expires { t.Errorf("Credential expiration time is not correct, expected %v, got %v", - expectedExpiration, p.ExpiresAt()) + p.cachedCredential.Expires, p.ExpiresAt()) } } func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { - providerCredential := makeCredential() - c := credentials.NewCredentials(&stubProvider{ - creds: providerCredential, - }) + provider := &stubProvider{ + creds: makeCredential(), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -416,45 +405,37 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken { t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + credential, provider.creds) } } -// makeExpirerCredentials returns an expiring credential -func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { - providerCredential = makeCredential() - expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) - c = credentials.NewCredentials(&stubProviderExpirer{ - stubProvider{ - creds: providerCredential, - }, - expiration, - }) - return -} - func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { tfs.Create(testFilename) return tfl })) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // fail to get write lock @@ -465,19 +446,22 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" || + credential.SessionToken != "TOKEN" || credential.ProviderName != "stubProvider" { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { - providerCredential, expiration, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -485,45 +469,50 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } expectedData := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: AKID - secretaccesskey: SECRET - sessiontoken: TOKEN - providername: stubProvider - expiration: ` + expiration.Format(time.RFC3339Nano) + ` + accesskeyid: AKID + secretaccesskey: SECRET + sessiontoken: TOKEN + source: stubProvider + canexpire: true + expires: ` + expires.Format(time.RFC3339Nano) + ` + accountid: "" `) got, err := afero.ReadFile(tfs, testFilename) if err != nil { t.Errorf("unexpected error reading generated file: %v", err) } - if !bytes.Equal(got, expectedData) { - t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, got) + if diff := cmp.Diff(got, expectedData); diff != "" { + t.Errorf("Wrong data written to cache, %s", diff) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -531,7 +520,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -540,15 +529,17 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) + provider := &stubProvider{} + currentTime := time.Now() tfs, tfl := getMocks() tfs.Create(testFilename) @@ -559,13 +550,14 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { PROFILE: ARN: credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + canexpire: true + expires: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -573,10 +565,7 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { afero.WriteFile(tfs, testFilename, content, 0700) return tfl })) - validateFileCacheProvider(t, p, err, c) - - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { return currentTime } + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { @@ -586,4 +575,11 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { credential.SessionToken != "GHI" || credential.ProviderName != "JKL" { t.Errorf("cached credential not returned") } + + if !p.ExpiresAt().Equal(currentTime.Add(time.Hour * 6)) { + t.Errorf("unexpected expiration time: got %s, wanted %s", + p.ExpiresAt().Format(time.RFC3339Nano), + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano), + ) + } } diff --git a/pkg/token/token.go b/pkg/token/token.go index d9d7fd2e8..716a8cb12 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -248,7 +248,11 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + if cacheProvider, err := filecache.NewFileCacheProvider( + options.ClusterID, + profile, + options.AssumeRoleARN, + filecache.V1CredentialToV2Provider(sess.Config.Credentials)); err == nil { sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) diff --git a/tests/integration/go.mod b/tests/integration/go.mod index ee7a84140..6781a0caf 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -18,6 +18,8 @@ require ( github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect + github.com/aws/aws-sdk-go-v2 v1.30.4 // indirect + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index c85dc3777..4794685e6 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -12,6 +12,10 @@ github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4 github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=