Skip to content

Commit

Permalink
Workload ID: Add WorkloadIdentity local service and cache config (#49942
Browse files Browse the repository at this point in the history
) (#49988)

* Add WorkloadIdentity store and cache

* Update lib/services/local/workload_identity.go



* Update lib/services/local/workload_identity.go



* Update lib/cache/resource_workload_identity.go



---------

Co-authored-by: Edward Dowling <[email protected]>
Co-authored-by: Edoardo Spadolini <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent e51c708 commit 025be5d
Show file tree
Hide file tree
Showing 16 changed files with 1,121 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/auth/accesspoint/accesspoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ type Config struct {
Users services.UsersService
WebSession types.WebSessionInterface
WebToken types.WebTokenInterface
WorkloadIdentity cache.WorkloadIdentityReader
DynamicWindowsDesktops services.DynamicWindowsDesktops
WindowsDesktops services.WindowsDesktops
AutoUpdateService services.AutoUpdateServiceGetter
Expand Down Expand Up @@ -201,6 +202,7 @@ func NewCache(cfg Config) (*cache.Cache, error) {
Users: cfg.Users,
WebSession: cfg.WebSession,
WebToken: cfg.WebToken,
WorkloadIdentity: cfg.WorkloadIdentity,
WindowsDesktops: cfg.WindowsDesktops,
DynamicWindowsDesktops: cfg.DynamicWindowsDesktops,
ProvisioningStates: cfg.ProvisioningStates,
Expand Down
9 changes: 9 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
return nil, trace.Wrap(err, "creating SPIFFEFederation service")
}
}
if cfg.WorkloadIdentity == nil {
workloadIdentity, err := local.NewWorkloadIdentityService(cfg.Backend)
if err != nil {
return nil, trace.Wrap(err, "creating WorkloadIdentity service")
}
cfg.WorkloadIdentity = workloadIdentity
}
if cfg.Logger == nil {
cfg.Logger = slog.With(teleport.ComponentKey, teleport.ComponentAuth)
}
Expand Down Expand Up @@ -486,6 +493,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
StaticHostUser: cfg.StaticHostUsers,
ProvisioningStates: cfg.ProvisioningStates,
IdentityCenter: cfg.IdentityCenter,
WorkloadIdentities: cfg.WorkloadIdentity,
}

as := Server{
Expand Down Expand Up @@ -703,6 +711,7 @@ type Services struct {
services.AutoUpdateService
services.ProvisioningStates
services.IdentityCenter
services.WorkloadIdentities
}

// GetWebSession returns existing web session described by req.
Expand Down
7 changes: 7 additions & 0 deletions lib/auth/authclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/discoveryconfig"
Expand Down Expand Up @@ -1229,6 +1230,12 @@ type Cache interface {
// pagination.
ListSPIFFEFederations(ctx context.Context, pageSize int, lastToken string) ([]*machineidv1.SPIFFEFederation, string, error)

// GetWorkloadIdentity gets a WorkloadIdentity by name.
GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error)
// ListWorkloadIdentities lists all SPIFFE Federations using Google style
// pagination.
ListWorkloadIdentities(ctx context.Context, pageSize int, lastToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)

// ListStaticHostUsers lists static host users.
ListStaticHostUsers(ctx context.Context, pageSize int, startKey string) ([]*userprovisioningpb.StaticHostUser, string, error)
// GetStaticHostUser returns a static host user by name.
Expand Down
1 change: 1 addition & 0 deletions lib/auth/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
SecReports: svces.SecReports,
SnowflakeSession: svces.Identity,
SPIFFEFederations: svces.SPIFFEFederations,
WorkloadIdentity: svces.WorkloadIdentities,
StaticHostUsers: svces.StaticHostUser,
Trust: svces.TrustInternal,
UserGroups: svces.UserGroups,
Expand Down
4 changes: 4 additions & 0 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ type InitConfig struct {
// SPIFFEFederations is a service that manages storing SPIFFE federations.
SPIFFEFederations services.SPIFFEFederations

// WorkloadIdentity is the service for storing and retrieving
// WorkloadIdentity resources.
WorkloadIdentity services.WorkloadIdentities

// StaticHostUsers is a service that manages host users that should be
// created on SSH nodes.
StaticHostUsers services.StaticHostUser
Expand Down
12 changes: 12 additions & 0 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ func ForAuth(cfg Config) Config {
{Kind: types.KindIdentityCenterAccount},
{Kind: types.KindIdentityCenterPrincipalAssignment},
{Kind: types.KindIdentityCenterAccountAssignment},
{Kind: types.KindWorkloadIdentity},
}
cfg.QueueSize = defaults.AuthQueueSize
// We don't want to enable partial health for auth cache because auth uses an event stream
Expand Down Expand Up @@ -550,6 +551,7 @@ type Cache struct {
staticHostUsersCache *local.StaticHostUserService
provisioningStatesCache *local.ProvisioningStateService
identityCenterCache *local.IdentityCenterService
workloadIdentityCache workloadIdentityCacher

// closed indicates that the cache has been closed
closed atomic.Bool
Expand Down Expand Up @@ -732,6 +734,9 @@ type Config struct {
SPIFFEFederations SPIFFEFederationReader
// StaticHostUsers is the static host user service.
StaticHostUsers services.StaticHostUser
// WorkloadIdentity is the upstream Workload Identities service that we're
// caching
WorkloadIdentity WorkloadIdentityReader
// Backend is a backend for local cache
Backend backend.Backend
// MaxRetryPeriod is the maximum period between cache retries on failures
Expand Down Expand Up @@ -998,6 +1003,12 @@ func New(config Config) (*Cache, error) {
return nil, trace.Wrap(err)
}

workloadIdentityCache, err := local.NewWorkloadIdentityService(config.Backend)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

staticHostUserCache, err := local.NewStaticHostUserService(config.Backend)
if err != nil {
cancel()
Expand Down Expand Up @@ -1070,6 +1081,7 @@ func New(config Config) (*Cache, error) {
staticHostUsersCache: staticHostUserCache,
provisioningStatesCache: provisioningStatesCache,
identityCenterCache: identityCenterCache,
workloadIdentityCache: workloadIdentityCache,
Logger: log.WithFields(log.Fields{
teleport.ComponentKey: config.Component,
}),
Expand Down
13 changes: 13 additions & 0 deletions lib/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ type testPack struct {
autoUpdateService services.AutoUpdateService
provisioningStates services.ProvisioningStates
identityCenter services.IdentityCenter
workloadIdentity *local.WorkloadIdentityService
}

// testFuncs are functions to support testing an object in a cache.
Expand Down Expand Up @@ -362,6 +363,12 @@ func newPackWithoutCache(dir string, opts ...packOption) (*testPack, error) {
}
p.spiffeFederations = spiffeFederationsSvc

workloadIdentitySvc, err := local.NewWorkloadIdentityService(p.backend)
if err != nil {
return nil, trace.Wrap(err)
}
p.workloadIdentity = workloadIdentitySvc

databaseObjectsSvc, err := local.NewDatabaseObjectService(p.backend)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -455,6 +462,7 @@ func newPack(dir string, setupConfig func(c Config) Config, opts ...packOption)
AutoUpdateService: p.autoUpdateService,
ProvisioningStates: p.provisioningStates,
IdentityCenter: p.identityCenter,
WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
Expand Down Expand Up @@ -866,6 +874,7 @@ func TestCompletenessInit(t *testing.T) {
StaticHostUsers: p.staticHostUsers,
AutoUpdateService: p.autoUpdateService,
ProvisioningStates: p.provisioningStates,
WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
IdentityCenter: p.identityCenter,
EventsC: p.eventsC,
Expand Down Expand Up @@ -951,6 +960,7 @@ func TestCompletenessReset(t *testing.T) {
AutoUpdateService: p.autoUpdateService,
ProvisioningStates: p.provisioningStates,
IdentityCenter: p.identityCenter,
WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
Expand Down Expand Up @@ -1161,6 +1171,7 @@ func TestListResources_NodesTTLVariant(t *testing.T) {
AutoUpdateService: p.autoUpdateService,
ProvisioningStates: p.provisioningStates,
IdentityCenter: p.identityCenter,
WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
neverOK: true, // ensure reads are never healthy
Expand Down Expand Up @@ -1256,6 +1267,7 @@ func initStrategy(t *testing.T) {
AutoUpdateService: p.autoUpdateService,
ProvisioningStates: p.provisioningStates,
IdentityCenter: p.identityCenter,
WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
Expand Down Expand Up @@ -3521,6 +3533,7 @@ func TestCacheWatchKindExistsInEvents(t *testing.T) {
types.KindIdentityCenterAccount: types.Resource153ToLegacy(newIdentityCenterAccount("some_account")),
types.KindIdentityCenterAccountAssignment: types.Resource153ToLegacy(newIdentityCenterAccountAssignment("some_account_assignment")),
types.KindIdentityCenterPrincipalAssignment: types.Resource153ToLegacy(newIdentityCenterPrincipalAssignment("some_principal_assignment")),
types.KindWorkloadIdentity: types.Resource153ToLegacy(newWorkloadIdentity("some_identifier")),
}

for name, cfg := range cases {
Expand Down
11 changes: 11 additions & 0 deletions lib/cache/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/discoveryconfig"
Expand Down Expand Up @@ -176,6 +177,7 @@ type cacheCollections struct {
identityCenterAccounts collectionReader[identityCenterAccountGetter]
identityCenterPrincipalAssignments collectionReader[identityCenterPrincipalAssignmentGetter]
identityCenterAccountAssignments collectionReader[identityCenterAccountAssignmentGetter]
workloadIdentity collectionReader[WorkloadIdentityReader]
}

// setupCollections returns a registry of collections.
Expand Down Expand Up @@ -704,6 +706,15 @@ func setupCollections(c *Cache, watches []types.WatchKind) (*cacheCollections, e
watch: watch,
}
collections.byKind[resourceKind] = collections.spiffeFederations
case types.KindWorkloadIdentity:
if c.Config.WorkloadIdentity == nil {
return nil, trace.BadParameter("missing parameter WorkloadIdentity")
}
collections.workloadIdentity = &genericCollection[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader, workloadIdentityExecutor]{
cache: c,
watch: watch,
}
collections.byKind[resourceKind] = collections.workloadIdentity
case types.KindAutoUpdateConfig:
if c.AutoUpdateService == nil {
return nil, trace.BadParameter("missing parameter AutoUpdateService")
Expand Down
119 changes: 119 additions & 0 deletions lib/cache/resource_workload_identity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

//nolint:unused // Because the executors generate a large amount of false positives.
package cache

import (
"context"

"github.com/gravitational/trace"

workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
)

// WorkloadIdentityReader is an interface that defines the methods for getting
// WorkloadIdentity. This is returned as the reader for the WorkloadIdentity
// collection but is also used by the executor to read the full list of
// WorkloadIdentity on initialization.
type WorkloadIdentityReader interface {
ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)
GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error)
}

// workloadIdentityCacher is used for storing and retrieving WorkloadIdentity
// from the cache's local backend.
type workloadIdentityCacher interface {
WorkloadIdentityReader
UpsertWorkloadIdentity(ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity) (*workloadidentityv1pb.WorkloadIdentity, error)
DeleteWorkloadIdentity(ctx context.Context, name string) error
DeleteAllWorkloadIdentities(ctx context.Context) error
}

type workloadIdentityExecutor struct{}

var _ executor[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader] = workloadIdentityExecutor{}

func (workloadIdentityExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*workloadidentityv1pb.WorkloadIdentity, error) {
var out []*workloadidentityv1pb.WorkloadIdentity
var nextToken string
for {
var page []*workloadidentityv1pb.WorkloadIdentity
var err error

const defaultPageSize = 0
page, nextToken, err = cache.Config.WorkloadIdentity.ListWorkloadIdentities(ctx, defaultPageSize, nextToken)
if err != nil {
return nil, trace.Wrap(err)
}
out = append(out, page...)
if nextToken == "" {
break
}
}
return out, nil
}

func (workloadIdentityExecutor) upsert(ctx context.Context, cache *Cache, resource *workloadidentityv1pb.WorkloadIdentity) error {
_, err := cache.workloadIdentityCache.UpsertWorkloadIdentity(ctx, resource)
return trace.Wrap(err)
}

func (workloadIdentityExecutor) deleteAll(ctx context.Context, cache *Cache) error {
return trace.Wrap(cache.workloadIdentityCache.DeleteAllWorkloadIdentities(ctx))
}

func (workloadIdentityExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
return trace.Wrap(cache.workloadIdentityCache.DeleteWorkloadIdentity(ctx, resource.GetName()))
}

func (workloadIdentityExecutor) isSingleton() bool { return false }

func (workloadIdentityExecutor) getReader(cache *Cache, cacheOK bool) WorkloadIdentityReader {
if cacheOK {
return cache.workloadIdentityCache
}
return cache.Config.WorkloadIdentity
}

// ListWorkloadIdentities returns a paginated list of WorkloadIdentity resources.
func (c *Cache) ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListWorkloadIdentities")
defer span.End()

rg, err := readCollectionCache(c, c.collections.workloadIdentity)
if err != nil {
return nil, "", trace.Wrap(err)
}
defer rg.Release()
out, nextKey, err := rg.reader.ListWorkloadIdentities(ctx, pageSize, nextToken)
return out, nextKey, trace.Wrap(err)
}

// GetWorkloadIdentity returns a single WorkloadIdentity by name
func (c *Cache) GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetWorkloadIdentity")
defer span.End()

rg, err := readCollectionCache(c, c.collections.workloadIdentity)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()
out, err := rg.reader.GetWorkloadIdentity(ctx, name)
return out, trace.Wrap(err)
}
Loading

0 comments on commit 025be5d

Please sign in to comment.