Skip to content

Commit

Permalink
migrate lib/srv/alpnproxy to AWS SDK v2
Browse files Browse the repository at this point in the history
  • Loading branch information
GavinFrazar committed Dec 9, 2024
1 parent 622e578 commit 117608a
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 191 deletions.
47 changes: 19 additions & 28 deletions lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ import (
"net/http"
"strings"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

Expand All @@ -34,20 +33,15 @@ import (
appcommon "github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
"github.com/gravitational/teleport/lib/utils/aws/migration"
)

// AWSAccessMiddleware verifies the requests to AWS proxy are properly signed.
type AWSAccessMiddleware struct {
DefaultLocalProxyHTTPMiddleware

// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
AWSCredentials *credentials.Credentials

// AWSCredentialsV2Provider is an aws sdk v2 credential provider used by
// LocalProxy for request's signature verification if AWSCredentials is not
// specified.
AWSCredentialsV2Provider awsv2.CredentialsProvider
// AWSCredentialsProvider provides credentials for local proxy request
// signature verification.
AWSCredentialsProvider aws.CredentialsProvider

Log logrus.FieldLogger

Expand All @@ -61,11 +55,8 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
m.Log = logrus.WithField(teleport.ComponentKey, "aws_access")
}

if m.AWSCredentials == nil {
if m.AWSCredentialsV2Provider == nil {
return trace.BadParameter("missing AWSCredentials")
}
m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider))
if m.AWSCredentialsProvider == nil {
return trace.BadParameter("missing AWS credentials")
}

return nil
Expand Down Expand Up @@ -143,7 +134,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re
}

func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool {
if err := awsutils.VerifyAWSSignature(req, m.AWSCredentials); err != nil {
if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
rw.WriteHeader(http.StatusForbidden)
return true
Expand All @@ -152,22 +143,22 @@ func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *h
}

func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter, req *http.Request, assumedRole *sts.AssumeRoleOutput) bool {
credentials := credentials.NewStaticCredentials(
aws.StringValue(assumedRole.Credentials.AccessKeyId),
aws.StringValue(assumedRole.Credentials.SecretAccessKey),
aws.StringValue(assumedRole.Credentials.SessionToken),
credentials := credentials.NewStaticCredentialsProvider(
aws.ToString(assumedRole.Credentials.AccessKeyId),
aws.ToString(assumedRole.Credentials.SecretAccessKey),
aws.ToString(assumedRole.Credentials.SessionToken),
)

if err := awsutils.VerifyAWSSignature(req, credentials); err != nil {
if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
rw.WriteHeader(http.StatusForbidden)
return true
}

m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn))
m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Add a custom header for marking the special request.
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.StringValue(assumedRole.AssumedRoleUser.Arn))
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Rename the original authorization header to ensure older app agents
// (that don't support the requests by assumed roles) will fail.
Expand All @@ -191,7 +182,7 @@ func (m *AWSAccessMiddleware) HandleResponse(response *http.Response) error {
return nil
}

if strings.EqualFold(sigV4.Service, sts.EndpointsID) {
if strings.EqualFold(sigV4.Service, "sts") {
return trace.Wrap(m.handleSTSResponse(response))
}
return nil
Expand Down Expand Up @@ -219,8 +210,8 @@ func (m *AWSAccessMiddleware) handleSTSResponse(response *http.Response) error {
return nil
}

m.assumedRoles.Store(aws.StringValue(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn))
m.assumedRoles.Store(aws.ToString(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))
return nil
}

Expand Down
162 changes: 71 additions & 91 deletions lib/srv/alpnproxy/aws_local_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@
package alpnproxy

import (
"context"
"encoding/xml"
"net/http"
"net/http/httptest"
"testing"
"time"

credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/stretchr/testify/require"

awsutils "github.com/gravitational/teleport/lib/utils/aws"
Expand All @@ -40,89 +39,70 @@ func TestAWSAccessMiddleware(t *testing.T) {
t.Parallel()

assumedRoleARN := "arn:aws:sts::123456789012:assumed-role/role-name/role-session-name"
localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token")

tests := []struct {
name string
middleware *AWSAccessMiddleware
}{
{
name: "v1",
middleware: &AWSAccessMiddleware{
AWSCredentials: localProxyCred,
},
},
{
name: "v2",
middleware: &AWSAccessMiddleware{
AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
},
},
}
localProxyCred := credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentialsProvider("assumed-role", "assumed-role-secret", "assumed-role-token")

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m := test.middleware
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
})
m := &AWSAccessMiddleware{
AWSCredentialsProvider: credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
}
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)

awsutils.NewSignerV2(localProxyCred, "sts").Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
awsutils.NewSignerV2(assumedRoleCred, "s3").Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
}

func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response {
func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsProvider) *http.Response {
t.Helper()

credValue, err := cred.Get()
credValue, err := provider.Retrieve(context.Background())
require.NoError(t, err)

body, err := awsutils.MarshalXML(
Expand All @@ -132,18 +112,18 @@ func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credenti
},
map[string]any{
"AssumeRoleResult": sts.AssumeRoleOutput{
AssumedRoleUser: &sts.AssumedRoleUser{
AssumedRoleUser: &ststypes.AssumedRoleUser{
Arn: aws.String(roleARN),
},
Credentials: &sts.Credentials{
Credentials: &ststypes.Credentials{
AccessKeyId: aws.String(credValue.AccessKeyID),
SecretAccessKey: aws.String(credValue.SecretAccessKey),
SessionToken: aws.String(credValue.SessionToken),
},
},
"ResponseMetadata": protocol.ResponseMetadata{
StatusCode: http.StatusOK,
RequestID: "22222222-3333-3333-3333-333333333333",
"ResponseMetadata": map[string]any{
"StatusCode": http.StatusOK,
"RequestID": "22222222-3333-3333-3333-333333333333",
},
},
)
Expand All @@ -163,9 +143,9 @@ func getCallerIdentityResponse(t *testing.T, roleARN string) *http.Response {
"GetCallerIdentityResult": sts.GetCallerIdentityOutput{
Arn: aws.String(roleARN),
},
"ResponseMetadata": protocol.ResponseMetadata{
StatusCode: http.StatusOK,
RequestID: "22222222-3333-3333-3333-333333333333",
"ResponseMetadata": map[string]any{
"StatusCode": http.StatusOK,
"RequestID": "22222222-3333-3333-3333-333333333333",
},
},
)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/alpnproxy/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func mustCallHTTPSServerAndReceiveCode(t *testing.T, addr string, client http.Cl
require.Equal(t, expectStatusCode, resp.StatusCode)
}

func mustStartHTTPServer(t *testing.T, l net.Listener) {
func mustStartHTTPServer(_ *testing.T, l net.Listener) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {})
go http.Serve(l, mux)
Expand Down
Loading

0 comments on commit 117608a

Please sign in to comment.