diff --git a/lib/srv/alpnproxy/aws_local_proxy.go b/lib/srv/alpnproxy/aws_local_proxy.go index 56c1bbabe4dd5..87cf634f365d1 100644 --- a/lib/srv/alpnproxy/aws_local_proxy.go +++ b/lib/srv/alpnproxy/aws_local_proxy.go @@ -22,10 +22,9 @@ import ( "net/http" "strings" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -34,20 +33,15 @@ import ( appcommon "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" - "github.com/gravitational/teleport/lib/utils/aws/migration" ) // AWSAccessMiddleware verifies the requests to AWS proxy are properly signed. type AWSAccessMiddleware struct { DefaultLocalProxyHTTPMiddleware - // AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification. - AWSCredentials *credentials.Credentials - - // AWSCredentialsV2Provider is an aws sdk v2 credential provider used by - // LocalProxy for request's signature verification if AWSCredentials is not - // specified. - AWSCredentialsV2Provider awsv2.CredentialsProvider + // AWSCredentialsProvider provides credentials for local proxy request + // signature verification. + AWSCredentialsProvider aws.CredentialsProvider Log logrus.FieldLogger @@ -61,11 +55,8 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error { m.Log = logrus.WithField(teleport.ComponentKey, "aws_access") } - if m.AWSCredentials == nil { - if m.AWSCredentialsV2Provider == nil { - return trace.BadParameter("missing AWSCredentials") - } - m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider)) + if m.AWSCredentialsProvider == nil { + return trace.BadParameter("missing AWS credentials") } return nil @@ -143,7 +134,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re } func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool { - if err := awsutils.VerifyAWSSignature(req, m.AWSCredentials); err != nil { + if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil { m.Log.WithError(err).Error("AWS signature verification failed.") rw.WriteHeader(http.StatusForbidden) return true @@ -152,22 +143,22 @@ func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *h } func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter, req *http.Request, assumedRole *sts.AssumeRoleOutput) bool { - credentials := credentials.NewStaticCredentials( - aws.StringValue(assumedRole.Credentials.AccessKeyId), - aws.StringValue(assumedRole.Credentials.SecretAccessKey), - aws.StringValue(assumedRole.Credentials.SessionToken), + credentials := credentials.NewStaticCredentialsProvider( + aws.ToString(assumedRole.Credentials.AccessKeyId), + aws.ToString(assumedRole.Credentials.SecretAccessKey), + aws.ToString(assumedRole.Credentials.SessionToken), ) - if err := awsutils.VerifyAWSSignature(req, credentials); err != nil { + if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil { m.Log.WithError(err).Error("AWS signature verification failed.") rw.WriteHeader(http.StatusForbidden) return true } - m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn)) + m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn)) // Add a custom header for marking the special request. - req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.StringValue(assumedRole.AssumedRoleUser.Arn)) + req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.ToString(assumedRole.AssumedRoleUser.Arn)) // Rename the original authorization header to ensure older app agents // (that don't support the requests by assumed roles) will fail. @@ -191,7 +182,7 @@ func (m *AWSAccessMiddleware) HandleResponse(response *http.Response) error { return nil } - if strings.EqualFold(sigV4.Service, sts.EndpointsID) { + if strings.EqualFold(sigV4.Service, "sts") { return trace.Wrap(m.handleSTSResponse(response)) } return nil @@ -219,8 +210,8 @@ func (m *AWSAccessMiddleware) handleSTSResponse(response *http.Response) error { return nil } - m.assumedRoles.Store(aws.StringValue(assumedRole.Credentials.AccessKeyId), assumedRole) - m.Log.Debugf("Saved credentials for assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn)) + m.assumedRoles.Store(aws.ToString(assumedRole.Credentials.AccessKeyId), assumedRole) + m.Log.Debugf("Saved credentials for assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn)) return nil } diff --git a/lib/srv/alpnproxy/aws_local_proxy_test.go b/lib/srv/alpnproxy/aws_local_proxy_test.go index 30f39290d697b..1e01cfed5606d 100644 --- a/lib/srv/alpnproxy/aws_local_proxy_test.go +++ b/lib/srv/alpnproxy/aws_local_proxy_test.go @@ -19,18 +19,17 @@ package alpnproxy import ( + "context" "encoding/xml" "net/http" "net/http/httptest" "testing" "time" - credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/private/protocol" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/stretchr/testify/require" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -40,89 +39,70 @@ func TestAWSAccessMiddleware(t *testing.T) { t.Parallel() assumedRoleARN := "arn:aws:sts::123456789012:assumed-role/role-name/role-session-name" - localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "") - assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token") - - tests := []struct { - name string - middleware *AWSAccessMiddleware - }{ - { - name: "v1", - middleware: &AWSAccessMiddleware{ - AWSCredentials: localProxyCred, - }, - }, - { - name: "v2", - middleware: &AWSAccessMiddleware{ - AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""), - }, - }, - } + localProxyCred := credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", "") + assumedRoleCred := credentials.NewStaticCredentialsProvider("assumed-role", "assumed-role-secret", "assumed-role-token") - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - m := test.middleware - require.NoError(t, m.CheckAndSetDefaults()) - - stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) - v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) - - requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) - v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) - - t.Run("request no authorization", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by unknown credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by local proxy credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies sts:AssumeRole output can be handled successfully. The - // credentials should be saved afterwards. - t.Run("handle sts:AssumeRole response", func(t *testing.T) { - response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) - - // This is the same request as the "unknown credentials" test above. But at - // this point, the assumed role credentials should have been saved by the - // middleware so the request can be handled successfully now. - t.Run("request signed by assumed role", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies non sts:AssumeRole responses do not give errors. - t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { - response := getCallerIdentityResponse(t, assumedRoleARN) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) - }) + m := &AWSAccessMiddleware{ + AWSCredentialsProvider: credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""), } + require.NoError(t, m.CheckAndSetDefaults()) + + stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) + + awsutils.NewSignerV2(localProxyCred, "sts").Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) + + requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) + awsutils.NewSignerV2(assumedRoleCred, "s3").Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) + + t.Run("request no authorization", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by unknown credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by local proxy credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies sts:AssumeRole output can be handled successfully. The + // credentials should be saved afterwards. + t.Run("handle sts:AssumeRole response", func(t *testing.T) { + response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) + + // This is the same request as the "unknown credentials" test above. But at + // this point, the assumed role credentials should have been saved by the + // middleware so the request can be handled successfully now. + t.Run("request signed by assumed role", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies non sts:AssumeRole responses do not give errors. + t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { + response := getCallerIdentityResponse(t, assumedRoleARN) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) } -func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response { +func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsProvider) *http.Response { t.Helper() - credValue, err := cred.Get() + credValue, err := provider.Retrieve(context.Background()) require.NoError(t, err) body, err := awsutils.MarshalXML( @@ -132,18 +112,18 @@ func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credenti }, map[string]any{ "AssumeRoleResult": sts.AssumeRoleOutput{ - AssumedRoleUser: &sts.AssumedRoleUser{ + AssumedRoleUser: &ststypes.AssumedRoleUser{ Arn: aws.String(roleARN), }, - Credentials: &sts.Credentials{ + Credentials: &ststypes.Credentials{ AccessKeyId: aws.String(credValue.AccessKeyID), SecretAccessKey: aws.String(credValue.SecretAccessKey), SessionToken: aws.String(credValue.SessionToken), }, }, - "ResponseMetadata": protocol.ResponseMetadata{ - StatusCode: http.StatusOK, - RequestID: "22222222-3333-3333-3333-333333333333", + "ResponseMetadata": map[string]any{ + "StatusCode": http.StatusOK, + "RequestID": "22222222-3333-3333-3333-333333333333", }, }, ) @@ -163,9 +143,9 @@ func getCallerIdentityResponse(t *testing.T, roleARN string) *http.Response { "GetCallerIdentityResult": sts.GetCallerIdentityOutput{ Arn: aws.String(roleARN), }, - "ResponseMetadata": protocol.ResponseMetadata{ - StatusCode: http.StatusOK, - RequestID: "22222222-3333-3333-3333-333333333333", + "ResponseMetadata": map[string]any{ + "StatusCode": http.StatusOK, + "RequestID": "22222222-3333-3333-3333-333333333333", }, }, ) diff --git a/lib/srv/alpnproxy/helpers_test.go b/lib/srv/alpnproxy/helpers_test.go index b4e9df20e83f0..90af912be29b3 100644 --- a/lib/srv/alpnproxy/helpers_test.go +++ b/lib/srv/alpnproxy/helpers_test.go @@ -153,6 +153,7 @@ func (s *Suite) Start(t *testing.T) { } func mustGenSelfSignedCert(t *testing.T) *tlsca.CertAuthority { + t.Helper() caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{ CommonName: "localhost", }, []string{"localhost"}, defaults.CATTL) @@ -183,6 +184,7 @@ func withClock(clock clockwork.Clock) signOptionsFunc { type signOptionsFunc func(o *signOptions) func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...signOptionsFunc) tls.Certificate { + t.Helper() options := signOptions{ identity: tlsca.Identity{Username: "test-user"}, clock: clockwork.NewRealClock(), @@ -218,6 +220,7 @@ func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...sign } func mustReadFromConnection(t *testing.T, conn net.Conn, want string) { + t.Helper() require.NoError(t, conn.SetReadDeadline(time.Now().Add(time.Second*5))) buff, err := io.ReadAll(conn) require.NoError(t, err) @@ -226,11 +229,13 @@ func mustReadFromConnection(t *testing.T, conn net.Conn, want string) { } func mustCloseConnection(t *testing.T, conn net.Conn) { + t.Helper() err := conn.Close() require.NoError(t, err) } func mustCreateLocalListener(t *testing.T) net.Listener { + t.Helper() l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { @@ -240,6 +245,7 @@ func mustCreateLocalListener(t *testing.T) net.Listener { } func mustCreateCertGenListener(t *testing.T, ca tls.Certificate) net.Listener { + t.Helper() listener, err := NewCertGenListener(CertGenListenerConfig{ ListenAddr: "localhost:0", CA: ca, @@ -253,23 +259,26 @@ func mustCreateCertGenListener(t *testing.T, ca tls.Certificate) net.Listener { } func mustSuccessfullyCallHTTPSServer(t *testing.T, addr string, client http.Client) { + t.Helper() mustCallHTTPSServerAndReceiveCode(t, addr, client, http.StatusOK) } func mustCallHTTPSServerAndReceiveCode(t *testing.T, addr string, client http.Client, expectStatusCode int) { + t.Helper() resp, err := client.Get(fmt.Sprintf("https://%s", addr)) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, expectStatusCode, resp.StatusCode) } -func mustStartHTTPServer(t *testing.T, l net.Listener) { +func mustStartHTTPServer(l net.Listener) { mux := http.NewServeMux() mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {}) go http.Serve(l, mux) } func mustStartLocalProxy(t *testing.T, config LocalProxyConfig) { + t.Helper() lp, err := NewLocalProxy(config) require.NoError(t, err) t.Cleanup(func() { diff --git a/lib/srv/alpnproxy/local_proxy_test.go b/lib/srv/alpnproxy/local_proxy_test.go index f7940ef22c069..347ca22e891ec 100644 --- a/lib/srv/alpnproxy/local_proxy_test.go +++ b/lib/srv/alpnproxy/local_proxy_test.go @@ -33,11 +33,13 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/gravitational/trace" "github.com/jackc/pgproto3/v2" "github.com/jonboulle/clockwork" @@ -53,72 +55,62 @@ import ( // TestHandleAWSAccessSigVerification tests if LocalProxy verifies the AWS SigV4 signature of incoming request. func TestHandleAWSAccessSigVerification(t *testing.T) { var ( - firstAWSCred = credentials.NewStaticCredentials("userID", "firstSecret", "") - secondAWSCred = credentials.NewStaticCredentials("userID", "secondSecret", "") - thirdAWSCred = credentials.NewStaticCredentials("userID2", "firstSecret", "") + firstAWSCred = credentials.NewStaticCredentialsProvider("userID", "firstSecret", "") + secondAWSCred = credentials.NewStaticCredentialsProvider("userID", "secondSecret", "") + thirdAWSCred = credentials.NewStaticCredentialsProvider("userID2", "firstSecret", "") - awsService = "s3" - awsRegion = "eu-central-1" + awsRegion = "eu-central-1" ) testCases := []struct { name string - proxyCred *credentials.Credentials - signFunc func(*http.Request, io.ReadSeeker, string, string, time.Time) (http.Header, error) - wantErr require.ErrorAssertionFunc + proxyCred aws.CredentialsProvider + clientCred aws.CredentialsProvider + apiOpts []func(*middleware.Stack) error wantStatus int }{ { name: "valid signature", proxyCred: firstAWSCred, - signFunc: v4.NewSigner(firstAWSCred).Sign, - wantErr: require.NoError, + clientCred: firstAWSCred, wantStatus: http.StatusOK, }, { name: "different aws secret access key", proxyCred: secondAWSCred, - signFunc: v4.NewSigner(firstAWSCred).Sign, + clientCred: firstAWSCred, wantStatus: http.StatusForbidden, }, { name: "different aws access key ID", proxyCred: thirdAWSCred, - signFunc: v4.NewSigner(firstAWSCred).Sign, + clientCred: firstAWSCred, wantStatus: http.StatusForbidden, }, { - name: "unsigned request", - proxyCred: firstAWSCred, - signFunc: func(*http.Request, io.ReadSeeker, string, string, time.Time) (http.Header, error) { - // no-op - return nil, nil - }, + name: "unsigned request", + proxyCred: firstAWSCred, + clientCred: nil, wantStatus: http.StatusForbidden, }, { - name: "signed with User-Agent header", - proxyCred: secondAWSCred, - signFunc: func(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) { - // Simulate a case where "User-Agent" is part of the "SignedHeaders". - // The signature does not have to be valid as it will not be compared. - header, err := v4.NewSigner(firstAWSCred).Sign(r, body, service, region, signTime) - if err != nil { - return nil, trace.Wrap(err) - } - - authHeader := r.Header.Get("Authorization") - authHeader = strings.Replace(authHeader, "SignedHeaders=", "SignedHeaders=user-agent;", 1) - r.Header.Set("Authorization", authHeader) - return header, nil + name: "signed with User-Agent header", + proxyCred: secondAWSCred, + clientCred: firstAWSCred, + apiOpts: []func(*middleware.Stack) error{ + func(stack *middleware.Stack) error { + stack.Finalize.Insert( + addUserAgentSignedHeaderMiddleware{}, + "Signing", + middleware.After, + ) + return nil + }, }, wantStatus: http.StatusOK, }, } - httpClient := &http.Client{ - Timeout: 5 * time.Second, - } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -132,45 +124,49 @@ func TestHandleAWSAccessSigVerification(t *testing.T) { Path: "/", } - payload := []byte("payload content") - req, err := http.NewRequest(http.MethodGet, url.String(), bytes.NewReader(payload)) - require.NoError(t, err) - - tc.signFunc(req, bytes.NewReader(payload), awsService, awsRegion, time.Now()) + clt := sts.New(sts.Options{ + APIOptions: tc.apiOpts, + Region: awsRegion, + Credentials: tc.clientCred, + BaseEndpoint: aws.String(url.String()), + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + RetryMaxAttempts: 0, + }) + _, err := clt.GetCallerIdentity(context.Background(), nil) + if tc.wantStatus == http.StatusOK { + require.NoError(t, err) + return + } - resp, err := httpClient.Do(req) - require.NoError(t, err) - require.Equal(t, tc.wantStatus, resp.StatusCode) - require.NoError(t, resp.Body.Close()) + require.Error(t, err) + var serr *awshttp.ResponseError + require.ErrorAs(t, err, &serr) + require.Equal(t, tc.wantStatus, serr.HTTPStatusCode()) }) } } // Verifies s3 requests are signed without URL escaping to match AWS SDKs. func TestHandleAWSAccessS3Signing(t *testing.T) { - cred := credentials.NewStaticCredentials("access-key", "secret-key", "") - lp := createAWSAccessProxySuite(t, cred) + provider := credentials.NewStaticCredentialsProvider("access-key", "secret-key", "") + lp := createAWSAccessProxySuite(t, provider) // Avoid loading extra things. t.Setenv("AWS_SDK_LOAD_CONFIG", "false") // Create a real AWS SDK s3 client. - awsConfig := aws.NewConfig(). - WithDisableSSL(true). - WithRegion("local"). - WithCredentials(cred). - WithEndpoint(lp.GetAddr()). - WithS3ForcePathStyle(true) - - s3client := s3.New(session.Must(session.NewSession(awsConfig)), - &aws.Config{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - MaxRetries: aws.Int(0), - }) + s3client := s3.New(s3.Options{ + Region: "local", + Credentials: provider, + BaseEndpoint: aws.String("http://" + lp.GetAddr()), + UsePathStyle: true, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + RetryMaxAttempts: 0, + }) // Use a bucket name with special charaters. AWS SDK actually signs the // request with the unescaped bucket name. - _, err := s3client.ListObjects(&s3.ListObjectsInput{ + _, err := s3client.ListObjects(context.Background(), &s3.ListObjectsInput{ Bucket: aws.String("=bucket=name="), }) @@ -628,7 +624,7 @@ func TestKubeMiddleware(t *testing.T) { } } -func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *LocalProxy { +func createAWSAccessProxySuite(t *testing.T, provider aws.CredentialsProvider) *LocalProxy { hs := httptest.NewTLSServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})) lp, err := NewLocalProxy(LocalProxyConfig{ @@ -637,7 +633,7 @@ func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *Loc Protocols: []common.Protocol{common.ProtocolHTTP}, ParentContext: context.Background(), InsecureSkipVerify: true, - HTTPMiddleware: &AWSAccessMiddleware{AWSCredentials: cred}, + HTTPMiddleware: &AWSAccessMiddleware{AWSCredentialsProvider: provider}, }) require.NoError(t, err) t.Cleanup(func() { @@ -767,3 +763,23 @@ func TestGetCertsForConn(t *testing.T) { }) } } + +type addUserAgentSignedHeaderMiddleware struct { +} + +func (m addUserAgentSignedHeaderMiddleware) ID() string { return "AddUserAgentSignedHeader" } +func (m addUserAgentSignedHeaderMiddleware) HandleFinalize( + ctx context.Context, + in middleware.FinalizeInput, + next middleware.FinalizeHandler, +) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, trace.Errorf("unexpected request middleware type %T", in.Request) + } + + authHeader := req.Header.Get("Authorization") + authHeader = strings.Replace(authHeader, "SignedHeaders=", "SignedHeaders=user-agent;", 1) + req.Header.Set("Authorization", authHeader) + return next.HandleFinalize(ctx, in) +} diff --git a/lib/srv/alpnproxy/proxy_test.go b/lib/srv/alpnproxy/proxy_test.go index 5c00b39a478a3..5e023485eee7d 100644 --- a/lib/srv/alpnproxy/proxy_test.go +++ b/lib/srv/alpnproxy/proxy_test.go @@ -328,7 +328,7 @@ func TestProxyHTTPConnection(t *testing.T) { lw := NewMuxListenerWrapper(l, suite.serverListener) - mustStartHTTPServer(t, lw) + mustStartHTTPServer(lw) suite.router = NewRouter() suite.router.Add(HandlerDecs{ @@ -359,7 +359,7 @@ func TestProxyMakeConnectionHandler(t *testing.T) { // Create a HTTP server and register the listener to ALPN server. lw := NewMuxListenerWrapper(nil, suite.serverListener) - mustStartHTTPServer(t, lw) + mustStartHTTPServer(lw) suite.router = NewRouter() suite.router.Add(HandlerDecs{ diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index b79afbdd23f0a..7b0933b2c16b5 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -31,15 +31,16 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/service/iam" "github.com/gravitational/trace" apievents "github.com/gravitational/teleport/api/types/events" apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/aws/migration" ) const ( @@ -75,6 +76,8 @@ const ( // used by the AssumeRole call. // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_iam-quotas.html MaxRoleSessionNameLength = 64 + + iamServiceName = "iam" ) // SigV4 contains parsed content of the AWS Authorization header. @@ -147,6 +150,11 @@ func IsSignedByAWSSigV4(r *http.Request) bool { return strings.HasPrefix(r.Header.Get(AuthorizationHeader), AmazonSigV4AuthorizationPrefix) } +// VerifyAWSSignatureV2 is a temporary AWS SDK migration helper. +func VerifyAWSSignatureV2(req *http.Request, provider aws.CredentialsProvider) error { + return VerifyAWSSignature(req, migration.NewCredentialsAdapter(provider)) +} + // VerifyAWSSignature verifies the request signature ensuring that the request originates from tsh aws command execution // AWS CLI signs the request with random generated credentials that are passed to LocalProxy by // the AWSCredentials LocalProxyConfig configuration. @@ -214,6 +222,11 @@ func VerifyAWSSignature(req *http.Request, credentials *credentials.Credentials) return nil } +// NewSignerV2 is a temporary AWS SDK migration helper. +func NewSignerV2(provider aws.CredentialsProvider, signingServiceName string) *v4.Signer { + return NewSigner(migration.NewCredentialsAdapter(provider), signingServiceName) +} + // NewSigner creates a new V4 signer. func NewSigner(credentials *credentials.Credentials, signingServiceName string) *v4.Signer { options := func(s *v4.Signer) { @@ -384,7 +397,7 @@ func BuildRoleARN(username, region, accountID string) (string, error) { } roleARN := arn.ARN{ Partition: partition, - Service: iam.ServiceName, + Service: iamServiceName, AccountID: accountID, Resource: resource, } @@ -424,7 +437,7 @@ func ParseRoleARN(roleARN string) (*arn.ARN, error) { // Example role ARN: arn:aws:iam::123456789012:role/some-role-name func checkRoleARN(parsed *arn.ARN) error { parts := strings.Split(parsed.Resource, "/") - if parts[0] != "role" || parsed.Service != iam.ServiceName { + if parts[0] != "role" || parsed.Service != iamServiceName { return trace.BadParameter("%q is not an AWS IAM role ARN", parsed) } if len(parts) < 2 || len(parts[len(parts)-1]) == 0 { diff --git a/lib/utils/aws/migration/migration.go b/lib/utils/aws/migration/migration.go index 5288f2ab2a13c..43fdb36b5673a 100644 --- a/lib/utils/aws/migration/migration.go +++ b/lib/utils/aws/migration/migration.go @@ -27,6 +27,12 @@ import ( "github.com/gravitational/trace" ) +// NewCredentialsAdapter adapts an AWS SDK v2 credentials provider to v1 +// credentials. +func NewCredentialsAdapter(providerV2 awsv2.CredentialsProvider) *credentials.Credentials { + return credentials.NewCredentials(NewProviderAdapter(providerV2)) +} + // NewProviderAdapter returns a [ProviderAdapter] that can be used as an AWS SDK // v1 credentials provider. func NewProviderAdapter(providerV2 awsv2.CredentialsProvider) *ProviderAdapter { diff --git a/tool/tsh/common/app_aws.go b/tool/tsh/common/app_aws.go index 7c11458b2f300..aaa6adeec6b01 100644 --- a/tool/tsh/common/app_aws.go +++ b/tool/tsh/common/app_aws.go @@ -170,7 +170,7 @@ func (a *awsApp) GetAppName() string { // through forward proxy. func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error { awsMiddleware := &alpnproxy.AWSAccessMiddleware{ - AWSCredentialsV2Provider: a.GetAWSCredentialsProvider(), + AWSCredentialsProvider: a.GetAWSCredentialsProvider(), } // AWS endpoint URL mode