Skip to content

Commit

Permalink
migrate tsh common to aws sdk v2 (#49449)
Browse files Browse the repository at this point in the history
* add AWS SDK v2 migration utils package

* migrate tool/tsh/common to AWS SDK v2

* adapt AWSAccessMiddleware to AWS SDK v2
  • Loading branch information
GavinFrazar authored Dec 9, 2024
1 parent 02cadb4 commit 3208c3d
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 82 deletions.
12 changes: 11 additions & 1 deletion lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"strings"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
Expand All @@ -33,6 +34,7 @@ import (
appcommon "github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
"github.com/gravitational/teleport/lib/utils/aws/migration"
)

// AWSAccessMiddleware verifies the requests to AWS proxy are properly signed.
Expand All @@ -42,6 +44,11 @@ type AWSAccessMiddleware struct {
// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
AWSCredentials *credentials.Credentials

// AWSCredentialsV2Provider is an aws sdk v2 credential provider used by
// LocalProxy for request's signature verification if AWSCredentials is not
// specified.
AWSCredentialsV2Provider awsv2.CredentialsProvider

Log logrus.FieldLogger

assumedRoles utils.SyncMap[string, *sts.AssumeRoleOutput]
Expand All @@ -55,7 +62,10 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
}

if m.AWSCredentials == nil {
return trace.BadParameter("missing AWSCredentials")
if m.AWSCredentialsV2Provider == nil {
return trace.BadParameter("missing AWSCredentials")
}
m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider))
}

return nil
Expand Down
125 changes: 73 additions & 52 deletions lib/srv/alpnproxy/aws_local_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"testing"
"time"

credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
Expand All @@ -42,60 +43,80 @@ func TestAWSAccessMiddleware(t *testing.T) {
localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token")

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())
tests := []struct {
name string
middleware *AWSAccessMiddleware
}{
{
name: "v1",
middleware: &AWSAccessMiddleware{
AWSCredentials: localProxyCred,
},
},
{
name: "v2",
middleware: &AWSAccessMiddleware{
AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
},
},
}

m := &AWSAccessMiddleware{
AWSCredentials: localProxyCred,
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m := test.middleware
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
})
}
require.NoError(t, m.CheckAndSetDefaults())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
}

func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response {
Expand Down
92 changes: 92 additions & 0 deletions lib/utils/aws/migration/migration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package migration

import (
"context"
"sync"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/gravitational/trace"
)

// NewProviderAdapter returns a [ProviderAdapter] that can be used as an AWS SDK
// v1 credentials provider.
func NewProviderAdapter(providerV2 awsv2.CredentialsProvider) *ProviderAdapter {
return &ProviderAdapter{
providerV2: providerV2,
}
}

var _ credentials.ProviderWithContext = (*ProviderAdapter)(nil)

// ProviderAdapter adapts an [awsv2.CredentialsProvider] to an AWS SDK v1
// credentials provider.
type ProviderAdapter struct {
providerV2 awsv2.CredentialsProvider

m sync.RWMutex
// creds are retrieved and saved to satisfy IsExpired.
creds awsv2.Credentials
}

func (a *ProviderAdapter) IsExpired() bool {
a.m.RLock()
defer a.m.RUnlock()

var emptyCreds awsv2.Credentials
return a.creds == emptyCreds || a.creds.Expired()
}

func (a *ProviderAdapter) Retrieve() (credentials.Value, error) {
return a.RetrieveWithContext(context.Background())
}

func (a *ProviderAdapter) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
creds, err := a.retrieveLocked(ctx)
if err != nil {
return credentials.Value{}, trace.Wrap(err)
}

return credentials.Value{
AccessKeyID: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
ProviderName: creds.Source,
}, nil
}

func (a *ProviderAdapter) retrieveLocked(ctx context.Context) (awsv2.Credentials, error) {
a.m.Lock()
defer a.m.Unlock()

var emptyCreds awsv2.Credentials
if a.creds != emptyCreds && !a.creds.Expired() {
return a.creds, nil
}

creds, err := a.providerV2.Retrieve(ctx)
if err != nil {
return emptyCreds, trace.Wrap(err)
}

a.creds = creds
return creds, nil
}
45 changes: 16 additions & 29 deletions tool/tsh/common/app_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import (
"strings"
"sync"

awsarn "github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/google/uuid"
"github.com/gravitational/trace"

Expand Down Expand Up @@ -136,7 +137,7 @@ type awsApp struct {

cf *CLIConf

credentials *credentials.Credentials
credentials aws.CredentialsProvider
credentialsOnce sync.Once
}

Expand Down Expand Up @@ -168,13 +169,8 @@ func (a *awsApp) GetAppName() string {
// The first method is always preferred as the original hostname is preserved
// through forward proxy.
func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error {
cred, err := a.GetAWSCredentials()
if err != nil {
return trace.Wrap(err)
}

awsMiddleware := &alpnproxy.AWSAccessMiddleware{
AWSCredentials: cred,
AWSCredentialsV2Provider: a.GetAWSCredentialsProvider(),
}

// AWS endpoint URL mode
Expand All @@ -184,14 +180,14 @@ func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalP
}

// HTTPS proxy mode
err = a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware))
err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware))
return trace.Wrap(err)
}

// GetAWSCredentials generates fake AWS credentials that are used for
// signing an AWS request during AWS API calls and verified on local AWS proxy
// side.
func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) {
// GetAWSCredentialsProvider returns an [aws.CredentialsProvider] that generates
// fake AWS credentials that are used for signing an AWS request during AWS API
// calls and verified on local AWS proxy side.
func (a *awsApp) GetAWSCredentialsProvider() aws.CredentialsProvider {
// There is no specific format or value required for access key and secret,
// as long as the AWS clients and the local proxy are using the same
// credentials. The only constraint is the access key must have a length
Expand All @@ -200,17 +196,13 @@ func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) {
//
// https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html
a.credentialsOnce.Do(func() {
a.credentials = credentials.NewStaticCredentials(
a.credentials = credentials.NewStaticCredentialsProvider(
getEnvOrDefault(awsAccessKeyIDEnvVar, uuid.NewString()),
getEnvOrDefault(awsSecretAccessKeyEnvVar, uuid.NewString()),
"",
)
})

if a.credentials == nil {
return nil, trace.BadParameter("missing credentials")
}
return a.credentials, nil
return a.credentials
}

// GetEnvVars returns required environment variables to configure the
Expand All @@ -220,12 +212,7 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) {
return nil, trace.NotFound("ALPN proxy is not running")
}

cred, err := a.GetAWSCredentials()
if err != nil {
return nil, trace.Wrap(err)
}

credValues, err := cred.Get()
cred, err := a.GetAWSCredentialsProvider().Retrieve(context.Background())
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -234,8 +221,8 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) {
// AWS CLI and SDKs can load credentials through environment variables.
//
// https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html
"AWS_ACCESS_KEY_ID": credValues.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": credValues.SecretAccessKey,
"AWS_ACCESS_KEY_ID": cred.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": cred.SecretAccessKey,
"AWS_CA_BUNDLE": a.appInfo.appLocalCAPath(a.cf.SiteName),
}

Expand Down Expand Up @@ -318,7 +305,7 @@ func getARNFromFlags(cf *CLIConf, app types.Application, logins []string) (strin
}

// Match by role ARN.
if awsarn.IsARN(cf.AWSRole) {
if arn.IsARN(cf.AWSRole) {
if role, found := roles.FindRoleByARN(cf.AWSRole); found {
return role.ARN, nil
}
Expand Down

0 comments on commit 3208c3d

Please sign in to comment.