diff --git a/pkg/token/token.go b/pkg/token/token.go index 88ab70299..7936f4c88 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -34,6 +34,7 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/prometheus/client_golang/prometheus" @@ -198,6 +199,7 @@ type Generator interface { type generator struct { forwardSessionName bool cache bool + nowFunc func() time.Time } // NewGenerator creates a Generator and returns it. @@ -205,6 +207,7 @@ func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) { return generator{ forwardSessionName: forwardSessionName, cache: cache, + nowFunc: time.Now, }, nil } @@ -332,12 +335,23 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { return g.GetWithSTS(options.ClusterID, stsAPI) } +func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler { + return request.NamedHandler{ + Name: "v4.SignRequestHandler", Fn: func(req *request.Request) { + v4.SignSDKRequestWithCurrentTime(req, nowFunc) + }, + } +} + // GetWithSTS returns a token valid for clusterID using the given STS client. func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) { // generate an sts:GetCallerIdentity request and add our custom cluster ID header request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) request.HTTPRequest.Header.Add(clusterIDHeader, clusterID) + // override the Sign handler so we can control the now time for testing. + request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc)) + // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date // timestamp regardless. We set it to 60 seconds for backwards compatibility (the @@ -350,7 +364,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, } // Set token expiration to 1 minute before the presigned URL expires for some cushion - tokenExpiration := time.Now().Local().Add(presignedURLExpiration - 1*time.Minute) + tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) // TODO: this may need to be a constant-time base64 encoding return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index 021718815..6cc87eb80 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -14,7 +14,11 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -582,3 +586,61 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { }) } } + +func TestGetWithSTS(t *testing.T) { + clusterID := "test-cluster" + + cases := []struct { + name string + creds *credentials.Credentials + nowTime time.Time + want Token + wantErr error + }{ + { + "Non-zero time", + // Example non-real credentials + func() *credentials.Credentials { + decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") + decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") + return credentials.NewStaticCredentials( + string(decodedAkid), + string(decodedSk), + "", + ) + }(), + time.Unix(1682640000, 0), + Token{ + Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JTNCeC1rOHMtYXdzLWlkJlgtQW16LVNpZ25hdHVyZT00ZDdhYmZkZTk2NzI1ZWI4YTc3MzgyNDg0MTZlNGI1ZDA4ZDlkYmQ3MThiNGY2ZGQ2OTBmOGZiNzUwMTMyOWQ1", + Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14), + }, + nil, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc := sts.New(session.Must(session.NewSession( + &aws.Config{ + Credentials: tc.creds, + Region: aws.String("us-west-2"), + STSRegionalEndpoint: endpoints.RegionalSTSEndpoint, + }, + ))) + + gen := &generator{ + forwardSessionName: false, + cache: false, + nowFunc: func() time.Time { return tc.nowTime }, + } + + got, err := gen.GetWithSTS(clusterID, svc) + if diff := cmp.Diff(err, tc.wantErr); diff != "" { + t.Errorf("Unexpected error: %s", diff) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("Got unexpected token: %s", diff) + } + }) + } +}