Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configurable Now time for signature generation #741

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -198,13 +199,15 @@ type Generator interface {
type generator struct {
forwardSessionName bool
cache bool
nowFunc func() time.Time
}

// NewGenerator creates a Generator and returns it.
func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) {
return generator{
forwardSessionName: forwardSessionName,
cache: cache,
nowFunc: time.Now,
}, nil
}

Expand Down Expand Up @@ -332,12 +335,23 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
return g.GetWithSTS(options.ClusterID, stsAPI)
}

func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler {
return request.NamedHandler{
Name: "v4.SignRequestHandler", Fn: func(req *request.Request) {
v4.SignSDKRequestWithCurrentTime(req, nowFunc)
},
}
}

// GetWithSTS returns a token valid for clusterID using the given STS client.
func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) {
// generate an sts:GetCallerIdentity request and add our custom cluster ID header
request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
request.HTTPRequest.Header.Add(clusterIDHeader, clusterID)

// override the Sign handler so we can control the now time for testing.
request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc))

// Sign the request. The expires parameter (sets the x-amz-expires header) is
// currently ignored by STS, and the token expires 15 minutes after the x-amz-date
// timestamp regardless. We set it to 60 seconds for backwards compatibility (the
Expand All @@ -350,7 +364,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token,
}

// Set token expiration to 1 minute before the presigned URL expires for some cushion
tokenExpiration := time.Now().Local().Add(presignedURLExpiration - 1*time.Minute)
tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute)
// TODO: this may need to be a constant-time base64 encoding
return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil
}
Expand Down
62 changes: 62 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/google/go-cmp/cmp"
"github.com/prometheus/client_golang/prometheus"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -582,3 +586,61 @@ func Test_getDefaultHostNameForRegion(t *testing.T) {
})
}
}

func TestGetWithSTS(t *testing.T) {
clusterID := "test-cluster"

cases := []struct {
name string
creds *credentials.Credentials
nowTime time.Time
want Token
wantErr error
}{
{
"Non-zero time",
// Example non-real credentials
func() *credentials.Credentials {
decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=")
decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==")
return credentials.NewStaticCredentials(
string(decodedAkid),
string(decodedSk),
"",
)
}(),
time.Unix(1682640000, 0),
Token{
Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JTNCeC1rOHMtYXdzLWlkJlgtQW16LVNpZ25hdHVyZT00ZDdhYmZkZTk2NzI1ZWI4YTc3MzgyNDg0MTZlNGI1ZDA4ZDlkYmQ3MThiNGY2ZGQ2OTBmOGZiNzUwMTMyOWQ1",
Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14),
},
nil,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a test case that has an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only possible error is a request.Presign() error, which I'm not really concerned with for the purposes of this test

},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
svc := sts.New(session.Must(session.NewSession(
&aws.Config{
Credentials: tc.creds,
Region: aws.String("us-west-2"),
STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
},
)))

gen := &generator{
forwardSessionName: false,
cache: false,
nowFunc: func() time.Time { return tc.nowTime },
}

got, err := gen.GetWithSTS(clusterID, svc)
if diff := cmp.Diff(err, tc.wantErr); diff != "" {
t.Errorf("Unexpected error: %s", diff)
}
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("Got unexpected token: %s", diff)
}
})
}
}