Skip to content

Commit

Permalink
AWS OIDC Ping: allow custom ARN
Browse files Browse the repository at this point in the history
This PR changes the Ping method to accept a custom ARN.

This is meant to be used by WebUI to do a health check for the
integration:
- when creating
- when editing
- when selecting during Discover flows

If the Ping method receives an ARN, it will use that value instead of
using the one stored in the backend.
  • Loading branch information
marcoandredinis committed Oct 15, 2024
1 parent 39a86de commit 716ed7c
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 168 deletions.
194 changes: 103 additions & 91 deletions api/gen/proto/go/teleport/integration/v1/awsoidc_service.pb.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion api/proto/teleport/integration/v1/awsoidc_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,13 @@ message ListEKSClustersResponse {
// PingRequest is a request for doing an health check against the configured integration.
message PingRequest {
// Integration is the AWS OIDC Integration name.
// Required.
// Required if ARN is empty.
string integration = 1;

// The AWS ARN to be used when generating the token.
// This is used to test another ARN before saving the Integration.
// Required if integration is empty.
string arn = 2;
}

// PingResponse contains the response for the Ping operation.
Expand Down
52 changes: 35 additions & 17 deletions lib/auth/integration/integrationv1/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,48 @@ func NewAWSOIDCService(cfg *AWSOIDCServiceConfig) (*AWSOIDCService, error) {

var _ integrationpb.AWSOIDCServiceServer = (*AWSOIDCService)(nil)

func (s *AWSOIDCService) awsClientReq(ctx context.Context, integrationName, region string) (*awsoidc.AWSClientRequest, error) {
func (s *AWSOIDCService) roleARNForIntegration(ctx context.Context, integrationName string) (string, error) {
integration, err := s.integrationService.GetIntegration(ctx, &integrationpb.GetIntegrationRequest{
Name: integrationName,
})
if err != nil {
return nil, trace.Wrap(err)
return "", trace.Wrap(err)
}

if integration.GetSubKind() != types.IntegrationSubKindAWSOIDC {
return nil, trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
return "", trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
}

if integration.GetAWSOIDCIntegrationSpec() == nil {
return nil, trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
return "", trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
}

return integration.GetAWSOIDCIntegrationSpec().RoleARN, nil
}

func (s *AWSOIDCService) awsClientReqWithARN(ctx context.Context, integrationName, region, arn string) (*awsoidc.AWSClientRequest, error) {
token, err := s.integrationService.generateAWSOIDCTokenWithoutAuthZ(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}

return &awsoidc.AWSClientRequest{
IntegrationName: integrationName,
Token: token.Token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: region,
Token: token.Token,
RoleARN: arn,
Region: region,
}, nil
}

func (s *AWSOIDCService) awsClientReq(ctx context.Context, integrationName, region string) (*awsoidc.AWSClientRequest, error) {
roleARN, err := s.roleARNForIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}

return s.awsClientReqWithARN(ctx, integrationName, region, roleARN)

}

// ListEICE returns a paginated list of EC2 Instance Connect Endpoints.
func (s *AWSOIDCService) ListEICE(ctx context.Context, req *integrationpb.ListEICERequest) (*integrationpb.ListEICEResponse, error) {
authCtx, err := s.authorizer.Authorize(ctx)
Expand Down Expand Up @@ -788,15 +801,20 @@ func (s *AWSOIDCService) Ping(ctx context.Context, req *integrationpb.PingReques
return nil, trace.Wrap(err)
}

if req.Integration == "" {
return nil, trace.BadParameter("integration is required")
}

// Instead of asking the user for a region (or storing a default region), we use the sentinel value for the global region.
// This improves the UX, because it is one less input we require from the user.
awsClientReq, err := s.awsClientReq(ctx, req.Integration, awsutils.AWSGlobalRegion)
if err != nil {
return nil, trace.Wrap(err)
var awsClientReq *awsoidc.AWSClientRequest
switch {
case req.Arn != "":
awsClientReq, err = s.awsClientReqWithARN(ctx, req.Integration, awsutils.AWSGlobalRegion, req.Arn)
if err != nil {
return nil, trace.Wrap(err)
}
case req.Integration != "":
awsClientReq, err = s.awsClientReq(ctx, req.Integration, awsutils.AWSGlobalRegion)
if err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.BadParameter("one of arn and integration is required")
}

awsClient, err := awsoidc.NewPingClient(ctx, awsClientReq)
Expand Down
15 changes: 10 additions & 5 deletions lib/auth/integration/integrationv1/awsoidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ func TestGenerateAWSOIDCToken(t *testing.T) {
})
require.NoError(t, err)

t.Run("without integration (v15 and lower clients) returns an error", func(t *testing.T) {
_, err := resourceSvc.GenerateAWSOIDCToken(ctx, &integrationv1.GenerateAWSOIDCTokenRequest{})
require.Error(t, err)
})

t.Run("with integration in rpc call but no issuer defined", func(t *testing.T) {
resp, err := resourceSvc.GenerateAWSOIDCToken(ctx, &integrationv1.GenerateAWSOIDCTokenRequest{
Integration: integrationNameWithoutIssuer,
Expand Down Expand Up @@ -323,6 +318,16 @@ func TestRBAC(t *testing.T) {
return err
},
},
{
name: "Ping with arn",
fn: func() error {
_, err := awsoidService.Ping(userCtx, &integrationv1.PingRequest{
Integration: integrationName,
Arn: "some-arn",
})
return err
},
},
} {
t.Run(tt.name, func(t *testing.T) {
err := tt.fn()
Expand Down
7 changes: 0 additions & 7 deletions lib/integrations/awsoidc/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ import (

// AWSClientRequest contains the required fields to set up an AWS service client.
type AWSClientRequest struct {
// IntegrationName is the integration name that is going to issue an API Call.
IntegrationName string

// Token is the token used to issue the API Call.
Token string

Expand All @@ -55,10 +52,6 @@ type AWSClientRequest struct {

// CheckAndSetDefaults checks if the required fields are present.
func (req *AWSClientRequest) CheckAndSetDefaults() error {
if req.IntegrationName == "" {
return trace.BadParameter("integration name is required")
}

if req.Token == "" {
return trace.BadParameter("token is required")
}
Expand Down
24 changes: 12 additions & 12 deletions lib/integrations/awsoidc/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,30 @@ import (
func TestCheckAndSetDefaults(t *testing.T) {
t.Run("invalid regions must return an error", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "?",
//IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "?",
}).CheckAndSetDefaults()

require.True(t, trace.IsBadParameter(err))
})
t.Run("valid region", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "us-east-1",
//IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "us-east-1",
}).CheckAndSetDefaults()
require.NoError(t, err)
})

t.Run("empty region", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "",
//IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "",
}).CheckAndSetDefaults()
require.NoError(t, err)
})
Expand Down
9 changes: 4 additions & 5 deletions lib/integrations/awsoidc/deployservice_vcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ func TestDeployDBService(t *testing.T) {
return &AWSClientRequest{
// To record new fixtures you will need a valid token.
// You can get one by getting the generated token in a real cluster.
Token: awsOIDCToken,
RoleARN: awsOIDCRoleARN,
Region: awsRegion,
IntegrationName: integrationName,
httpClient: httpClient,
Token: awsOIDCToken,
RoleARN: awsOIDCRoleARN,
Region: awsRegion,
httpClient: httpClient,
}
}

Expand Down
43 changes: 25 additions & 18 deletions lib/integrations/awsoidc/token_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type KeyStoreManager interface {
// GenerateAWSOIDCTokenRequest contains the required elements to generate an AWS OIDC Token (JWT).
type GenerateAWSOIDCTokenRequest struct {
// Integration is the AWS OIDC Integration name.
// This field is only used to obtain custom Issuers (those stored at S3 buckets).
// If empty, the default issuer for the cluster (its public endpoint URL) will be used.
Integration string
// Username is the JWT Username (on behalf of claim)
Username string
Expand All @@ -70,9 +72,6 @@ type GenerateAWSOIDCTokenRequest struct {

// CheckAndSetDefaults checks the request params.
func (g *GenerateAWSOIDCTokenRequest) CheckAndSetDefaults() error {
if g.Integration == "" {
return trace.BadParameter("integration missing")
}
if g.Username == "" {
return trace.BadParameter("username missing")
}
Expand All @@ -86,7 +85,28 @@ func (g *GenerateAWSOIDCTokenRequest) CheckAndSetDefaults() error {
return nil
}

func issuerForIntegration(ctx context.Context, integration types.Integration, cacheClt Cache) (string, error) {
// IssuerForIntegration returns the issuer for a given integration.
// Returns the default Issuer (oidc.IssuerForCluster) if integrationName is empty.
// All calls should be replaced with oidc.IssuerForCluster when IssuerS3URI is removed (it is currently deprecated).
func issuerForIntegration(ctx context.Context, cacheClt Cache, integrationName string) (string, error) {
if integrationName == "" {
issuer, err := oidc.IssuerForCluster(ctx, cacheClt, "")
return issuer, trace.Wrap(err)
}

integration, err := cacheClt.GetIntegration(ctx, integrationName)
if err != nil {
return "", trace.Wrap(err)
}

if integration.GetSubKind() != types.IntegrationSubKindAWSOIDC {
return "", trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
}

if integration.GetAWSOIDCIntegrationSpec() == nil {
return "", trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
}

issuerS3URI := integration.GetAWSOIDCIntegrationSpec().IssuerS3URI
if issuerS3URI == "" {
issuer, err := oidc.IssuerForCluster(ctx, cacheClt, "")
Expand All @@ -107,20 +127,7 @@ func GenerateAWSOIDCToken(ctx context.Context, cacheClt Cache, keyStoreManager K
return "", trace.Wrap(err)
}

integration, err := cacheClt.GetIntegration(ctx, req.Integration)
if err != nil {
return "", trace.Wrap(err)
}

if integration.GetSubKind() != types.IntegrationSubKindAWSOIDC {
return "", trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
}

if integration.GetAWSOIDCIntegrationSpec() == nil {
return "", trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
}

issuer, err := issuerForIntegration(ctx, integration, cacheClt)
issuer, err := issuerForIntegration(ctx, cacheClt, req.Integration)
if err != nil {
return "", trace.Wrap(err)
}
Expand Down
7 changes: 3 additions & 4 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,9 @@ func (s *localSite) setupTunnelForOpenSSHEICENode(ctx context.Context, targetSer
}

openTunnelClt, err := awsoidc.NewOpenTunnelEC2Client(ctx, &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
})
if err != nil {
return nil, trace.BadParameter("failed to create the ec2 open tunnel client: %v", err)
Expand Down
7 changes: 3 additions & 4 deletions lib/service/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,9 @@ func (updater *AWSOIDCDeployServiceUpdater) updateAWSOIDCDeployService(ctx conte
}

req := &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsRegion,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsRegion,
}

// The deploy service client is initialized using AWS OIDC integration.
Expand Down
7 changes: 3 additions & 4 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,9 @@ func (s *Server) sendSSHPublicKeyToTarget(ctx context.Context) (ssh.Signer, erro
}

sendSSHClient, err := awsoidc.NewEICESendSSHPublicKeyClient(ctx, &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
})
if err != nil {
return nil, trace.BadParameter("failed to create an aws client to send ssh public key: %v", err)
Expand Down
7 changes: 7 additions & 0 deletions lib/web/integrations_awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,7 @@ func getServiceURLs(dbServices []types.DatabaseService, accountID, region, telep
}

// awsOIDCPing performs an health check for the integration.
// If ARN is present in the request body, that's the ARN that will be used instead of using the one stored in the integration.
// Returns meta information: account id and assumed the ARN for the IAM Role.
func (h *Handler) awsOIDCPing(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (any, error) {
ctx := r.Context()
Expand All @@ -1443,13 +1444,19 @@ func (h *Handler) awsOIDCPing(w http.ResponseWriter, r *http.Request, p httprout
return nil, trace.BadParameter("an integration name is required")
}

var req ui.AWSOIDCPingRequest
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}

clt, err := sctx.GetUserClient(ctx, site)
if err != nil {
return nil, trace.Wrap(err)
}

pingResp, err := clt.IntegrationAWSOIDCClient().Ping(ctx, &integrationv1.PingRequest{
Integration: integrationName,
Arn: req.ARN,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
8 changes: 8 additions & 0 deletions lib/web/ui/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,11 @@ type AWSOIDCPingResponse struct {
// UserID is the unique identifier of the calling entity.
UserID string `json:"userId"`
}

// AWSOIDCPingRequest contains ping request fields.
type AWSOIDCPingRequest struct {
// ARN is optional, and used for cases such as
// pinging to check validity before upserting an
// AWS OIDC integration.
ARN string `json:"arn,omitempty"`
}

0 comments on commit 716ed7c

Please sign in to comment.