Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v17] migrate tsh common to aws sdk v2 #50112

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"strings"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
Expand All @@ -33,6 +34,7 @@ import (
appcommon "github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
"github.com/gravitational/teleport/lib/utils/aws/migration"
)

// AWSAccessMiddleware verifies the requests to AWS proxy are properly signed.
Expand All @@ -42,6 +44,11 @@ type AWSAccessMiddleware struct {
// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
AWSCredentials *credentials.Credentials

// AWSCredentialsV2Provider is an aws sdk v2 credential provider used by
// LocalProxy for request's signature verification if AWSCredentials is not
// specified.
AWSCredentialsV2Provider awsv2.CredentialsProvider

Log logrus.FieldLogger

assumedRoles utils.SyncMap[string, *sts.AssumeRoleOutput]
Expand All @@ -55,7 +62,10 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
}

if m.AWSCredentials == nil {
return trace.BadParameter("missing AWSCredentials")
if m.AWSCredentialsV2Provider == nil {
return trace.BadParameter("missing AWSCredentials")
}
m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider))
}

return nil
Expand Down
125 changes: 73 additions & 52 deletions lib/srv/alpnproxy/aws_local_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"testing"
"time"

credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
Expand All @@ -42,60 +43,80 @@ func TestAWSAccessMiddleware(t *testing.T) {
localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "")
assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token")

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())
tests := []struct {
name string
middleware *AWSAccessMiddleware
}{
{
name: "v1",
middleware: &AWSAccessMiddleware{
AWSCredentials: localProxyCred,
},
},
{
name: "v2",
middleware: &AWSAccessMiddleware{
AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
},
},
}

m := &AWSAccessMiddleware{
AWSCredentials: localProxyCred,
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m := test.middleware
require.NoError(t, m.CheckAndSetDefaults())

stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())

requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
})
}
require.NoError(t, m.CheckAndSetDefaults())

t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil)))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by unknown credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.True(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusForbidden, recorder.Code)
})

t.Run("request signed by local proxy credentials", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies sts:AssumeRole output can be handled successfully. The
// credentials should be saved afterwards.
t.Run("handle sts:AssumeRole response", func(t *testing.T) {
response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})

// This is the same request as the "unknown credentials" test above. But at
// this point, the assumed role credentials should have been saved by the
// middleware so the request can be handled successfully now.
t.Run("request signed by assumed role", func(t *testing.T) {
recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, requestByAssumedRole))
require.Equal(t, http.StatusOK, recorder.Code)
})

// Verifies non sts:AssumeRole responses do not give errors.
t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) {
response := getCallerIdentityResponse(t, assumedRoleARN)
response.Request = stsRequestByLocalProxyCred
defer response.Body.Close()
require.NoError(t, m.HandleResponse(response))
})
}

func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response {
Expand Down
92 changes: 92 additions & 0 deletions lib/utils/aws/migration/migration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package migration

import (
"context"
"sync"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/gravitational/trace"
)

// NewProviderAdapter returns a [ProviderAdapter] that can be used as an AWS SDK
// v1 credentials provider.
func NewProviderAdapter(providerV2 awsv2.CredentialsProvider) *ProviderAdapter {
return &ProviderAdapter{
providerV2: providerV2,
}
}

var _ credentials.ProviderWithContext = (*ProviderAdapter)(nil)

// ProviderAdapter adapts an [awsv2.CredentialsProvider] to an AWS SDK v1
// credentials provider.
type ProviderAdapter struct {
providerV2 awsv2.CredentialsProvider

m sync.RWMutex
// creds are retrieved and saved to satisfy IsExpired.
creds awsv2.Credentials
}

func (a *ProviderAdapter) IsExpired() bool {
a.m.RLock()
defer a.m.RUnlock()

var emptyCreds awsv2.Credentials
return a.creds == emptyCreds || a.creds.Expired()
}

func (a *ProviderAdapter) Retrieve() (credentials.Value, error) {
return a.RetrieveWithContext(context.Background())
}

func (a *ProviderAdapter) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
creds, err := a.retrieveLocked(ctx)
if err != nil {
return credentials.Value{}, trace.Wrap(err)
}

return credentials.Value{
AccessKeyID: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
ProviderName: creds.Source,
}, nil
}

func (a *ProviderAdapter) retrieveLocked(ctx context.Context) (awsv2.Credentials, error) {
a.m.Lock()
defer a.m.Unlock()

var emptyCreds awsv2.Credentials
if a.creds != emptyCreds && !a.creds.Expired() {
return a.creds, nil
}

creds, err := a.providerV2.Retrieve(ctx)
if err != nil {
return emptyCreds, trace.Wrap(err)
}

a.creds = creds
return creds, nil
}
45 changes: 16 additions & 29 deletions tool/tsh/common/app_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import (
"strings"
"sync"

awsarn "github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/google/uuid"
"github.com/gravitational/trace"

Expand Down Expand Up @@ -136,7 +137,7 @@ type awsApp struct {

cf *CLIConf

credentials *credentials.Credentials
credentials aws.CredentialsProvider
credentialsOnce sync.Once
}

Expand Down Expand Up @@ -168,13 +169,8 @@ func (a *awsApp) GetAppName() string {
// The first method is always preferred as the original hostname is preserved
// through forward proxy.
func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error {
cred, err := a.GetAWSCredentials()
if err != nil {
return trace.Wrap(err)
}

awsMiddleware := &alpnproxy.AWSAccessMiddleware{
AWSCredentials: cred,
AWSCredentialsV2Provider: a.GetAWSCredentialsProvider(),
}

// AWS endpoint URL mode
Expand All @@ -184,14 +180,14 @@ func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalP
}

// HTTPS proxy mode
err = a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware))
err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware))
return trace.Wrap(err)
}

// GetAWSCredentials generates fake AWS credentials that are used for
// signing an AWS request during AWS API calls and verified on local AWS proxy
// side.
func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) {
// GetAWSCredentialsProvider returns an [aws.CredentialsProvider] that generates
// fake AWS credentials that are used for signing an AWS request during AWS API
// calls and verified on local AWS proxy side.
func (a *awsApp) GetAWSCredentialsProvider() aws.CredentialsProvider {
// There is no specific format or value required for access key and secret,
// as long as the AWS clients and the local proxy are using the same
// credentials. The only constraint is the access key must have a length
Expand All @@ -200,17 +196,13 @@ func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) {
//
// https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html
a.credentialsOnce.Do(func() {
a.credentials = credentials.NewStaticCredentials(
a.credentials = credentials.NewStaticCredentialsProvider(
getEnvOrDefault(awsAccessKeyIDEnvVar, uuid.NewString()),
getEnvOrDefault(awsSecretAccessKeyEnvVar, uuid.NewString()),
"",
)
})

if a.credentials == nil {
return nil, trace.BadParameter("missing credentials")
}
return a.credentials, nil
return a.credentials
}

// GetEnvVars returns required environment variables to configure the
Expand All @@ -220,12 +212,7 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) {
return nil, trace.NotFound("ALPN proxy is not running")
}

cred, err := a.GetAWSCredentials()
if err != nil {
return nil, trace.Wrap(err)
}

credValues, err := cred.Get()
cred, err := a.GetAWSCredentialsProvider().Retrieve(context.Background())
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -234,8 +221,8 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) {
// AWS CLI and SDKs can load credentials through environment variables.
//
// https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html
"AWS_ACCESS_KEY_ID": credValues.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": credValues.SecretAccessKey,
"AWS_ACCESS_KEY_ID": cred.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": cred.SecretAccessKey,
"AWS_CA_BUNDLE": a.appInfo.appLocalCAPath(a.cf.SiteName),
}

Expand Down Expand Up @@ -318,7 +305,7 @@ func getARNFromFlags(cf *CLIConf, app types.Application, logins []string) (strin
}

// Match by role ARN.
if awsarn.IsARN(cf.AWSRole) {
if arn.IsARN(cf.AWSRole) {
if role, found := roles.FindRoleByARN(cf.AWSRole); found {
return role.ARN, nil
}
Expand Down
Loading