From 117608ac555f3138eaa3da5b3d029dc5a73019e3 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Wed, 4 Dec 2024 17:23:49 -0800 Subject: [PATCH] migrate lib/srv/alpnproxy to AWS SDK v2 --- lib/srv/alpnproxy/aws_local_proxy.go | 47 +++---- lib/srv/alpnproxy/aws_local_proxy_test.go | 162 ++++++++++------------ lib/srv/alpnproxy/helpers_test.go | 2 +- lib/srv/alpnproxy/local_proxy_test.go | 157 ++++++++++++--------- lib/utils/aws/aws.go | 21 ++- lib/utils/aws/migration/migration.go | 6 + tool/tsh/common/app_aws.go | 2 +- 7 files changed, 206 insertions(+), 191 deletions(-) diff --git a/lib/srv/alpnproxy/aws_local_proxy.go b/lib/srv/alpnproxy/aws_local_proxy.go index 56c1bbabe4dd5..87cf634f365d1 100644 --- a/lib/srv/alpnproxy/aws_local_proxy.go +++ b/lib/srv/alpnproxy/aws_local_proxy.go @@ -22,10 +22,9 @@ import ( "net/http" "strings" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -34,20 +33,15 @@ import ( appcommon "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" - "github.com/gravitational/teleport/lib/utils/aws/migration" ) // AWSAccessMiddleware verifies the requests to AWS proxy are properly signed. type AWSAccessMiddleware struct { DefaultLocalProxyHTTPMiddleware - // AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification. - AWSCredentials *credentials.Credentials - - // AWSCredentialsV2Provider is an aws sdk v2 credential provider used by - // LocalProxy for request's signature verification if AWSCredentials is not - // specified. - AWSCredentialsV2Provider awsv2.CredentialsProvider + // AWSCredentialsProvider provides credentials for local proxy request + // signature verification. + AWSCredentialsProvider aws.CredentialsProvider Log logrus.FieldLogger @@ -61,11 +55,8 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error { m.Log = logrus.WithField(teleport.ComponentKey, "aws_access") } - if m.AWSCredentials == nil { - if m.AWSCredentialsV2Provider == nil { - return trace.BadParameter("missing AWSCredentials") - } - m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider)) + if m.AWSCredentialsProvider == nil { + return trace.BadParameter("missing AWS credentials") } return nil @@ -143,7 +134,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re } func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool { - if err := awsutils.VerifyAWSSignature(req, m.AWSCredentials); err != nil { + if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil { m.Log.WithError(err).Error("AWS signature verification failed.") rw.WriteHeader(http.StatusForbidden) return true @@ -152,22 +143,22 @@ func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *h } func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter, req *http.Request, assumedRole *sts.AssumeRoleOutput) bool { - credentials := credentials.NewStaticCredentials( - aws.StringValue(assumedRole.Credentials.AccessKeyId), - aws.StringValue(assumedRole.Credentials.SecretAccessKey), - aws.StringValue(assumedRole.Credentials.SessionToken), + credentials := credentials.NewStaticCredentialsProvider( + aws.ToString(assumedRole.Credentials.AccessKeyId), + aws.ToString(assumedRole.Credentials.SecretAccessKey), + aws.ToString(assumedRole.Credentials.SessionToken), ) - if err := awsutils.VerifyAWSSignature(req, credentials); err != nil { + if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil { m.Log.WithError(err).Error("AWS signature verification failed.") rw.WriteHeader(http.StatusForbidden) return true } - m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn)) + m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn)) // Add a custom header for marking the special request. - req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.StringValue(assumedRole.AssumedRoleUser.Arn)) + req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.ToString(assumedRole.AssumedRoleUser.Arn)) // Rename the original authorization header to ensure older app agents // (that don't support the requests by assumed roles) will fail. @@ -191,7 +182,7 @@ func (m *AWSAccessMiddleware) HandleResponse(response *http.Response) error { return nil } - if strings.EqualFold(sigV4.Service, sts.EndpointsID) { + if strings.EqualFold(sigV4.Service, "sts") { return trace.Wrap(m.handleSTSResponse(response)) } return nil @@ -219,8 +210,8 @@ func (m *AWSAccessMiddleware) handleSTSResponse(response *http.Response) error { return nil } - m.assumedRoles.Store(aws.StringValue(assumedRole.Credentials.AccessKeyId), assumedRole) - m.Log.Debugf("Saved credentials for assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn)) + m.assumedRoles.Store(aws.ToString(assumedRole.Credentials.AccessKeyId), assumedRole) + m.Log.Debugf("Saved credentials for assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn)) return nil } diff --git a/lib/srv/alpnproxy/aws_local_proxy_test.go b/lib/srv/alpnproxy/aws_local_proxy_test.go index 30f39290d697b..1e01cfed5606d 100644 --- a/lib/srv/alpnproxy/aws_local_proxy_test.go +++ b/lib/srv/alpnproxy/aws_local_proxy_test.go @@ -19,18 +19,17 @@ package alpnproxy import ( + "context" "encoding/xml" "net/http" "net/http/httptest" "testing" "time" - credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/private/protocol" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/stretchr/testify/require" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -40,89 +39,70 @@ func TestAWSAccessMiddleware(t *testing.T) { t.Parallel() assumedRoleARN := "arn:aws:sts::123456789012:assumed-role/role-name/role-session-name" - localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "") - assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token") - - tests := []struct { - name string - middleware *AWSAccessMiddleware - }{ - { - name: "v1", - middleware: &AWSAccessMiddleware{ - AWSCredentials: localProxyCred, - }, - }, - { - name: "v2", - middleware: &AWSAccessMiddleware{ - AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""), - }, - }, - } + localProxyCred := credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", "") + assumedRoleCred := credentials.NewStaticCredentialsProvider("assumed-role", "assumed-role-secret", "assumed-role-token") - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - m := test.middleware - require.NoError(t, m.CheckAndSetDefaults()) - - stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) - v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) - - requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) - v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) - - t.Run("request no authorization", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by unknown credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by local proxy credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies sts:AssumeRole output can be handled successfully. The - // credentials should be saved afterwards. - t.Run("handle sts:AssumeRole response", func(t *testing.T) { - response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) - - // This is the same request as the "unknown credentials" test above. But at - // this point, the assumed role credentials should have been saved by the - // middleware so the request can be handled successfully now. - t.Run("request signed by assumed role", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies non sts:AssumeRole responses do not give errors. - t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { - response := getCallerIdentityResponse(t, assumedRoleARN) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) - }) + m := &AWSAccessMiddleware{ + AWSCredentialsProvider: credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""), } + require.NoError(t, m.CheckAndSetDefaults()) + + stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) + + awsutils.NewSignerV2(localProxyCred, "sts").Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) + + requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) + awsutils.NewSignerV2(assumedRoleCred, "s3").Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) + + t.Run("request no authorization", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by unknown credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by local proxy credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies sts:AssumeRole output can be handled successfully. The + // credentials should be saved afterwards. + t.Run("handle sts:AssumeRole response", func(t *testing.T) { + response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) + + // This is the same request as the "unknown credentials" test above. But at + // this point, the assumed role credentials should have been saved by the + // middleware so the request can be handled successfully now. + t.Run("request signed by assumed role", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies non sts:AssumeRole responses do not give errors. + t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { + response := getCallerIdentityResponse(t, assumedRoleARN) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) } -func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response { +func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsProvider) *http.Response { t.Helper() - credValue, err := cred.Get() + credValue, err := provider.Retrieve(context.Background()) require.NoError(t, err) body, err := awsutils.MarshalXML( @@ -132,18 +112,18 @@ func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credenti }, map[string]any{ "AssumeRoleResult": sts.AssumeRoleOutput{ - AssumedRoleUser: &sts.AssumedRoleUser{ + AssumedRoleUser: &ststypes.AssumedRoleUser{ Arn: aws.String(roleARN), }, - Credentials: &sts.Credentials{ + Credentials: &ststypes.Credentials{ AccessKeyId: aws.String(credValue.AccessKeyID), SecretAccessKey: aws.String(credValue.SecretAccessKey), SessionToken: aws.String(credValue.SessionToken), }, }, - "ResponseMetadata": protocol.ResponseMetadata{ - StatusCode: http.StatusOK, - RequestID: "22222222-3333-3333-3333-333333333333", + "ResponseMetadata": map[string]any{ + "StatusCode": http.StatusOK, + "RequestID": "22222222-3333-3333-3333-333333333333", }, }, ) @@ -163,9 +143,9 @@ func getCallerIdentityResponse(t *testing.T, roleARN string) *http.Response { "GetCallerIdentityResult": sts.GetCallerIdentityOutput{ Arn: aws.String(roleARN), }, - "ResponseMetadata": protocol.ResponseMetadata{ - StatusCode: http.StatusOK, - RequestID: "22222222-3333-3333-3333-333333333333", + "ResponseMetadata": map[string]any{ + "StatusCode": http.StatusOK, + "RequestID": "22222222-3333-3333-3333-333333333333", }, }, ) diff --git a/lib/srv/alpnproxy/helpers_test.go b/lib/srv/alpnproxy/helpers_test.go index b4e9df20e83f0..abbfcd3eb46d0 100644 --- a/lib/srv/alpnproxy/helpers_test.go +++ b/lib/srv/alpnproxy/helpers_test.go @@ -263,7 +263,7 @@ func mustCallHTTPSServerAndReceiveCode(t *testing.T, addr string, client http.Cl require.Equal(t, expectStatusCode, resp.StatusCode) } -func mustStartHTTPServer(t *testing.T, l net.Listener) { +func mustStartHTTPServer(_ *testing.T, l net.Listener) { mux := http.NewServeMux() mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {}) go http.Serve(l, mux) diff --git a/lib/srv/alpnproxy/local_proxy_test.go b/lib/srv/alpnproxy/local_proxy_test.go index f7940ef22c069..785067478e7e5 100644 --- a/lib/srv/alpnproxy/local_proxy_test.go +++ b/lib/srv/alpnproxy/local_proxy_test.go @@ -21,8 +21,11 @@ package alpnproxy import ( "bytes" "context" + "crypto/sha256" "crypto/tls" "crypto/x509" + "encoding/hex" + "errors" "io" "net" "net/http" @@ -33,11 +36,13 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/gravitational/trace" "github.com/jackc/pgproto3/v2" "github.com/jonboulle/clockwork" @@ -53,72 +58,62 @@ import ( // TestHandleAWSAccessSigVerification tests if LocalProxy verifies the AWS SigV4 signature of incoming request. func TestHandleAWSAccessSigVerification(t *testing.T) { var ( - firstAWSCred = credentials.NewStaticCredentials("userID", "firstSecret", "") - secondAWSCred = credentials.NewStaticCredentials("userID", "secondSecret", "") - thirdAWSCred = credentials.NewStaticCredentials("userID2", "firstSecret", "") + firstAWSCred = credentials.NewStaticCredentialsProvider("userID", "firstSecret", "") + secondAWSCred = credentials.NewStaticCredentialsProvider("userID", "secondSecret", "") + thirdAWSCred = credentials.NewStaticCredentialsProvider("userID2", "firstSecret", "") - awsService = "s3" - awsRegion = "eu-central-1" + awsRegion = "eu-central-1" ) testCases := []struct { name string - proxyCred *credentials.Credentials - signFunc func(*http.Request, io.ReadSeeker, string, string, time.Time) (http.Header, error) - wantErr require.ErrorAssertionFunc + proxyCred aws.CredentialsProvider + clientCred aws.CredentialsProvider + apiOpts []func(*middleware.Stack) error wantStatus int }{ { name: "valid signature", proxyCred: firstAWSCred, - signFunc: v4.NewSigner(firstAWSCred).Sign, - wantErr: require.NoError, + clientCred: firstAWSCred, wantStatus: http.StatusOK, }, { name: "different aws secret access key", proxyCred: secondAWSCred, - signFunc: v4.NewSigner(firstAWSCred).Sign, + clientCred: firstAWSCred, wantStatus: http.StatusForbidden, }, { name: "different aws access key ID", proxyCred: thirdAWSCred, - signFunc: v4.NewSigner(firstAWSCred).Sign, + clientCred: firstAWSCred, wantStatus: http.StatusForbidden, }, { - name: "unsigned request", - proxyCred: firstAWSCred, - signFunc: func(*http.Request, io.ReadSeeker, string, string, time.Time) (http.Header, error) { - // no-op - return nil, nil - }, + name: "unsigned request", + proxyCred: firstAWSCred, + clientCred: nil, wantStatus: http.StatusForbidden, }, { - name: "signed with User-Agent header", - proxyCred: secondAWSCred, - signFunc: func(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) { - // Simulate a case where "User-Agent" is part of the "SignedHeaders". - // The signature does not have to be valid as it will not be compared. - header, err := v4.NewSigner(firstAWSCred).Sign(r, body, service, region, signTime) - if err != nil { - return nil, trace.Wrap(err) - } - - authHeader := r.Header.Get("Authorization") - authHeader = strings.Replace(authHeader, "SignedHeaders=", "SignedHeaders=user-agent;", 1) - r.Header.Set("Authorization", authHeader) - return header, nil + name: "signed with User-Agent header", + proxyCred: secondAWSCred, + clientCred: firstAWSCred, + apiOpts: []func(*middleware.Stack) error{ + func(stack *middleware.Stack) error { + stack.Finalize.Insert( + addUserAgentSignedHeaderMiddleware{}, + "Signing", + middleware.After, + ) + return nil + }, }, wantStatus: http.StatusOK, }, } - httpClient := &http.Client{ - Timeout: 5 * time.Second, - } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -132,45 +127,49 @@ func TestHandleAWSAccessSigVerification(t *testing.T) { Path: "/", } - payload := []byte("payload content") - req, err := http.NewRequest(http.MethodGet, url.String(), bytes.NewReader(payload)) - require.NoError(t, err) - - tc.signFunc(req, bytes.NewReader(payload), awsService, awsRegion, time.Now()) + clt := sts.New(sts.Options{ + APIOptions: tc.apiOpts, + Region: awsRegion, + Credentials: tc.clientCred, + BaseEndpoint: aws.String(url.String()), + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + RetryMaxAttempts: 0, + }) + _, err := clt.GetCallerIdentity(context.Background(), nil) + if tc.wantStatus == http.StatusOK { + require.NoError(t, err) + return + } - resp, err := httpClient.Do(req) - require.NoError(t, err) - require.Equal(t, tc.wantStatus, resp.StatusCode) - require.NoError(t, resp.Body.Close()) + require.Error(t, err) + var serr *awshttp.ResponseError + require.True(t, errors.As(err, &serr)) + require.Equal(t, tc.wantStatus, serr.HTTPStatusCode()) }) } } // Verifies s3 requests are signed without URL escaping to match AWS SDKs. func TestHandleAWSAccessS3Signing(t *testing.T) { - cred := credentials.NewStaticCredentials("access-key", "secret-key", "") - lp := createAWSAccessProxySuite(t, cred) + provider := credentials.NewStaticCredentialsProvider("access-key", "secret-key", "") + lp := createAWSAccessProxySuite(t, provider) // Avoid loading extra things. t.Setenv("AWS_SDK_LOAD_CONFIG", "false") // Create a real AWS SDK s3 client. - awsConfig := aws.NewConfig(). - WithDisableSSL(true). - WithRegion("local"). - WithCredentials(cred). - WithEndpoint(lp.GetAddr()). - WithS3ForcePathStyle(true) - - s3client := s3.New(session.Must(session.NewSession(awsConfig)), - &aws.Config{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - MaxRetries: aws.Int(0), - }) + s3client := s3.New(s3.Options{ + Region: "local", + Credentials: provider, + BaseEndpoint: aws.String("http://" + lp.GetAddr()), + UsePathStyle: true, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + RetryMaxAttempts: 0, + }) // Use a bucket name with special charaters. AWS SDK actually signs the // request with the unescaped bucket name. - _, err := s3client.ListObjects(&s3.ListObjectsInput{ + _, err := s3client.ListObjects(context.Background(), &s3.ListObjectsInput{ Bucket: aws.String("=bucket=name="), }) @@ -628,7 +627,7 @@ func TestKubeMiddleware(t *testing.T) { } } -func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *LocalProxy { +func createAWSAccessProxySuite(t *testing.T, provider aws.CredentialsProvider) *LocalProxy { hs := httptest.NewTLSServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})) lp, err := NewLocalProxy(LocalProxyConfig{ @@ -637,7 +636,7 @@ func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *Loc Protocols: []common.Protocol{common.ProtocolHTTP}, ParentContext: context.Background(), InsecureSkipVerify: true, - HTTPMiddleware: &AWSAccessMiddleware{AWSCredentials: cred}, + HTTPMiddleware: &AWSAccessMiddleware{AWSCredentialsProvider: provider}, }) require.NoError(t, err) t.Cleanup(func() { @@ -767,3 +766,29 @@ func TestGetCertsForConn(t *testing.T) { }) } } + +func hashPayload(payload []byte) string { + hasher := sha256.New() + io.Copy(hasher, bytes.NewReader(payload)) + return hex.EncodeToString(hasher.Sum(nil)) +} + +type addUserAgentSignedHeaderMiddleware struct { +} + +func (m addUserAgentSignedHeaderMiddleware) ID() string { return "AddUserAgentSignedHeader" } +func (m addUserAgentSignedHeaderMiddleware) HandleFinalize( + ctx context.Context, + in middleware.FinalizeInput, + next middleware.FinalizeHandler, +) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, trace.Errorf("unexpected request middleware type %T", in.Request) + } + + authHeader := req.Header.Get("Authorization") + authHeader = strings.Replace(authHeader, "SignedHeaders=", "SignedHeaders=user-agent;", 1) + req.Header.Set("Authorization", authHeader) + return next.HandleFinalize(ctx, in) +} diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index b79afbdd23f0a..7b0933b2c16b5 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -31,15 +31,16 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/service/iam" "github.com/gravitational/trace" apievents "github.com/gravitational/teleport/api/types/events" apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/aws/migration" ) const ( @@ -75,6 +76,8 @@ const ( // used by the AssumeRole call. // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_iam-quotas.html MaxRoleSessionNameLength = 64 + + iamServiceName = "iam" ) // SigV4 contains parsed content of the AWS Authorization header. @@ -147,6 +150,11 @@ func IsSignedByAWSSigV4(r *http.Request) bool { return strings.HasPrefix(r.Header.Get(AuthorizationHeader), AmazonSigV4AuthorizationPrefix) } +// VerifyAWSSignatureV2 is a temporary AWS SDK migration helper. +func VerifyAWSSignatureV2(req *http.Request, provider aws.CredentialsProvider) error { + return VerifyAWSSignature(req, migration.NewCredentialsAdapter(provider)) +} + // VerifyAWSSignature verifies the request signature ensuring that the request originates from tsh aws command execution // AWS CLI signs the request with random generated credentials that are passed to LocalProxy by // the AWSCredentials LocalProxyConfig configuration. @@ -214,6 +222,11 @@ func VerifyAWSSignature(req *http.Request, credentials *credentials.Credentials) return nil } +// NewSignerV2 is a temporary AWS SDK migration helper. +func NewSignerV2(provider aws.CredentialsProvider, signingServiceName string) *v4.Signer { + return NewSigner(migration.NewCredentialsAdapter(provider), signingServiceName) +} + // NewSigner creates a new V4 signer. func NewSigner(credentials *credentials.Credentials, signingServiceName string) *v4.Signer { options := func(s *v4.Signer) { @@ -384,7 +397,7 @@ func BuildRoleARN(username, region, accountID string) (string, error) { } roleARN := arn.ARN{ Partition: partition, - Service: iam.ServiceName, + Service: iamServiceName, AccountID: accountID, Resource: resource, } @@ -424,7 +437,7 @@ func ParseRoleARN(roleARN string) (*arn.ARN, error) { // Example role ARN: arn:aws:iam::123456789012:role/some-role-name func checkRoleARN(parsed *arn.ARN) error { parts := strings.Split(parsed.Resource, "/") - if parts[0] != "role" || parsed.Service != iam.ServiceName { + if parts[0] != "role" || parsed.Service != iamServiceName { return trace.BadParameter("%q is not an AWS IAM role ARN", parsed) } if len(parts) < 2 || len(parts[len(parts)-1]) == 0 { diff --git a/lib/utils/aws/migration/migration.go b/lib/utils/aws/migration/migration.go index 5288f2ab2a13c..43fdb36b5673a 100644 --- a/lib/utils/aws/migration/migration.go +++ b/lib/utils/aws/migration/migration.go @@ -27,6 +27,12 @@ import ( "github.com/gravitational/trace" ) +// NewCredentialsAdapter adapts an AWS SDK v2 credentials provider to v1 +// credentials. +func NewCredentialsAdapter(providerV2 awsv2.CredentialsProvider) *credentials.Credentials { + return credentials.NewCredentials(NewProviderAdapter(providerV2)) +} + // NewProviderAdapter returns a [ProviderAdapter] that can be used as an AWS SDK // v1 credentials provider. func NewProviderAdapter(providerV2 awsv2.CredentialsProvider) *ProviderAdapter { diff --git a/tool/tsh/common/app_aws.go b/tool/tsh/common/app_aws.go index 7c11458b2f300..aaa6adeec6b01 100644 --- a/tool/tsh/common/app_aws.go +++ b/tool/tsh/common/app_aws.go @@ -170,7 +170,7 @@ func (a *awsApp) GetAppName() string { // through forward proxy. func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error { awsMiddleware := &alpnproxy.AWSAccessMiddleware{ - AWSCredentialsV2Provider: a.GetAWSCredentialsProvider(), + AWSCredentialsProvider: a.GetAWSCredentialsProvider(), } // AWS endpoint URL mode