Skip to content

Commit

Permalink
migrate lib/srv/alpnproxy to AWS SDK v2
Browse files Browse the repository at this point in the history
  • Loading branch information
GavinFrazar committed Dec 11, 2024
1 parent 622e578 commit b2809ee
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 193 deletions.
47 changes: 19 additions & 28 deletions lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ import (
"net/http"
"strings"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

Expand All @@ -34,20 +33,15 @@ import (
appcommon "github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
"github.com/gravitational/teleport/lib/utils/aws/migration"
)

// AWSAccessMiddleware verifies the requests to AWS proxy are properly signed.
type AWSAccessMiddleware struct {
DefaultLocalProxyHTTPMiddleware

// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
AWSCredentials *credentials.Credentials

// AWSCredentialsV2Provider is an aws sdk v2 credential provider used by
// LocalProxy for request's signature verification if AWSCredentials is not
// specified.
AWSCredentialsV2Provider awsv2.CredentialsProvider
// AWSCredentialsProvider provides credentials for local proxy request
// signature verification.
AWSCredentialsProvider aws.CredentialsProvider

Log logrus.FieldLogger

Expand All @@ -61,11 +55,8 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
m.Log = logrus.WithField(teleport.ComponentKey, "aws_access")
}

if m.AWSCredentials == nil {
if m.AWSCredentialsV2Provider == nil {
return trace.BadParameter("missing AWSCredentials")
}
m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider))
if m.AWSCredentialsProvider == nil {
return trace.BadParameter("missing AWS credentials")
}

return nil
Expand Down Expand Up @@ -143,7 +134,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re
}

func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool {
if err := awsutils.VerifyAWSSignature(req, m.AWSCredentials); err != nil {
if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
rw.WriteHeader(http.StatusForbidden)
return true
Expand All @@ -152,22 +143,22 @@ func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *h
}

func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter, req *http.Request, assumedRole *sts.AssumeRoleOutput) bool {
credentials := credentials.NewStaticCredentials(
aws.StringValue(assumedRole.Credentials.AccessKeyId),
aws.StringValue(assumedRole.Credentials.SecretAccessKey),
aws.StringValue(assumedRole.Credentials.SessionToken),
credentials := credentials.NewStaticCredentialsProvider(
aws.ToString(assumedRole.Credentials.AccessKeyId),
aws.ToString(assumedRole.Credentials.SecretAccessKey),
aws.ToString(assumedRole.Credentials.SessionToken),
)

if err := awsutils.VerifyAWSSignature(req, credentials); err != nil {
if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
rw.WriteHeader(http.StatusForbidden)
return true
}

m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn))
m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Add a custom header for marking the special request.
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.StringValue(assumedRole.AssumedRoleUser.Arn))
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Rename the original authorization header to ensure older app agents
// (that don't support the requests by assumed roles) will fail.
Expand All @@ -191,7 +182,7 @@ func (m *AWSAccessMiddleware) HandleResponse(response *http.Response) error {
return nil
}

if strings.EqualFold(sigV4.Service, sts.EndpointsID) {
if strings.EqualFold(sigV4.Service, "sts") {
return trace.Wrap(m.handleSTSResponse(response))
}
return nil
Expand Down Expand Up @@ -219,8 +210,8 @@ func (m *AWSAccessMiddleware) handleSTSResponse(response *http.Response) error {
return nil
}

m.assumedRoles.Store(aws.StringValue(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn))
m.assumedRoles.Store(aws.ToString(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))
return nil
}

Expand Down
162 changes: 71 additions & 91 deletions lib/srv/alpnproxy/aws_local_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@
package alpnproxy

import (
"context"
"encoding/xml"
"net/http"
"net/http/httptest"
"testing"
"time"

credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/stretchr/testify/require"

awsutils "github.com/gravitational/teleport/lib/utils/aws"
Expand All @@ -40,89 +39,70 @@ func TestAWSAccessMiddleware(t *testing.T) {
t.Parallel()

assumedRoleARN := "arn:aws:sts::123456789012:assumed-role/role-name/role-session-name"
localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token")

tests := []struct {
name string
middleware *AWSAccessMiddleware
}{
{
name: "v1",
middleware: &AWSAccessMiddleware{
AWSCredentials: localProxyCred,
},
},
{
name: "v2",
middleware: &AWSAccessMiddleware{
AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
},
},
}
localProxyCred := credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentialsProvider("assumed-role", "assumed-role-secret", "assumed-role-token")

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m := test.middleware
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
})
m := &AWSAccessMiddleware{
AWSCredentialsProvider: credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
}
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)

awsutils.NewSignerV2(localProxyCred, "sts").Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
awsutils.NewSignerV2(assumedRoleCred, "s3").Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
}

func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response {
func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsProvider) *http.Response {
t.Helper()

credValue, err := cred.Get()
credValue, err := provider.Retrieve(context.Background())
require.NoError(t, err)

body, err := awsutils.MarshalXML(
Expand All @@ -132,18 +112,18 @@ func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credenti
},
map[string]any{
"AssumeRoleResult": sts.AssumeRoleOutput{
AssumedRoleUser: &sts.AssumedRoleUser{
AssumedRoleUser: &ststypes.AssumedRoleUser{
Arn: aws.String(roleARN),
},
Credentials: &sts.Credentials{
Credentials: &ststypes.Credentials{
AccessKeyId: aws.String(credValue.AccessKeyID),
SecretAccessKey: aws.String(credValue.SecretAccessKey),
SessionToken: aws.String(credValue.SessionToken),
},
},
"ResponseMetadata": protocol.ResponseMetadata{
StatusCode: http.StatusOK,
RequestID: "22222222-3333-3333-3333-333333333333",
"ResponseMetadata": map[string]any{
"StatusCode": http.StatusOK,
"RequestID": "22222222-3333-3333-3333-333333333333",
},
},
)
Expand All @@ -163,9 +143,9 @@ func getCallerIdentityResponse(t *testing.T, roleARN string) *http.Response {
"GetCallerIdentityResult": sts.GetCallerIdentityOutput{
Arn: aws.String(roleARN),
},
"ResponseMetadata": protocol.ResponseMetadata{
StatusCode: http.StatusOK,
RequestID: "22222222-3333-3333-3333-333333333333",
"ResponseMetadata": map[string]any{
"StatusCode": http.StatusOK,
"RequestID": "22222222-3333-3333-3333-333333333333",
},
},
)
Expand Down
11 changes: 10 additions & 1 deletion lib/srv/alpnproxy/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func (s *Suite) Start(t *testing.T) {
}

func mustGenSelfSignedCert(t *testing.T) *tlsca.CertAuthority {
t.Helper()
caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{
CommonName: "localhost",
}, []string{"localhost"}, defaults.CATTL)
Expand Down Expand Up @@ -183,6 +184,7 @@ func withClock(clock clockwork.Clock) signOptionsFunc {
type signOptionsFunc func(o *signOptions)

func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...signOptionsFunc) tls.Certificate {
t.Helper()
options := signOptions{
identity: tlsca.Identity{Username: "test-user"},
clock: clockwork.NewRealClock(),
Expand Down Expand Up @@ -218,6 +220,7 @@ func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...sign
}

func mustReadFromConnection(t *testing.T, conn net.Conn, want string) {
t.Helper()
require.NoError(t, conn.SetReadDeadline(time.Now().Add(time.Second*5)))
buff, err := io.ReadAll(conn)
require.NoError(t, err)
Expand All @@ -226,11 +229,13 @@ func mustReadFromConnection(t *testing.T, conn net.Conn, want string) {
}

func mustCloseConnection(t *testing.T, conn net.Conn) {
t.Helper()
err := conn.Close()
require.NoError(t, err)
}

func mustCreateLocalListener(t *testing.T) net.Listener {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
Expand All @@ -240,6 +245,7 @@ func mustCreateLocalListener(t *testing.T) net.Listener {
}

func mustCreateCertGenListener(t *testing.T, ca tls.Certificate) net.Listener {
t.Helper()
listener, err := NewCertGenListener(CertGenListenerConfig{
ListenAddr: "localhost:0",
CA: ca,
Expand All @@ -253,23 +259,26 @@ func mustCreateCertGenListener(t *testing.T, ca tls.Certificate) net.Listener {
}

func mustSuccessfullyCallHTTPSServer(t *testing.T, addr string, client http.Client) {
t.Helper()
mustCallHTTPSServerAndReceiveCode(t, addr, client, http.StatusOK)
}

func mustCallHTTPSServerAndReceiveCode(t *testing.T, addr string, client http.Client, expectStatusCode int) {
t.Helper()
resp, err := client.Get(fmt.Sprintf("https://%s", addr))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, expectStatusCode, resp.StatusCode)
}

func mustStartHTTPServer(t *testing.T, l net.Listener) {
func mustStartHTTPServer(l net.Listener) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {})
go http.Serve(l, mux)
}

func mustStartLocalProxy(t *testing.T, config LocalProxyConfig) {
t.Helper()
lp, err := NewLocalProxy(config)
require.NoError(t, err)
t.Cleanup(func() {
Expand Down
Loading

0 comments on commit b2809ee

Please sign in to comment.