Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate alpnproxy to AWS SDK v2 #49973

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 19 additions & 28 deletions lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ import (
"net/http"
"strings"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

Expand All @@ -34,20 +33,15 @@ import (
appcommon "github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
"github.com/gravitational/teleport/lib/utils/aws/migration"
)

// AWSAccessMiddleware verifies the requests to AWS proxy are properly signed.
type AWSAccessMiddleware struct {
DefaultLocalProxyHTTPMiddleware

// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
AWSCredentials *credentials.Credentials

// AWSCredentialsV2Provider is an aws sdk v2 credential provider used by
// LocalProxy for request's signature verification if AWSCredentials is not
// specified.
AWSCredentialsV2Provider awsv2.CredentialsProvider
// AWSCredentialsProvider provides credentials for local proxy request
// signature verification.
AWSCredentialsProvider aws.CredentialsProvider

Log logrus.FieldLogger

Expand All @@ -61,11 +55,8 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
m.Log = logrus.WithField(teleport.ComponentKey, "aws_access")
}

if m.AWSCredentials == nil {
if m.AWSCredentialsV2Provider == nil {
return trace.BadParameter("missing AWSCredentials")
}
m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider))
if m.AWSCredentialsProvider == nil {
return trace.BadParameter("missing AWS credentials")
}

return nil
Expand Down Expand Up @@ -143,7 +134,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re
}

func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool {
if err := awsutils.VerifyAWSSignature(req, m.AWSCredentials); err != nil {
if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
rw.WriteHeader(http.StatusForbidden)
return true
Expand All @@ -152,22 +143,22 @@ func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *h
}

func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter, req *http.Request, assumedRole *sts.AssumeRoleOutput) bool {
credentials := credentials.NewStaticCredentials(
aws.StringValue(assumedRole.Credentials.AccessKeyId),
aws.StringValue(assumedRole.Credentials.SecretAccessKey),
aws.StringValue(assumedRole.Credentials.SessionToken),
credentials := credentials.NewStaticCredentialsProvider(
aws.ToString(assumedRole.Credentials.AccessKeyId),
aws.ToString(assumedRole.Credentials.SecretAccessKey),
aws.ToString(assumedRole.Credentials.SessionToken),
)

if err := awsutils.VerifyAWSSignature(req, credentials); err != nil {
if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
rw.WriteHeader(http.StatusForbidden)
return true
}

m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn))
m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Add a custom header for marking the special request.
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.StringValue(assumedRole.AssumedRoleUser.Arn))
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Rename the original authorization header to ensure older app agents
// (that don't support the requests by assumed roles) will fail.
Expand All @@ -191,7 +182,7 @@ func (m *AWSAccessMiddleware) HandleResponse(response *http.Response) error {
return nil
}

if strings.EqualFold(sigV4.Service, sts.EndpointsID) {
if strings.EqualFold(sigV4.Service, "sts") {
return trace.Wrap(m.handleSTSResponse(response))
}
return nil
Expand Down Expand Up @@ -219,8 +210,8 @@ func (m *AWSAccessMiddleware) handleSTSResponse(response *http.Response) error {
return nil
}

m.assumedRoles.Store(aws.StringValue(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.StringValue(assumedRole.AssumedRoleUser.Arn))
m.assumedRoles.Store(aws.ToString(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))
return nil
}

Expand Down
162 changes: 71 additions & 91 deletions lib/srv/alpnproxy/aws_local_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@
package alpnproxy

import (
"context"
"encoding/xml"
"net/http"
"net/http/httptest"
"testing"
"time"

credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/stretchr/testify/require"

awsutils "github.com/gravitational/teleport/lib/utils/aws"
Expand All @@ -40,89 +39,70 @@ func TestAWSAccessMiddleware(t *testing.T) {
t.Parallel()

assumedRoleARN := "arn:aws:sts::123456789012:assumed-role/role-name/role-session-name"
localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token")

tests := []struct {
name string
middleware *AWSAccessMiddleware
}{
{
name: "v1",
middleware: &AWSAccessMiddleware{
AWSCredentials: localProxyCred,
},
},
{
name: "v2",
middleware: &AWSAccessMiddleware{
AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
},
},
}
localProxyCred := credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentialsProvider("assumed-role", "assumed-role-secret", "assumed-role-token")

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m := test.middleware
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
})
m := &AWSAccessMiddleware{
AWSCredentialsProvider: credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
}
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)

awsutils.NewSignerV2(localProxyCred, "sts").Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
awsutils.NewSignerV2(assumedRoleCred, "s3").Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
}

func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response {
func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsProvider) *http.Response {
t.Helper()

credValue, err := cred.Get()
credValue, err := provider.Retrieve(context.Background())
require.NoError(t, err)

body, err := awsutils.MarshalXML(
Expand All @@ -132,18 +112,18 @@ func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credenti
},
map[string]any{
"AssumeRoleResult": sts.AssumeRoleOutput{
AssumedRoleUser: &sts.AssumedRoleUser{
AssumedRoleUser: &ststypes.AssumedRoleUser{
Arn: aws.String(roleARN),
},
Credentials: &sts.Credentials{
Credentials: &ststypes.Credentials{
AccessKeyId: aws.String(credValue.AccessKeyID),
SecretAccessKey: aws.String(credValue.SecretAccessKey),
SessionToken: aws.String(credValue.SessionToken),
},
},
"ResponseMetadata": protocol.ResponseMetadata{
StatusCode: http.StatusOK,
RequestID: "22222222-3333-3333-3333-333333333333",
"ResponseMetadata": map[string]any{
"StatusCode": http.StatusOK,
"RequestID": "22222222-3333-3333-3333-333333333333",
},
},
)
Expand All @@ -163,9 +143,9 @@ func getCallerIdentityResponse(t *testing.T, roleARN string) *http.Response {
"GetCallerIdentityResult": sts.GetCallerIdentityOutput{
Arn: aws.String(roleARN),
},
"ResponseMetadata": protocol.ResponseMetadata{
StatusCode: http.StatusOK,
RequestID: "22222222-3333-3333-3333-333333333333",
"ResponseMetadata": map[string]any{
"StatusCode": http.StatusOK,
"RequestID": "22222222-3333-3333-3333-333333333333",
},
},
)
Expand Down
11 changes: 10 additions & 1 deletion lib/srv/alpnproxy/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func (s *Suite) Start(t *testing.T) {
}

func mustGenSelfSignedCert(t *testing.T) *tlsca.CertAuthority {
t.Helper()
caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{
CommonName: "localhost",
}, []string{"localhost"}, defaults.CATTL)
Expand Down Expand Up @@ -183,6 +184,7 @@ func withClock(clock clockwork.Clock) signOptionsFunc {
type signOptionsFunc func(o *signOptions)

func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...signOptionsFunc) tls.Certificate {
t.Helper()
options := signOptions{
identity: tlsca.Identity{Username: "test-user"},
clock: clockwork.NewRealClock(),
Expand Down Expand Up @@ -218,6 +220,7 @@ func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...sign
}

func mustReadFromConnection(t *testing.T, conn net.Conn, want string) {
t.Helper()
require.NoError(t, conn.SetReadDeadline(time.Now().Add(time.Second*5)))
buff, err := io.ReadAll(conn)
require.NoError(t, err)
Expand All @@ -226,11 +229,13 @@ func mustReadFromConnection(t *testing.T, conn net.Conn, want string) {
}

func mustCloseConnection(t *testing.T, conn net.Conn) {
t.Helper()
err := conn.Close()
require.NoError(t, err)
}

func mustCreateLocalListener(t *testing.T) net.Listener {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
Expand All @@ -240,6 +245,7 @@ func mustCreateLocalListener(t *testing.T) net.Listener {
}

func mustCreateCertGenListener(t *testing.T, ca tls.Certificate) net.Listener {
t.Helper()
listener, err := NewCertGenListener(CertGenListenerConfig{
ListenAddr: "localhost:0",
CA: ca,
Expand All @@ -253,23 +259,26 @@ func mustCreateCertGenListener(t *testing.T, ca tls.Certificate) net.Listener {
}

func mustSuccessfullyCallHTTPSServer(t *testing.T, addr string, client http.Client) {
t.Helper()
mustCallHTTPSServerAndReceiveCode(t, addr, client, http.StatusOK)
}

func mustCallHTTPSServerAndReceiveCode(t *testing.T, addr string, client http.Client, expectStatusCode int) {
t.Helper()
resp, err := client.Get(fmt.Sprintf("https://%s", addr))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, expectStatusCode, resp.StatusCode)
}

func mustStartHTTPServer(t *testing.T, l net.Listener) {
func mustStartHTTPServer(l net.Listener) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {})
go http.Serve(l, mux)
}

func mustStartLocalProxy(t *testing.T, config LocalProxyConfig) {
t.Helper()
lp, err := NewLocalProxy(config)
require.NoError(t, err)
t.Cleanup(func() {
Expand Down
Loading
Loading