diff --git a/api/gen/proto/go/teleport/machineid/v1/bot_instance.pb.go b/api/gen/proto/go/teleport/machineid/v1/bot_instance.pb.go index 757c72160aa17..ec0d5c2dd24d3 100644 --- a/api/gen/proto/go/teleport/machineid/v1/bot_instance.pb.go +++ b/api/gen/proto/go/teleport/machineid/v1/bot_instance.pb.go @@ -318,11 +318,14 @@ type BotInstanceStatusAuthentication struct { // Server. AuthenticatedAt *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=authenticated_at,json=authenticatedAt,proto3" json:"authenticated_at,omitempty"` // The join method used for this join or renewal. + // Deprecated: prefer using join_attrs.meta.join_method JoinMethod string `protobuf:"bytes,2,opt,name=join_method,json=joinMethod,proto3" json:"join_method,omitempty"` // The join token used for this join or renewal. This is only populated for // delegated join methods as the value for `token` join methods is sensitive. + // Deprecated: prefer using join_attrs.meta.join_token_name JoinToken string `protobuf:"bytes,3,opt,name=join_token,json=joinToken,proto3" json:"join_token,omitempty"` // The metadata sourced from the join method. + // Deprecated: prefer using join_attrs. Metadata *structpb.Struct `protobuf:"bytes,4,opt,name=metadata,proto3" json:"metadata,omitempty"` // On each renewal, this generation is incremented. For delegated join // methods, this counter is not checked during renewal. For the `token` join diff --git a/api/proto/teleport/machineid/v1/bot_instance.proto b/api/proto/teleport/machineid/v1/bot_instance.proto index 5904e8896a6bd..76a3820f2bfac 100644 --- a/api/proto/teleport/machineid/v1/bot_instance.proto +++ b/api/proto/teleport/machineid/v1/bot_instance.proto @@ -90,12 +90,16 @@ message BotInstanceStatusAuthentication { // Server. google.protobuf.Timestamp authenticated_at = 1; // The join method used for this join or renewal. + // Deprecated: prefer using join_attrs.meta.join_method string join_method = 2; // The join token used for this join or renewal. This is only populated for // delegated join methods as the value for `token` join methods is sensitive. + // Deprecated: prefer using join_attrs.meta.join_token_name string join_token = 3; // The metadata sourced from the join method. + // Deprecated: prefer using join_attrs. google.protobuf.Struct metadata = 4; + // On each renewal, this generation is incremented. For delegated join // methods, this counter is not checked during renewal. For the `token` join // method, this counter is checked during renewal and the Bot is locked out if diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 067cea661c7e1..82bd49e68befb 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -71,6 +71,7 @@ import ( headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1" notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/types" @@ -2290,6 +2291,9 @@ type certRequest struct { // botInstanceID is the unique identifier of the bot instance associated // with this cert, if any botInstanceID string + // joinAttributes holds attributes derived from attested metadata from the + // join process, should any exist. + joinAttributes *workloadidentityv1pb.JoinAttrs } // check verifies the cert request is valid. @@ -3370,7 +3374,8 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. AssetTag: req.deviceExtensions.AssetTag, CredentialID: req.deviceExtensions.CredentialID, }, - UserType: req.user.GetUserType(), + UserType: req.user.GetUserType(), + JoinAttributes: req.joinAttributes, } var signedTLSCert []byte diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 85cc6fe6237b1..fe50d3d0af68d 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -3453,6 +3453,9 @@ func (a *ServerWithRoles) generateUserCerts(ctx context.Context, req proto.UserC // `updateBotInstance()` is called below, and this (empty) value will be // overridden. botInstanceID: a.context.Identity.GetIdentity().BotInstanceID, + // Propagate any join attributes from the current identity to the new + // identity. + joinAttributes: a.context.Identity.GetIdentity().JoinAttributes, } if user.GetName() != a.context.User.GetName() { diff --git a/lib/auth/bot.go b/lib/auth/bot.go index 104518ea7687e..c08ae5f1d7580 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -31,6 +31,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" @@ -315,7 +316,7 @@ func (a *Server) updateBotInstance( if templateAuthRecord != nil { authRecord.JoinToken = templateAuthRecord.JoinToken authRecord.JoinMethod = templateAuthRecord.JoinMethod - authRecord.Metadata = templateAuthRecord.Metadata + authRecord.JoinAttrs = templateAuthRecord.JoinAttrs } // An empty bot instance most likely means a bot is rejoining after an @@ -493,6 +494,7 @@ func (a *Server) generateInitialBotCerts( expires time.Time, renewable bool, initialAuth *machineidv1pb.BotInstanceStatusAuthentication, existingInstanceID string, currentIdentityGeneration int32, + joinAttrs *workloadidentityv1pb.JoinAttrs, ) (*proto.Certs, string, error) { var err error @@ -535,16 +537,17 @@ func (a *Server) generateInitialBotCerts( // Generate certificate certReq := certRequest{ - user: userState, - ttl: expires.Sub(a.GetClock().Now()), - sshPublicKey: sshPubKey, - tlsPublicKey: tlsPubKey, - checker: checker, - traits: accessInfo.Traits, - renewable: renewable, - includeHostCA: true, - loginIP: loginIP, - botName: botName, + user: userState, + ttl: expires.Sub(a.GetClock().Now()), + sshPublicKey: sshPubKey, + tlsPublicKey: tlsPubKey, + checker: checker, + traits: accessInfo.Traits, + renewable: renewable, + includeHostCA: true, + loginIP: loginIP, + botName: botName, + joinAttributes: joinAttrs, } if existingInstanceID == "" { diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index ae4ddb14136b9..2e019ffa7123e 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -42,6 +42,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/grpc" + "google.golang.org/protobuf/testing/protocmp" "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" @@ -49,10 +50,12 @@ import ( "github.com/gravitational/teleport/api/client/webclient" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/integrations/lib/testing/fakejoin" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" @@ -216,6 +219,146 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) { } } +// TestBotJoinAttrs_Kubernetes validates that a bot can join using the +// Kubernetes join method and that the correct join attributes are encoded in +// the resulting bot cert, and, that when this cert is used to produce role +// certificates, the correct attributes are encoded in the role cert. +// +// Whilst this specifically tests the Kubernetes join method, it tests by proxy +// the implementation for most of the join methods. +func TestBotJoinAttrs_Kubernetes(t *testing.T) { + t.Parallel() + + srv := newTestTLSServer(t) + ctx := context.Background() + + role, err := CreateRole(ctx, srv.Auth(), "example", types.RoleSpecV6{}) + require.NoError(t, err) + + // Create a new bot. + client, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + bot, err := client.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ + Bot: &machineidv1pb.Bot{ + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{"example"}, + }, + }, + }) + require.NoError(t, err) + + k8s, err := fakejoin.NewKubernetesSigner(srv.Clock()) + require.NoError(t, err) + jwks, err := k8s.GetMarshaledJWKS() + require.NoError(t, err) + fakePSAT, err := k8s.SignServiceAccountJWT( + "my-pod", + "my-namespace", + "my-service-account", + srv.ClusterName(), + ) + require.NoError(t, err) + + tok, err := types.NewProvisionTokenFromSpec( + "my-k8s-token", + time.Time{}, + types.ProvisionTokenSpecV2{ + Roles: types.SystemRoles{types.RoleBot}, + JoinMethod: types.JoinMethodKubernetes, + BotName: bot.Metadata.Name, + Kubernetes: &types.ProvisionTokenSpecV2Kubernetes{ + Type: types.KubernetesJoinTypeStaticJWKS, + StaticJWKS: &types.ProvisionTokenSpecV2Kubernetes_StaticJWKSConfig{ + JWKS: jwks, + }, + Allow: []*types.ProvisionTokenSpecV2Kubernetes_Rule{ + { + ServiceAccount: "my-namespace:my-service-account", + }, + }, + }, + }, + ) + require.NoError(t, err) + require.NoError(t, client.CreateToken(ctx, tok)) + + result, err := join.Register(ctx, join.RegisterParams{ + Token: tok.GetName(), + JoinMethod: types.JoinMethodKubernetes, + ID: state.IdentityID{ + Role: types.RoleBot, + }, + AuthServers: []utils.NetAddr{*utils.MustParseAddr(srv.Addr().String())}, + KubernetesReadFileFunc: func(name string) ([]byte, error) { + return []byte(fakePSAT), nil + }, + }) + require.NoError(t, err) + + // Validate correct join attributes are encoded. + cert, err := tlsca.ParseCertificatePEM(result.Certs.TLS) + require.NoError(t, err) + ident, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + wantAttrs := &workloadidentityv1pb.JoinAttrs{ + Meta: &workloadidentityv1pb.JoinAttrsMeta{ + JoinTokenName: tok.GetName(), + JoinMethod: string(types.JoinMethodKubernetes), + }, + Kubernetes: &workloadidentityv1pb.JoinAttrsKubernetes{ + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Namespace: "my-namespace", + Name: "my-service-account", + }, + Pod: &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: "my-pod", + }, + Subject: "system:serviceaccount:my-namespace:my-service-account", + }, + } + require.Empty(t, cmp.Diff( + ident.JoinAttributes, + wantAttrs, + protocmp.Transform(), + )) + + // Now, try to produce a role certificate using the bot cert, to ensure + // that the join attributes are correctly propagated. + privateKeyPEM, err := keys.MarshalPrivateKey(result.PrivateKey) + require.NoError(t, err) + tlsCert, err := tls.X509KeyPair(result.Certs.TLS, privateKeyPEM) + require.NoError(t, err) + sshPub, err := ssh.NewPublicKey(result.PrivateKey.Public()) + require.NoError(t, err) + tlsPub, err := keys.MarshalPublicKey(result.PrivateKey.Public()) + require.NoError(t, err) + botClient := srv.NewClientWithCert(tlsCert) + roleCerts, err := botClient.GenerateUserCerts(ctx, proto.UserCertsRequest{ + SSHPublicKey: ssh.MarshalAuthorizedKey(sshPub), + TLSPublicKey: tlsPub, + Username: bot.Status.UserName, + RoleRequests: []string{ + role.GetName(), + }, + UseRoleRequests: true, + Expires: srv.Clock().Now().Add(time.Hour), + }) + require.NoError(t, err) + + roleCert, err := tlsca.ParseCertificatePEM(roleCerts.TLS) + require.NoError(t, err) + roleIdent, err := tlsca.FromSubject(roleCert.Subject, roleCert.NotAfter) + require.NoError(t, err) + require.Empty(t, cmp.Diff( + roleIdent.JoinAttributes, + wantAttrs, + protocmp.Transform(), + )) +} + // TestRegisterBotInstance tests that bot instances are created properly on join func TestRegisterBotInstance(t *testing.T) { t.Parallel() @@ -282,7 +425,6 @@ func TestRegisterBotInstance(t *testing.T) { require.Equal(t, int32(1), ia.Generation) require.Equal(t, string(types.JoinMethodToken), ia.JoinMethod) require.Equal(t, token.GetSafeName(), ia.JoinToken) - // The latest authentications field should contain the same record (and // only that record.) require.Len(t, botInstance.GetStatus().LatestAuthentications, 1) diff --git a/lib/auth/join.go b/lib/auth/join.go index 00d4f8847f1e9..ad92db5eb3a0d 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -22,6 +22,7 @@ import ( "context" "crypto/rand" "encoding/base64" + "encoding/json" "log/slog" "net" "slices" @@ -34,6 +35,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" @@ -104,12 +106,6 @@ func (a *Server) checkTokenJoinRequestCommon(ctx context.Context, req *types.Reg return provisionToken, nil } -type joinAttributeSourcer interface { - // JoinAuditAttributes returns a series of attributes that can be inserted into - // audit events related to a specific join. - JoinAuditAttributes() (map[string]interface{}, error) -} - func setRemoteAddrFromContext(ctx context.Context, req *types.RegisterUsingTokenRequest) error { var addr string if clientIP, err := authz.ClientSrcAddrFromContext(ctx); err == nil { @@ -132,7 +128,7 @@ func (a *Server) handleJoinFailure( ctx context.Context, origErr error, pt types.ProvisionToken, - attributeSource joinAttributeSourcer, + rawJoinAttrs any, req *types.RegisterUsingTokenRequest, ) { attrs := []slog.Attr{slog.Any("error", origErr)} @@ -145,19 +141,13 @@ func (a *Server) handleJoinFailure( }...) } - // Fetch and encode attributes if they are available. - var attributesProto *apievents.Struct - if attributeSource != nil { - var err error - attributes, err := attributeSource.JoinAuditAttributes() - if err != nil { - a.logger.WarnContext(ctx, "Unable to fetch join attributes from join method", "error", err) - } - attrs = append(attrs, slog.Any("attributes", attributes)) - attributesProto, err = apievents.EncodeMap(attributes) - if err != nil { - a.logger.WarnContext(ctx, "Unable to encode join attributes for audit event", "error", err) - } + // Fetch and encode rawJoinAttrs if they are available. + attributesStruct, err := rawJoinAttrsToStruct(rawJoinAttrs) + if err != nil { + a.logger.WarnContext(ctx, "Unable to fetch join attributes from join method", "error", err) + } + if attributesStruct != nil { + attrs = append(attrs, slog.Any("attributes", attributesStruct)) } // Add log fields from token if available. @@ -179,7 +169,7 @@ func (a *Server) handleJoinFailure( Code: events.BotJoinFailureCode, }, Status: status, - Attributes: attributesProto, + Attributes: attributesStruct, ConnectionMetadata: apievents.ConnectionMetadata{ RemoteAddr: req.RemoteAddr, }, @@ -197,7 +187,7 @@ func (a *Server) handleJoinFailure( Code: events.InstanceJoinFailureCode, }, Status: status, - Attributes: attributesProto, + Attributes: attributesStruct, } if pt != nil { instanceJoinEvent.Method = string(pt.GetJoinMethod()) @@ -228,12 +218,13 @@ func (a *Server) handleJoinFailure( // If the token includes a specific join method, the rules for that join method // will be checked. func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (certs *proto.Certs, err error) { - var joinAttributeSrc joinAttributeSourcer + attrs := &workloadidentityv1pb.JoinAttrs{} + var rawClaims any var provisionToken types.ProvisionToken defer func() { // Emit a log message and audit event on join failure. if err != nil { - a.handleJoinFailure(ctx, err, provisionToken, joinAttributeSrc, req) + a.handleJoinFailure(ctx, err, provisionToken, rawClaims, req) } }() @@ -255,7 +246,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGitHub: claims, err := a.checkGitHubJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.Github = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -263,7 +255,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGitLab: claims, err := a.checkGitLabJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.Gitlab = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -271,7 +264,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodCircleCI: claims, err := a.checkCircleCIJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.Circleci = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -279,7 +273,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodKubernetes: claims, err := a.checkKubernetesJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.Kubernetes = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -287,7 +282,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGCP: claims, err := a.checkGCPJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.Gcp = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -295,7 +291,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodSpacelift: claims, err := a.checkSpaceliftJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.Spacelift = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -303,7 +300,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodTerraformCloud: claims, err := a.checkTerraformCloudJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.TerraformCloud = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -311,7 +309,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodBitbucket: claims, err := a.checkBitbucketJoinRequest(ctx, req) if claims != nil { - joinAttributeSrc = claims + rawClaims = claims + attrs.Bitbucket = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -334,10 +333,16 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin // With all elements of the token validated, we can now generate & return // certificates. if req.Role == types.RoleBot { - certs, err = a.generateCertsBot(ctx, provisionToken, req, joinAttributeSrc) + certs, err = a.generateCertsBot( + ctx, + provisionToken, + req, + rawClaims, + attrs, + ) return certs, trace.Wrap(err) } - certs, err = a.generateCerts(ctx, provisionToken, req, joinAttributeSrc) + certs, err = a.generateCerts(ctx, provisionToken, req, rawClaims) return certs, trace.Wrap(err) } @@ -345,7 +350,8 @@ func (a *Server) generateCertsBot( ctx context.Context, provisionToken types.ProvisionToken, req *types.RegisterUsingTokenRequest, - joinAttributeSrc joinAttributeSourcer, + rawJoinClaims any, + attrs *workloadidentityv1pb.JoinAttrs, ) (*proto.Certs, error) { // bots use this endpoint but get a user cert // botResourceName must be set, enforced in CheckAndSetDefaults @@ -393,6 +399,27 @@ func (a *Server) generateCertsBot( RemoteAddr: req.RemoteAddr, }, } + var err error + joinEvent.Attributes, err = rawJoinAttrsToStruct(rawJoinClaims) + if err != nil { + a.logger.WarnContext( + ctx, + "Unable to encode join attributes for join audit event", + "error", err, + ) + } + + // Prepare join attributes for encoding into the X509 cert and for inclusion + // in audit logs. + if attrs == nil { + attrs = &workloadidentityv1pb.JoinAttrs{} + } + attrs.Meta = &workloadidentityv1pb.JoinAttrsMeta{ + JoinMethod: string(joinMethod), + } + if joinMethod != types.JoinMethodToken { + attrs.Meta.JoinTokenName = provisionToken.GetName() + } auth := &machineidv1pb.BotInstanceStatusAuthentication{ AuthenticatedAt: timestamppb.New(a.GetClock().Now()), @@ -404,22 +431,13 @@ func (a *Server) generateCertsBot( // TODO(nklaassen): consider logging the SSH public key as well, for now // the SSH and TLS public keys are still identical for tbot. PublicKey: req.PublicTLSKey, + JoinAttrs: attrs, } - if joinAttributeSrc != nil { - attributes, err := joinAttributeSrc.JoinAuditAttributes() - if err != nil { - a.logger.WarnContext(ctx, "Unable to fetch join attributes from join method", "error", err) - } - joinEvent.Attributes, err = apievents.EncodeMap(attributes) - if err != nil { - a.logger.WarnContext(ctx, "Unable to encode join attributes for audit event", "error", err) - } - - auth.Metadata, err = structpb.NewStruct(attributes) - if err != nil { - a.logger.WarnContext(ctx, "Unable to encode struct value for join metadata", "error", err) - } + // TODO(noah): In v19, we can drop writing to the deprecated Metadata field. + auth.Metadata, err = rawJoinAttrsToGoogleStruct(rawJoinClaims) + if err != nil { + a.logger.WarnContext(ctx, "Unable to encode struct value for join metadata", "error", err) } certs, botInstanceID, err := a.generateInitialBotCerts( @@ -434,6 +452,7 @@ func (a *Server) generateCertsBot( auth, req.BotInstanceID, req.BotGeneration, + attrs, ) if err != nil { return nil, trace.Wrap(err) @@ -465,7 +484,7 @@ func (a *Server) generateCerts( ctx context.Context, provisionToken types.ProvisionToken, req *types.RegisterUsingTokenRequest, - joinAttributeSrc joinAttributeSourcer, + rawJoinClaims any, ) (*proto.Certs, error) { if req.Expires != nil { return nil, trace.BadParameter("'expires' cannot be set on join for non-bot certificates") @@ -534,15 +553,9 @@ func (a *Server) generateCerts( RemoteAddr: req.RemoteAddr, }, } - if joinAttributeSrc != nil { - attributes, err := joinAttributeSrc.JoinAuditAttributes() - if err != nil { - a.logger.WarnContext(ctx, "Unable to fetch join attributes from join method", "error", err) - } - joinEvent.Attributes, err = apievents.EncodeMap(attributes) - if err != nil { - a.logger.WarnContext(ctx, "Unable to encode join attributes for audit event", "error", err) - } + joinEvent.Attributes, err = rawJoinAttrsToStruct(rawJoinClaims) + if err != nil { + a.logger.WarnContext(ctx, "Unable to fetch join attributes from join method", "error", err) } if err := a.emitter.EmitAuditEvent(ctx, joinEvent); err != nil { a.logger.WarnContext(ctx, "Failed to emit instance join event", "error", err) @@ -550,6 +563,36 @@ func (a *Server) generateCerts( return certs, nil } +func rawJoinAttrsToStruct(in any) (*apievents.Struct, error) { + if in == nil { + return nil, nil + } + attrBytes, err := json.Marshal(in) + if err != nil { + return nil, trace.Wrap(err, "marshaling join attributes") + } + out := &apievents.Struct{} + if err := out.UnmarshalJSON(attrBytes); err != nil { + return nil, trace.Wrap(err, "unmarshaling join attributes") + } + return out, nil +} + +func rawJoinAttrsToGoogleStruct(in any) (*structpb.Struct, error) { + if in == nil { + return nil, nil + } + attrBytes, err := json.Marshal(in) + if err != nil { + return nil, trace.Wrap(err, "marshaling join attributes") + } + out := &structpb.Struct{} + if err := out.UnmarshalJSON(attrBytes); err != nil { + return nil, trace.Wrap(err, "unmarshaling join attributes") + } + return out, nil +} + func generateChallenge(encoding *base64.Encoding, length int) (string, error) { // read crypto-random bytes to generate the challenge challengeRawBytes := make([]byte, length) diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 70ae17918b7fa..df5a1632e05e0 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/utils" @@ -312,37 +313,49 @@ func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup return false } -func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *proto.RegisterUsingAzureMethodRequest, cfg *azureRegisterConfig) error { +func azureJoinToAttrs(vm *azure.VirtualMachine) *workloadidentityv1pb.JoinAttrsAzure { + return &workloadidentityv1pb.JoinAttrsAzure{ + Subscription: vm.Subscription, + ResourceGroup: vm.ResourceGroup, + } +} + +func (a *Server) checkAzureRequest( + ctx context.Context, + challenge string, + req *proto.RegisterUsingAzureMethodRequest, + cfg *azureRegisterConfig, +) (*workloadidentityv1pb.JoinAttrsAzure, error) { requestStart := a.clock.Now() tokenName := req.RegisterUsingTokenRequest.Token provisionToken, err := a.GetToken(ctx, tokenName) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } if provisionToken.GetJoinMethod() != types.JoinMethodAzure { - return trace.AccessDenied("this token does not support the Azure join method") + return nil, trace.AccessDenied("this token does not support the Azure join method") + } + token, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken) } subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart) if err != nil { - return trace.Wrap(err) - } - - token, ok := provisionToken.(*types.ProvisionTokenV2) - if !ok { - return trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken) + return nil, trace.Wrap(err) } + attrs := azureJoinToAttrs(vm) if err := checkAzureAllowRules(vm, token.GetName(), token.Spec.Azure.Allow); err != nil { - return trace.Wrap(err) + return attrs, trace.Wrap(err) } - return nil + return attrs, nil } func generateAzureChallenge() (string, error) { @@ -397,7 +410,8 @@ func (a *Server) RegisterUsingAzureMethodWithOpts( return nil, trace.Wrap(err) } - if err := a.checkAzureRequest(ctx, challenge, req, cfg); err != nil { + joinAttrs, err := a.checkAzureRequest(ctx, challenge, req, cfg) + if err != nil { return nil, trace.Wrap(err) } @@ -407,6 +421,9 @@ func (a *Server) RegisterUsingAzureMethodWithOpts( provisionToken, req.RegisterUsingTokenRequest, nil, + &workloadidentityv1pb.JoinAttrs{ + Azure: joinAttrs, + }, ) return certs, trace.Wrap(err) } diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index 7b284733bee1c..9ecfedd07bebd 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/join/iam" "github.com/gravitational/teleport/lib/utils" @@ -172,6 +173,18 @@ type awsIdentity struct { Arn string `json:"Arn"` } +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *awsIdentity) JoinAttrs() *workloadidentityv1pb.JoinAttrsAWSIAM { + attrs := &workloadidentityv1pb.JoinAttrsAWSIAM{ + Account: c.Account, + Arn: c.Arn, + } + + return attrs +} + // getCallerIdentityReponse is used for JSON parsing type getCallerIdentityResponse struct { GetCallerIdentityResult awsIdentity `json:"GetCallerIdentityResult"` @@ -260,41 +273,48 @@ func checkIAMAllowRules(identity *awsIdentity, token string, allowRules []*types // checkIAMRequest checks if the given request satisfies the token rules and // included the required challenge. -func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *proto.RegisterUsingIAMMethodRequest, cfg *iamRegisterConfig) error { +// +// If the joining entity presents a valid IAM identity, this will be returned, +// even if the identity does not match the token's allow rules. This is to +// support inclusion in audit logs. +func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *proto.RegisterUsingIAMMethodRequest, cfg *iamRegisterConfig) (*awsIdentity, error) { tokenName := req.RegisterUsingTokenRequest.Token provisionToken, err := a.GetToken(ctx, tokenName) if err != nil { - return trace.Wrap(err, "getting token") + return nil, trace.Wrap(err, "getting token") } if provisionToken.GetJoinMethod() != types.JoinMethodIAM { - return trace.AccessDenied("this token does not support the IAM join method") + return nil, trace.AccessDenied("this token does not support the IAM join method") } // parse the incoming http request to the sts:GetCallerIdentity endpoint identityRequest, err := parseSTSRequest(req.StsIdentityRequest) if err != nil { - return trace.Wrap(err, "parsing STS request") + return nil, trace.Wrap(err, "parsing STS request") } // validate that the host, method, and headers are correct and the expected // challenge is included in the signed portion of the request if err := validateSTSIdentityRequest(identityRequest, challenge, cfg); err != nil { - return trace.Wrap(err, "validating STS request") + return nil, trace.Wrap(err, "validating STS request") } // send the signed request to the public AWS API and get the node identity // from the response identity, err := executeSTSIdentityRequest(ctx, a.httpClientForAWSSTS, identityRequest) if err != nil { - return trace.Wrap(err, "executing STS request") + return nil, trace.Wrap(err, "executing STS request") } // check that the node identity matches an allow rule for this token if err := checkIAMAllowRules(identity, provisionToken.GetName(), provisionToken.GetAllowRules()); err != nil { - return trace.Wrap(err, "checking allow rules") + // We return the identity since it's "validated" but does not match the + // rules. This allows us to include it in a failed join audit event + // as additional context to help the user understand why the join failed. + return identity, trace.Wrap(err, "checking allow rules") } - return nil + return identity, nil } func generateIAMChallenge() (string, error) { @@ -341,10 +361,13 @@ func (a *Server) RegisterUsingIAMMethodWithOpts( ) (certs *proto.Certs, err error) { var provisionToken types.ProvisionToken var joinRequest *types.RegisterUsingTokenRequest + var joinFailureMetadata any defer func() { // Emit a log message and audit event on join failure. if err != nil { - a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest) + a.handleJoinFailure( + ctx, err, provisionToken, joinFailureMetadata, joinRequest, + ) } }() @@ -375,15 +398,27 @@ func (a *Server) RegisterUsingIAMMethodWithOpts( } // check that the GetCallerIdentity request is valid and matches the token - if err := a.checkIAMRequest(ctx, challenge, req, cfg); err != nil { + verifiedIdentity, err := a.checkIAMRequest(ctx, challenge, req, cfg) + if verifiedIdentity != nil { + joinFailureMetadata = verifiedIdentity + } + if err != nil { return nil, trace.Wrap(err, "checking iam request") } if req.RegisterUsingTokenRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) + certs, err := a.generateCertsBot( + ctx, + provisionToken, + req.RegisterUsingTokenRequest, + verifiedIdentity, + &workloadidentityv1pb.JoinAttrs{ + Iam: verifiedIdentity.JoinAttrs(), + }, + ) return certs, trace.Wrap(err, "generating bot certs") } - certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) + certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, verifiedIdentity) return certs, trace.Wrap(err, "generating certs") } diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go index 12463e8ecd811..df2e6b4e4cbcc 100644 --- a/lib/auth/join_tpm.go +++ b/lib/auth/join_tpm.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/tpm" @@ -39,11 +40,13 @@ func (a *Server) RegisterUsingTPMMethod( solveChallenge client.RegisterTPMChallengeResponseFunc, ) (_ *proto.Certs, err error) { var provisionToken types.ProvisionToken - var attributeSrc joinAttributeSourcer + var joinFailureMetadata any defer func() { // Emit a log message and audit event on join failure. if err != nil { - a.handleJoinFailure(ctx, err, provisionToken, attributeSrc, initReq.JoinRequest) + a.handleJoinFailure( + ctx, err, provisionToken, joinFailureMetadata, initReq.JoinRequest, + ) } }() @@ -97,10 +100,12 @@ func (a *Server) RegisterUsingTPMMethod( return solution.Solution, nil }, }) + if validatedEK != nil { + joinFailureMetadata = validatedEK + } if err != nil { return nil, trace.Wrap(err, "validating TPM EK") } - attributeSrc = validatedEK if err := checkTPMAllowRules(validatedEK, ptv2.Spec.TPM.Allow); err != nil { return nil, trace.Wrap(err) @@ -108,7 +113,13 @@ func (a *Server) RegisterUsingTPMMethod( if initReq.JoinRequest.Role == types.RoleBot { certs, err := a.generateCertsBot( - ctx, ptv2, initReq.JoinRequest, validatedEK, + ctx, + ptv2, + initReq.JoinRequest, + validatedEK, + &workloadidentityv1pb.JoinAttrs{ + Tpm: validatedEK.JoinAttrs(), + }, ) return certs, trace.Wrap(err, "generating certs for bot") } diff --git a/lib/auth/machineid/workloadidentityv1/decision_test.go b/lib/auth/machineid/workloadidentityv1/decision_test.go index e8cb267bb0879..5d00bf7595669 100644 --- a/lib/auth/machineid/workloadidentityv1/decision_test.go +++ b/lib/auth/machineid/workloadidentityv1/decision_test.go @@ -95,6 +95,23 @@ func Test_getFieldStringValue(t *testing.T) { want: "jeff", requireErr: require.NoError, }, + { + // This test ensures that the proto name (e.g service_account) is + // used instead of the Go name (e.g serviceAccount). + name: "underscored", + in: &workloadidentityv1pb.Attrs{ + Join: &workloadidentityv1pb.JoinAttrs{ + Kubernetes: &workloadidentityv1pb.JoinAttrsKubernetes{ + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Namespace: "default", + }, + }, + }, + }, + attr: "join.kubernetes.service_account.namespace", + want: "default", + requireErr: require.NoError, + }, { name: "bool", in: &workloadidentityv1pb.Attrs{ diff --git a/lib/auth/machineid/workloadidentityv1/issuer_service.go b/lib/auth/machineid/workloadidentityv1/issuer_service.go index eb75befe32b0b..6842ae01632ec 100644 --- a/lib/auth/machineid/workloadidentityv1/issuer_service.go +++ b/lib/auth/machineid/workloadidentityv1/issuer_service.go @@ -135,6 +135,7 @@ func (s *IssuanceService) deriveAttrs( BotName: authzCtx.Identity.GetIdentity().BotName, Labels: authzCtx.User.GetAllLabels(), }, + Join: authzCtx.Identity.GetIdentity().JoinAttributes, } return attrs, nil diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go index 3f615c2749c89..e5f23dc96216c 100644 --- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go +++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go @@ -19,6 +19,7 @@ package workloadidentityv1_test import ( "context" "crypto" + "crypto/tls" "crypto/x509" "errors" "fmt" @@ -34,23 +35,32 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/durationpb" + apiproto "github.com/gravitational/teleport/api/client/proto" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + machineidv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" + apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/integrations/lib/testing/fakejoin" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment" + "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/cryptosuites" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" libjwt "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" ) func TestMain(m *testing.M) { @@ -137,6 +147,193 @@ func newIssuanceTestPack(t *testing.T, ctx context.Context) *issuanceTestPack { } } +// TestIssueWorkloadIdentityE2E performs a more E2E test than the RPC specific +// tests in this package. The idea is to validate that the various Auth Server +// APIs necessary for a bot to join and then issue a workload identity are +// functioning correctly. +func TestIssueWorkloadIdentityE2E(t *testing.T) { + experimentStatus := experiment.Enabled() + defer experiment.SetEnabled(experimentStatus) + experiment.SetEnabled(true) + + ctx := context.Background() + tp := newIssuanceTestPack(t, ctx) + + role, err := types.NewRole("my-role", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + types.NewRule(types.KindWorkloadIdentity, []string{types.VerbRead, types.VerbList}), + }, + WorkloadIdentityLabels: map[string]apiutils.Strings{ + "my-label": []string{"my-value"}, + }, + }, + }) + require.NoError(t, err) + + wid, err := tp.srv.Auth().CreateWorkloadIdentity(ctx, &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "my-wid", + Labels: map[string]string{ + "my-label": "my-value", + }, + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "join.kubernetes.service_account.namespace", + Equals: "my-namespace", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example/{{ user.name }}/{{ join.kubernetes.service_account.namespace }}/{{ join.kubernetes.pod.name }}/{{ workload.unix.pid }}", + }, + }, + }) + require.NoError(t, err) + + bot := &machineidv1.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "my-bot", + }, + Spec: &machineidv1.BotSpec{ + Roles: []string{ + role.GetName(), + }, + }, + } + + k8s, err := fakejoin.NewKubernetesSigner(tp.clock) + require.NoError(t, err) + jwks, err := k8s.GetMarshaledJWKS() + require.NoError(t, err) + fakePSAT, err := k8s.SignServiceAccountJWT( + "my-pod", + "my-namespace", + "my-service-account", + tp.srv.ClusterName(), + ) + require.NoError(t, err) + + token, err := types.NewProvisionTokenFromSpec( + "my-k8s-token", + time.Time{}, + types.ProvisionTokenSpecV2{ + Roles: types.SystemRoles{types.RoleBot}, + JoinMethod: types.JoinMethodKubernetes, + BotName: bot.Metadata.Name, + Kubernetes: &types.ProvisionTokenSpecV2Kubernetes{ + Type: types.KubernetesJoinTypeStaticJWKS, + StaticJWKS: &types.ProvisionTokenSpecV2Kubernetes_StaticJWKSConfig{ + JWKS: jwks, + }, + Allow: []*types.ProvisionTokenSpecV2Kubernetes_Rule{ + { + ServiceAccount: "my-namespace:my-service-account", + }, + }, + }, + }, + ) + require.NoError(t, err) + + adminClient, err := tp.srv.NewClient(auth.TestAdmin()) + require.NoError(t, err) + _, err = adminClient.CreateRole(ctx, role) + require.NoError(t, err) + _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1.CreateBotRequest{ + Bot: bot, + }) + require.NoError(t, err) + err = adminClient.CreateToken(ctx, token) + require.NoError(t, err) + + // With the basic setup complete, we can now "fake" a join. + botCerts, err := join.Register(ctx, join.RegisterParams{ + Token: token.GetName(), + JoinMethod: types.JoinMethodKubernetes, + ID: state.IdentityID{ + Role: types.RoleBot, + }, + AuthServers: []utils.NetAddr{*utils.MustParseAddr(tp.srv.Addr().String())}, + KubernetesReadFileFunc: func(name string) ([]byte, error) { + return []byte(fakePSAT), nil + }, + }) + require.NoError(t, err) + + // We now have to actually impersonate the role cert to be able to issue + // a workload identity. + privateKeyPEM, err := keys.MarshalPrivateKey(botCerts.PrivateKey) + require.NoError(t, err) + tlsCert, err := tls.X509KeyPair(botCerts.Certs.TLS, privateKeyPEM) + require.NoError(t, err) + sshPub, err := ssh.NewPublicKey(botCerts.PrivateKey.Public()) + require.NoError(t, err) + tlsPub, err := keys.MarshalPublicKey(botCerts.PrivateKey.Public()) + require.NoError(t, err) + botClient := tp.srv.NewClientWithCert(tlsCert) + certs, err := botClient.GenerateUserCerts(ctx, apiproto.UserCertsRequest{ + SSHPublicKey: ssh.MarshalAuthorizedKey(sshPub), + TLSPublicKey: tlsPub, + Username: "bot-my-bot", + RoleRequests: []string{ + role.GetName(), + }, + UseRoleRequests: true, + Expires: tp.clock.Now().Add(time.Hour), + }) + require.NoError(t, err) + roleTLSCert, err := tls.X509KeyPair(certs.TLS, privateKeyPEM) + require.NoError(t, err) + roleClient := tp.srv.NewClientWithCert(roleTLSCert) + + // Generate a keypair to generate x509 SVIDs for. + workloadKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.ECDSAP256) + require.NoError(t, err) + workloadKeyPubBytes, err := x509.MarshalPKIXPublicKey(workloadKey.Public()) + require.NoError(t, err) + // Finally, we can request the issuance of a SVID + c := workloadidentityv1pb.NewWorkloadIdentityIssuanceServiceClient( + roleClient.GetConnection(), + ) + res, err := c.IssueWorkloadIdentity(ctx, &workloadidentityv1pb.IssueWorkloadIdentityRequest{ + Name: wid.Metadata.Name, + WorkloadAttrs: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ + Pid: 123, + }, + }, + Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_X509SvidParams{ + X509SvidParams: &workloadidentityv1pb.X509SVIDParams{ + PublicKey: workloadKeyPubBytes, + }, + }, + }) + require.NoError(t, err) + + // Perform a minimal validation of the returned credential - enough to prove + // that the returned value is a valid SVID with the SPIFFE ID we expect. + // Other tests in this package validate this more fully. + x509SVID := res.GetCredential().GetX509Svid() + require.NotNil(t, x509SVID) + cert, err := x509.ParseCertificate(x509SVID.GetCert()) + require.NoError(t, err) + // Check included public key matches + require.Equal(t, workloadKey.Public(), cert.PublicKey) + require.Equal(t, "spiffe://localhost/example/bot-my-bot/my-namespace/my-pod/123", cert.URIs[0].String()) +} + func TestIssueWorkloadIdentity(t *testing.T) { experimentStatus := experiment.Enabled() defer experiment.SetEnabled(experimentStatus) diff --git a/lib/bitbucket/bitbucket.go b/lib/bitbucket/bitbucket.go index ee9923337f9e8..653d724c1a971 100644 --- a/lib/bitbucket/bitbucket.go +++ b/lib/bitbucket/bitbucket.go @@ -19,8 +19,7 @@ package bitbucket import ( - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // IDTokenClaims @@ -60,19 +59,17 @@ type IDTokenClaims struct { BranchName string `json:"branchName"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]any, error) { - res := map[string]any{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsBitbucket { + return &workloadidentityv1pb.JoinAttrsBitbucket{ + Sub: c.Sub, + StepUuid: c.StepUUID, + RepositoryUuid: c.RepositoryUUID, + PipelineUuid: c.PipelineUUID, + WorkspaceUuid: c.WorkspaceUUID, + DeploymentEnvironmentUuid: c.DeploymentEnvironmentUUID, + BranchName: c.BranchName, } - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil } diff --git a/lib/circleci/circleci.go b/lib/circleci/circleci.go index 0f0c351c5eae3..ef796322d5220 100644 --- a/lib/circleci/circleci.go +++ b/lib/circleci/circleci.go @@ -32,8 +32,7 @@ package circleci import ( "fmt" - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) const IssuerURLTemplate = "https://oidc.circleci.com/org/%s" @@ -55,20 +54,13 @@ type IDTokenClaims struct { ProjectID string `json:"oidc.circleci.com/project-id"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsCircleCI { + return &workloadidentityv1pb.JoinAttrsCircleCI{ + Sub: c.Sub, + ContextIds: c.ContextIDs, + ProjectId: c.ProjectID, } - - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil } diff --git a/lib/gcp/gcp.go b/lib/gcp/gcp.go index 4fd77ca6a4f52..a1ab7eb9daafa 100644 --- a/lib/gcp/gcp.go +++ b/lib/gcp/gcp.go @@ -19,8 +19,7 @@ package gcp import ( - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // defaultIssuerHost is the issuer for GCP ID tokens. @@ -52,20 +51,21 @@ type IDTokenClaims struct { Google Google `json:"google"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsGCP { + attrs := &workloadidentityv1pb.JoinAttrsGCP{ + ServiceAccount: c.Email, } - - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) + if c.Google.ComputeEngine.InstanceName != "" { + attrs.Gce = &workloadidentityv1pb.JoinAttrsGCPGCE{ + Project: c.Google.ComputeEngine.ProjectID, + Zone: c.Google.ComputeEngine.Zone, + Id: c.Google.ComputeEngine.InstanceID, + Name: c.Google.ComputeEngine.InstanceName, + } } - return res, nil + + return attrs } diff --git a/lib/githubactions/githubactions.go b/lib/githubactions/githubactions.go index f2921a9636d18..c2642904c6990 100644 --- a/lib/githubactions/githubactions.go +++ b/lib/githubactions/githubactions.go @@ -19,8 +19,7 @@ package githubactions import ( - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // GitHub Workload Identity @@ -101,20 +100,23 @@ type IDTokenClaims struct { Workflow string `json:"workflow"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsGitHub { + attrs := &workloadidentityv1pb.JoinAttrsGitHub{ + Sub: c.Sub, + Actor: c.Actor, + Environment: c.Environment, + Ref: c.Ref, + RefType: c.RefType, + Repository: c.Repository, + RepositoryOwner: c.RepositoryOwner, + Workflow: c.Workflow, + EventName: c.EventName, + Sha: c.SHA, + RunId: c.RunID, } - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil + return attrs } diff --git a/lib/gitlab/gitlab.go b/lib/gitlab/gitlab.go index 1129e6509d6c3..9daf1c4a68d8d 100644 --- a/lib/gitlab/gitlab.go +++ b/lib/gitlab/gitlab.go @@ -19,8 +19,7 @@ package gitlab import ( - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // GitLab Workload Identity @@ -112,20 +111,28 @@ type IDTokenClaims struct { ProjectVisibility string `json:"project_visibility"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsGitLab { + attrs := &workloadidentityv1pb.JoinAttrsGitLab{ + Sub: c.Sub, + Ref: c.Ref, + RefType: c.RefType, + RefProtected: c.RefProtected == "true", + NamespacePath: c.NamespacePath, + ProjectPath: c.ProjectPath, + UserLogin: c.UserLogin, + UserEmail: c.UserEmail, + PipelineId: c.PipelineID, + Environment: c.Environment, + EnvironmentProtected: c.EnvironmentProtected == "true", + RunnerId: int64(c.RunnerID), + RunnerEnvironment: c.RunnerEnvironment, + Sha: c.SHA, + CiConfigRefUri: c.CIConfigRefURI, + CiConfigSha: c.CIConfigSHA, } - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil + return attrs } diff --git a/lib/kube/token/validator.go b/lib/kube/token/validator.go index 056b5ee1def0d..0d88af8d46735 100644 --- a/lib/kube/token/validator.go +++ b/lib/kube/token/validator.go @@ -29,13 +29,13 @@ import ( "github.com/go-jose/go-jose/v3" josejwt "github.com/go-jose/go-jose/v3/jwt" "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" v1 "k8s.io/api/authentication/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/version" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" ) @@ -60,24 +60,14 @@ type ValidationResult struct { // This will be prepended with `system:serviceaccount:` for service // accounts. Username string `json:"username"` + attrs *workloadidentityv1pb.JoinAttrsKubernetes } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *ValidationResult) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - Squash: true, - }) - if err != nil { - return nil, trace.Wrap(err) - } - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *ValidationResult) JoinAttrs() *workloadidentityv1pb.JoinAttrsKubernetes { + return c.attrs } // TokenReviewValidator validates a Kubernetes Service Account JWT using the @@ -180,8 +170,11 @@ func (v *TokenReviewValidator) Validate(ctx context.Context, token, clusterName // Check the Username is a service account. // A user token would not match rules anyway, but we can produce a more relevant error message here. - if !strings.HasPrefix(reviewResult.Status.User.Username, ServiceAccountNamePrefix) { - return nil, trace.BadParameter("token user is not a service account: %s", reviewResult.Status.User.Username) + namespace, serviceAccount, err := serviceAccountFromUsername( + reviewResult.Status.User.Username, + ) + if err != nil { + return nil, trace.Wrap(err) } if !slices.Contains(reviewResult.Status.User.Groups, serviceAccountGroup) { @@ -203,20 +196,47 @@ func (v *TokenReviewValidator) Validate(ctx context.Context, token, clusterName // We know if the token is bound to a pod if its name is in the Extra userInfo. // If the token is not bound while Kubernetes supports bound tokens we abort. - if _, ok := reviewResult.Status.User.Extra[extraDataPodNameField]; !ok && boundTokenSupport { + podName, podNamePresent := reviewResult.Status.User.Extra[extraDataPodNameField] + if !podNamePresent && boundTokenSupport { return nil, trace.BadParameter( "legacy SA tokens are not accepted as kubernetes version %s supports bound tokens", kubeVersion.String(), ) } + attrs := &workloadidentityv1pb.JoinAttrsKubernetes{ + Subject: reviewResult.Status.User.Username, + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Name: serviceAccount, + Namespace: namespace, + }, + } + if podNamePresent && len(podName) == 1 { + attrs.Pod = &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: podName[0], + } + } + return &ValidationResult{ Raw: reviewResult.Status, Type: types.KubernetesJoinTypeInCluster, Username: reviewResult.Status.User.Username, + attrs: attrs, }, nil } +func serviceAccountFromUsername(username string) (namespace, name string, err error) { + cut, hasPrefix := strings.CutPrefix(username, ServiceAccountNamePrefix+":") + if !hasPrefix { + return "", "", trace.BadParameter("token user is not a service account: %s", username) + } + parts := strings.Split(cut, ":") + if len(parts) != 2 { + return "", "", trace.BadParameter("token user has malformed service account name: %s", username) + } + return parts[0], parts[1], nil +} + func kubernetesSupportsBoundTokens(gitVersion string) (bool, error) { kubeVersion, err := version.ParseSemantic(gitVersion) if err != nil { @@ -319,5 +339,15 @@ func ValidateTokenWithJWKS( Raw: claims, Type: types.KubernetesJoinTypeStaticJWKS, Username: claims.Subject, + attrs: &workloadidentityv1pb.JoinAttrsKubernetes{ + Subject: claims.Subject, + Pod: &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: claims.Kubernetes.Pod.Name, + }, + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Name: claims.Kubernetes.ServiceAccount.Name, + Namespace: claims.Kubernetes.Namespace, + }, + }, }, nil } diff --git a/lib/kube/token/validator_test.go b/lib/kube/token/validator_test.go index 49054df1e47ee..70d68fddb766d 100644 --- a/lib/kube/token/validator_test.go +++ b/lib/kube/token/validator_test.go @@ -26,9 +26,12 @@ import ( "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" v1 "k8s.io/api/authentication/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/version" @@ -37,6 +40,7 @@ import ( "k8s.io/client-go/kubernetes/fake" ctest "k8s.io/client-go/testing" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cryptosuites" ) @@ -168,6 +172,7 @@ func TestIDTokenValidator_Validate(t *testing.T) { review *v1.TokenReview kubeVersion *version.Info wantResult *ValidationResult + wantAttrs *workloadidentityv1pb.JoinAttrsKubernetes clusterAudiences []string expectedAudiences []string expectedError error @@ -196,6 +201,16 @@ func TestIDTokenValidator_Validate(t *testing.T) { Username: "system:serviceaccount:namespace:my-service-account", // Raw will be filled in during test run to value of review }, + wantAttrs: &workloadidentityv1pb.JoinAttrsKubernetes{ + Subject: "system:serviceaccount:namespace:my-service-account", + Pod: &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: "podA", + }, + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Name: "my-service-account", + Namespace: "namespace", + }, + }, kubeVersion: &boundTokenKubernetesVersion, expectedError: nil, // As the cluster doesn't have default audiences, we should not set @@ -226,6 +241,16 @@ func TestIDTokenValidator_Validate(t *testing.T) { Username: "system:serviceaccount:namespace:my-service-account", // Raw will be filled in during test run to value of review }, + wantAttrs: &workloadidentityv1pb.JoinAttrsKubernetes{ + Subject: "system:serviceaccount:namespace:my-service-account", + Pod: &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: "podA", + }, + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Name: "my-service-account", + Namespace: "namespace", + }, + }, kubeVersion: &boundTokenKubernetesVersion, expectedError: nil, clusterAudiences: defaultKubeAudiences, @@ -253,6 +278,13 @@ func TestIDTokenValidator_Validate(t *testing.T) { Username: "system:serviceaccount:namespace:my-service-account", // Raw will be filled in during test run to value of review }, + wantAttrs: &workloadidentityv1pb.JoinAttrsKubernetes{ + Subject: "system:serviceaccount:namespace:my-service-account", + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Name: "my-service-account", + Namespace: "namespace", + }, + }, kubeVersion: &legacyTokenKubernetesVersion, expectedError: nil, }, @@ -352,7 +384,19 @@ func TestIDTokenValidator_Validate(t *testing.T) { return } require.NoError(t, err) - require.Equal(t, tt.wantResult, result) + require.Empty(t, cmp.Diff( + tt.wantResult, + result, + cmpopts.IgnoreUnexported(ValidationResult{}), + )) + if tt.wantAttrs != nil { + gotAttrs := result.JoinAttrs() + require.Empty(t, cmp.Diff( + tt.wantAttrs, + gotAttrs, + protocmp.Transform(), + )) + } }) } } @@ -440,6 +484,7 @@ func TestValidateTokenWithJWKS(t *testing.T) { claims ServiceAccountClaims wantResult *ValidationResult + wantAttrs *workloadidentityv1pb.JoinAttrsKubernetes wantErr string }{ { @@ -459,6 +504,16 @@ func TestValidateTokenWithJWKS(t *testing.T) { Type: types.KubernetesJoinTypeStaticJWKS, Username: "system:serviceaccount:default:my-service-account", }, + wantAttrs: &workloadidentityv1pb.JoinAttrsKubernetes{ + Subject: "system:serviceaccount:default:my-service-account", + Pod: &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: "my-pod-797959fdf-wptbj", + }, + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Name: "my-service-account", + Namespace: "default", + }, + }, }, { name: "missing bound pod claim", @@ -607,7 +662,19 @@ func TestValidateTokenWithJWKS(t *testing.T) { return } require.NoError(t, err) - require.Equal(t, tt.wantResult, result) + require.Empty(t, cmp.Diff( + tt.wantResult, + result, + cmpopts.IgnoreUnexported(ValidationResult{}), + )) + if tt.wantAttrs != nil { + gotAttrs := result.JoinAttrs() + require.Empty(t, cmp.Diff( + tt.wantAttrs, + gotAttrs, + protocmp.Transform(), + )) + } }) } } diff --git a/lib/spacelift/spacelift.go b/lib/spacelift/spacelift.go index ddaba2f11cfd2..413620e324ae4 100644 --- a/lib/spacelift/spacelift.go +++ b/lib/spacelift/spacelift.go @@ -19,8 +19,7 @@ package spacelift import ( - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // IDTokenClaims @@ -49,20 +48,17 @@ type IDTokenClaims struct { Scope string `json:"scope"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsSpacelift { + return &workloadidentityv1pb.JoinAttrsSpacelift{ + Sub: c.Sub, + SpaceId: c.SpaceID, + CallerType: c.CallerType, + CallerId: c.CallerID, + RunType: c.RunType, + RunId: c.RunID, + Scope: c.Scope, } - - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil } diff --git a/lib/terraformcloud/terraform.go b/lib/terraformcloud/terraform.go index ded2340c2e5d1..c9db802130ae2 100644 --- a/lib/terraformcloud/terraform.go +++ b/lib/terraformcloud/terraform.go @@ -19,8 +19,7 @@ package terraformcloud import ( - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // IDTokenClaims @@ -52,20 +51,17 @@ type IDTokenClaims struct { RunPhase string `json:"terraform_run_phase"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsTerraformCloud { + return &workloadidentityv1pb.JoinAttrsTerraformCloud{ + Sub: c.Sub, + OrganizationName: c.OrganizationName, + ProjectName: c.ProjectName, + WorkspaceName: c.WorkspaceName, + FullWorkspace: c.FullWorkspace, + RunId: c.RunID, + RunPhase: c.RunPhase, } - - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil } diff --git a/lib/tlsca/ca.go b/lib/tlsca/ca.go index 3edde794e5860..a7e6ad24e39e4 100644 --- a/lib/tlsca/ca.go +++ b/lib/tlsca/ca.go @@ -36,8 +36,10 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "google.golang.org/protobuf/encoding/protojson" "github.com/gravitational/teleport" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" @@ -203,6 +205,10 @@ type Identity struct { // UserType indicates if the User was created by an SSO Provider or locally. UserType types.UserType + + // JoinAttributes holds the attributes that resulted from the + // Bot/Agent join process. + JoinAttributes *workloadidentityv1pb.JoinAttrs } // RouteToApp holds routing information for applications. @@ -556,6 +562,10 @@ var ( // BotInstanceASN1ExtensionOID is an extension that encodes a unique bot // instance identifier into a certificate. BotInstanceASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 20} + + // JoinAttributesASN1ExtensionOID is an extension that encodes the + // attributes that resulted from the Bot/Agent join process. + JoinAttributesASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 21} ) // Device Trust OIDs. @@ -895,6 +905,24 @@ func (id *Identity) Subject() (pkix.Name, error) { ) } + if id.JoinAttributes != nil { + encoded, err := protojson.MarshalOptions{ + // Use the proto field names as this is what we use in the + // templating engine and this being consistent for any user who + // inspects the cert is kind. + UseProtoNames: true, + }.Marshal(id.JoinAttributes) + if err != nil { + return pkix.Name{}, trace.Wrap(err, "encoding join attributes as protojson") + } + subject.ExtraNames = append(subject.ExtraNames, + pkix.AttributeTypeAndValue{ + Type: JoinAttributesASN1ExtensionOID, + Value: string(encoded), + }, + ) + } + // Device extensions. if devID := id.DeviceExtensions.DeviceID; devID != "" { subject.ExtraNames = append(subject.ExtraNames, pkix.AttributeTypeAndValue{ @@ -1158,6 +1186,19 @@ func FromSubject(subject pkix.Name, expires time.Time) (*Identity, error) { if val, ok := attr.Value.(string); ok { id.UserType = types.UserType(val) } + case attr.Type.Equal(JoinAttributesASN1ExtensionOID): + if val, ok := attr.Value.(string); ok { + id.JoinAttributes = &workloadidentityv1pb.JoinAttrs{} + unmarshaler := protojson.UnmarshalOptions{ + // We specifically want to DiscardUnknown or unmarshaling + // will fail if the proto message was issued by a newer + // auth server w/ new fields. + DiscardUnknown: true, + } + if err := unmarshaler.Unmarshal([]byte(val), id.JoinAttributes); err != nil { + return nil, trace.Wrap(err) + } + } } } diff --git a/lib/tlsca/ca_test.go b/lib/tlsca/ca_test.go index 022facef5d0cf..50295f1e7bcf9 100644 --- a/lib/tlsca/ca_test.go +++ b/lib/tlsca/ca_test.go @@ -34,8 +34,10 @@ import ( "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" "github.com/gravitational/teleport" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/cryptosuites" @@ -154,6 +156,58 @@ func TestRenewableIdentity(t *testing.T) { require.True(t, parsed.Renewable) } +func TestJoinAttributes(t *testing.T) { + t.Parallel() + + clock := clockwork.NewFakeClock() + expires := clock.Now().Add(1 * time.Hour) + + ca, err := FromKeys([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) + require.NoError(t, err) + + privateKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.ECDSAP256) + require.NoError(t, err) + + identity := Identity{ + Username: "bot-bernard", + Groups: []string{"bot-bernard"}, + BotName: "bernard", + BotInstanceID: "1234-5678", + Expires: expires, + JoinAttributes: &workloadidentityv1pb.JoinAttrs{ + Kubernetes: &workloadidentityv1pb.JoinAttrsKubernetes{ + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Namespace: "default", + Name: "foo", + }, + Pod: &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: "bar", + }, + }, + }, + } + + subj, err := identity.Subject() + require.NoError(t, err) + require.NotNil(t, subj) + + certBytes, err := ca.GenerateCertificate(CertificateRequest{ + Clock: clock, + PublicKey: privateKey.Public(), + Subject: subj, + NotAfter: expires, + }) + require.NoError(t, err) + + cert, err := ParseCertificatePEM(certBytes) + require.NoError(t, err) + + parsed, err := FromSubject(cert.Subject, expires) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Empty(t, cmp.Diff(parsed, &identity, protocmp.Transform())) +} + // TestKubeExtensions test ASN1 subject kubernetes extensions func TestKubeExtensions(t *testing.T) { clock := clockwork.NewFakeClock() diff --git a/lib/tpm/validate.go b/lib/tpm/validate.go index 268857d35e4ff..126133d31e644 100644 --- a/lib/tpm/validate.go +++ b/lib/tpm/validate.go @@ -27,6 +27,8 @@ import ( "github.com/google/go-attestation/attest" "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // ValidateParams are the parameters required to validate a TPM. @@ -63,14 +65,17 @@ type ValidatedTPM struct { EKCertVerified bool `json:"ek_cert_verified"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *ValidatedTPM) JoinAuditAttributes() (map[string]interface{}, error) { - return map[string]interface{}{ - "ek_pub_hash": c.EKPubHash, - "ek_cert_serial": c.EKCertSerial, - "ek_cert_verified": c.EKCertVerified, - }, nil +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *ValidatedTPM) JoinAttrs() *workloadidentityv1pb.JoinAttrsTPM { + attrs := &workloadidentityv1pb.JoinAttrsTPM{ + EkPubHash: c.EKPubHash, + EkCertSerial: c.EKCertSerial, + EkCertVerified: c.EKCertVerified, + } + + return attrs } // Validate takes the parameters from a remote TPM and performs the necessary diff --git a/tool/tctl/common/bots_command.go b/tool/tctl/common/bots_command.go index 1cd290cb1bcd2..fa8ffbf7861cd 100644 --- a/tool/tctl/common/bots_command.go +++ b/tool/tctl/common/bots_command.go @@ -588,7 +588,10 @@ func (c *BotsCommand) ListBotInstances(ctx context.Context, client *authclient.C ) joined := i.Status.InitialAuthentication.AuthenticatedAt.AsTime().Format(time.RFC3339) - initialJoinMethod := i.Status.InitialAuthentication.JoinMethod + initialJoinMethod := cmp.Or( + i.Status.InitialAuthentication.GetJoinAttrs().GetMeta().GetJoinMethod(), + i.Status.InitialAuthentication.JoinMethod, + ) lastSeen := i.Status.InitialAuthentication.AuthenticatedAt.AsTime() @@ -599,8 +602,12 @@ func (c *BotsCommand) ListBotInstances(ctx context.Context, client *authclient.C generation = fmt.Sprint(auth.Generation) - if auth.JoinMethod == initialJoinMethod { - joinMethod = auth.JoinMethod + authJM := cmp.Or( + auth.GetJoinAttrs().GetMeta().GetJoinMethod(), + auth.JoinMethod, + ) + if authJM == initialJoinMethod { + joinMethod = authJM } else { // If the join method changed, show the original method and latest joinMethod = fmt.Sprintf("%s (%s)", auth.JoinMethod, initialJoinMethod) @@ -844,9 +851,13 @@ func splitEntries(flag string) []string { func formatBotInstanceAuthentication(record *machineidv1pb.BotInstanceStatusAuthentication) string { table := asciitable.MakeHeadlessTable(2) table.AddRow([]string{"Authenticated At:", record.AuthenticatedAt.AsTime().Format(time.RFC3339)}) - table.AddRow([]string{"Join Method:", record.JoinMethod}) - table.AddRow([]string{"Join Token:", record.JoinToken}) - table.AddRow([]string{"Join Metadata:", record.Metadata.String()}) + table.AddRow([]string{"Join Method:", cmp.Or(record.GetJoinAttrs().GetMeta().GetJoinMethod(), record.JoinMethod)}) + table.AddRow([]string{"Join Token:", cmp.Or(record.GetJoinAttrs().GetMeta().GetJoinTokenName(), record.JoinToken)}) + var meta fmt.Stringer = record.Metadata + if attrs := record.GetJoinAttrs(); attrs != nil { + meta = attrs + } + table.AddRow([]string{"Join Metadata:", meta.String()}) table.AddRow([]string{"Generation:", fmt.Sprint(record.Generation)}) table.AddRow([]string{"Public Key:", fmt.Sprintf("<%d bytes>", len(record.PublicKey))})