Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing an optional ARN when health checking an AWSOIDC integration #46935

Merged
merged 3 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 106 additions & 93 deletions api/gen/proto/go/teleport/integration/v1/awsoidc_service.pb.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion api/proto/teleport/integration/v1/awsoidc_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,13 @@ message ListEKSClustersResponse {
// PingRequest is a request for doing an health check against the configured integration.
message PingRequest {
// Integration is the AWS OIDC Integration name.
// Required.
// Required if ARN is empty.
string integration = 1;

// The AWS Role ARN to be used when generating the token.
// This is used to test another ARN before saving the Integration.
// Required if integration is empty.
string role_arn = 2;
}

// PingResponse contains the response for the Ping operation.
Expand Down
52 changes: 35 additions & 17 deletions lib/auth/integration/integrationv1/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,48 @@ func NewAWSOIDCService(cfg *AWSOIDCServiceConfig) (*AWSOIDCService, error) {

var _ integrationpb.AWSOIDCServiceServer = (*AWSOIDCService)(nil)

func (s *AWSOIDCService) awsClientReq(ctx context.Context, integrationName, region string) (*awsoidc.AWSClientRequest, error) {
func (s *AWSOIDCService) roleARNForIntegration(ctx context.Context, integrationName string) (string, error) {
integration, err := s.integrationService.GetIntegration(ctx, &integrationpb.GetIntegrationRequest{
Name: integrationName,
})
if err != nil {
return nil, trace.Wrap(err)
return "", trace.Wrap(err)
}

if integration.GetSubKind() != types.IntegrationSubKindAWSOIDC {
return nil, trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
return "", trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
}

if integration.GetAWSOIDCIntegrationSpec() == nil {
return nil, trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
return "", trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
}

return integration.GetAWSOIDCIntegrationSpec().RoleARN, nil
}

func (s *AWSOIDCService) awsClientReqWithARN(ctx context.Context, integrationName, region, arn string) (*awsoidc.AWSClientRequest, error) {
token, err := s.integrationService.generateAWSOIDCTokenWithoutAuthZ(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}

return &awsoidc.AWSClientRequest{
IntegrationName: integrationName,
Token: token.Token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: region,
Token: token.Token,
RoleARN: arn,
Region: region,
}, nil
}

func (s *AWSOIDCService) awsClientReq(ctx context.Context, integrationName, region string) (*awsoidc.AWSClientRequest, error) {
roleARN, err := s.roleARNForIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}

return s.awsClientReqWithARN(ctx, integrationName, region, roleARN)

}

// ListEICE returns a paginated list of EC2 Instance Connect Endpoints.
func (s *AWSOIDCService) ListEICE(ctx context.Context, req *integrationpb.ListEICERequest) (*integrationpb.ListEICEResponse, error) {
authCtx, err := s.authorizer.Authorize(ctx)
Expand Down Expand Up @@ -788,15 +801,20 @@ func (s *AWSOIDCService) Ping(ctx context.Context, req *integrationpb.PingReques
return nil, trace.Wrap(err)
}

if req.Integration == "" {
return nil, trace.BadParameter("integration is required")
}

// Instead of asking the user for a region (or storing a default region), we use the sentinel value for the global region.
// This improves the UX, because it is one less input we require from the user.
awsClientReq, err := s.awsClientReq(ctx, req.Integration, awsutils.AWSGlobalRegion)
if err != nil {
return nil, trace.Wrap(err)
var awsClientReq *awsoidc.AWSClientRequest
switch {
case req.GetRoleArn() != "":
awsClientReq, err = s.awsClientReqWithARN(ctx, req.Integration, awsutils.AWSGlobalRegion, req.GetRoleArn())
if err != nil {
return nil, trace.Wrap(err)
}
case req.GetIntegration() != "":
awsClientReq, err = s.awsClientReq(ctx, req.GetIntegration(), awsutils.AWSGlobalRegion)
if err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.BadParameter("one of arn and integration is required")
}

awsClient, err := awsoidc.NewPingClient(ctx, awsClientReq)
Expand Down
15 changes: 10 additions & 5 deletions lib/auth/integration/integrationv1/awsoidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ func TestGenerateAWSOIDCToken(t *testing.T) {
})
require.NoError(t, err)

t.Run("without integration (v15 and lower clients) returns an error", func(t *testing.T) {
_, err := resourceSvc.GenerateAWSOIDCToken(ctx, &integrationv1.GenerateAWSOIDCTokenRequest{})
require.Error(t, err)
})

t.Run("with integration in rpc call but no issuer defined", func(t *testing.T) {
resp, err := resourceSvc.GenerateAWSOIDCToken(ctx, &integrationv1.GenerateAWSOIDCTokenRequest{
Integration: integrationNameWithoutIssuer,
Expand Down Expand Up @@ -323,6 +318,16 @@ func TestRBAC(t *testing.T) {
return err
},
},
{
name: "Ping with arn",
fn: func() error {
_, err := awsoidService.Ping(userCtx, &integrationv1.PingRequest{
Integration: integrationName,
RoleArn: "some-arn",
})
return err
},
},
} {
t.Run(tt.name, func(t *testing.T) {
err := tt.fn()
Expand Down
7 changes: 0 additions & 7 deletions lib/integrations/awsoidc/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ import (

// AWSClientRequest contains the required fields to set up an AWS service client.
type AWSClientRequest struct {
// IntegrationName is the integration name that is going to issue an API Call.
IntegrationName string

// Token is the token used to issue the API Call.
Token string

Expand All @@ -55,10 +52,6 @@ type AWSClientRequest struct {

// CheckAndSetDefaults checks if the required fields are present.
func (req *AWSClientRequest) CheckAndSetDefaults() error {
if req.IntegrationName == "" {
return trace.BadParameter("integration name is required")
}

if req.Token == "" {
return trace.BadParameter("token is required")
}
Expand Down
21 changes: 9 additions & 12 deletions lib/integrations/awsoidc/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,27 @@ import (
func TestCheckAndSetDefaults(t *testing.T) {
t.Run("invalid regions must return an error", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "?",
Token: "token",
RoleARN: "some-arn",
Region: "?",
}).CheckAndSetDefaults()

require.True(t, trace.IsBadParameter(err))
})
t.Run("valid region", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "us-east-1",
Token: "token",
RoleARN: "some-arn",
Region: "us-east-1",
}).CheckAndSetDefaults()
require.NoError(t, err)
})

t.Run("empty region", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "",
Token: "token",
RoleARN: "some-arn",
Region: "",
}).CheckAndSetDefaults()
require.NoError(t, err)
})
Expand Down
9 changes: 4 additions & 5 deletions lib/integrations/awsoidc/deployservice_vcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ func TestDeployDBService(t *testing.T) {
return &AWSClientRequest{
// To record new fixtures you will need a valid token.
// You can get one by getting the generated token in a real cluster.
Token: awsOIDCToken,
RoleARN: awsOIDCRoleARN,
Region: awsRegion,
IntegrationName: integrationName,
httpClient: httpClient,
Token: awsOIDCToken,
RoleARN: awsOIDCRoleARN,
Region: awsRegion,
httpClient: httpClient,
}
}

Expand Down
43 changes: 25 additions & 18 deletions lib/integrations/awsoidc/token_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type KeyStoreManager interface {
// GenerateAWSOIDCTokenRequest contains the required elements to generate an AWS OIDC Token (JWT).
type GenerateAWSOIDCTokenRequest struct {
// Integration is the AWS OIDC Integration name.
// This field is only used to obtain custom Issuers (those stored at S3 buckets).
// If empty, the default issuer for the cluster (its public endpoint URL) will be used.
Integration string
// Username is the JWT Username (on behalf of claim)
Username string
Expand All @@ -70,9 +72,6 @@ type GenerateAWSOIDCTokenRequest struct {

// CheckAndSetDefaults checks the request params.
func (g *GenerateAWSOIDCTokenRequest) CheckAndSetDefaults() error {
if g.Integration == "" {
return trace.BadParameter("integration missing")
}
if g.Username == "" {
return trace.BadParameter("username missing")
}
Expand All @@ -86,7 +85,28 @@ func (g *GenerateAWSOIDCTokenRequest) CheckAndSetDefaults() error {
return nil
}

func issuerForIntegration(ctx context.Context, integration types.Integration, cacheClt Cache) (string, error) {
// IssuerForIntegration returns the issuer for a given integration.
// Returns the default Issuer (oidc.IssuerForCluster) if integrationName is empty.
// All calls should be replaced with oidc.IssuerForCluster when IssuerS3URI is removed (it is currently deprecated).
func issuerForIntegration(ctx context.Context, cacheClt Cache, integrationName string) (string, error) {
if integrationName == "" {
issuer, err := oidc.IssuerForCluster(ctx, cacheClt, "")
return issuer, trace.Wrap(err)
}

integration, err := cacheClt.GetIntegration(ctx, integrationName)
if err != nil {
return "", trace.Wrap(err)
}

if integration.GetSubKind() != types.IntegrationSubKindAWSOIDC {
return "", trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
}

if integration.GetAWSOIDCIntegrationSpec() == nil {
return "", trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
}

issuerS3URI := integration.GetAWSOIDCIntegrationSpec().IssuerS3URI
if issuerS3URI == "" {
issuer, err := oidc.IssuerForCluster(ctx, cacheClt, "")
Expand All @@ -107,20 +127,7 @@ func GenerateAWSOIDCToken(ctx context.Context, cacheClt Cache, keyStoreManager K
return "", trace.Wrap(err)
}

integration, err := cacheClt.GetIntegration(ctx, req.Integration)
if err != nil {
return "", trace.Wrap(err)
}

if integration.GetSubKind() != types.IntegrationSubKindAWSOIDC {
return "", trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
}

if integration.GetAWSOIDCIntegrationSpec() == nil {
return "", trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
}

issuer, err := issuerForIntegration(ctx, integration, cacheClt)
issuer, err := issuerForIntegration(ctx, cacheClt, req.Integration)
if err != nil {
return "", trace.Wrap(err)
}
Expand Down
7 changes: 3 additions & 4 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,9 @@ func (s *localSite) setupTunnelForOpenSSHEICENode(ctx context.Context, targetSer
}

openTunnelClt, err := awsoidc.NewOpenTunnelEC2Client(ctx, &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
})
if err != nil {
return nil, trace.BadParameter("failed to create the ec2 open tunnel client: %v", err)
Expand Down
7 changes: 3 additions & 4 deletions lib/service/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,9 @@ func (updater *AWSOIDCDeployServiceUpdater) updateAWSOIDCDeployService(ctx conte
}

req := &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsRegion,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsRegion,
}

// The deploy service client is initialized using AWS OIDC integration.
Expand Down
7 changes: 3 additions & 4 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -690,10 +690,9 @@ func (s *Server) sendSSHPublicKeyToTarget(ctx context.Context) (ssh.Signer, erro
}

sendSSHClient, err := awsoidc.NewEICESendSSHPublicKeyClient(ctx, &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
})
if err != nil {
return nil, trace.BadParameter("failed to create an aws client to send ssh public key: %v", err)
Expand Down
7 changes: 7 additions & 0 deletions lib/web/integrations_awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,7 @@ func getServiceURLs(dbServices []types.DatabaseService, accountID, region, telep
}

// awsOIDCPing performs an health check for the integration.
// If ARN is present in the request body, that's the ARN that will be used instead of using the one stored in the integration.
// Returns meta information: account id and assumed the ARN for the IAM Role.
func (h *Handler) awsOIDCPing(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (any, error) {
ctx := r.Context()
Expand All @@ -1443,13 +1444,19 @@ func (h *Handler) awsOIDCPing(w http.ResponseWriter, r *http.Request, p httprout
return nil, trace.BadParameter("an integration name is required")
}

var req ui.AWSOIDCPingRequest
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}

clt, err := sctx.GetUserClient(ctx, site)
if err != nil {
return nil, trace.Wrap(err)
}

pingResp, err := clt.IntegrationAWSOIDCClient().Ping(ctx, &integrationv1.PingRequest{
Integration: integrationName,
RoleArn: req.RoleARN,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
8 changes: 8 additions & 0 deletions lib/web/ui/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,11 @@ type AWSOIDCPingResponse struct {
// UserID is the unique identifier of the calling entity.
UserID string `json:"userId"`
}

// AWSOIDCPingRequest contains ping request fields.
type AWSOIDCPingRequest struct {
// RoleARN is optional, and used for cases such as
// pinging to check validity before upserting an
// AWS OIDC integration.
RoleARN string `json:"roleArn,omitempty"`
}
Loading