From 2445199447b762a82c0ffd19f96d9eb4c3809120 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Mon, 14 Oct 2024 18:05:32 -0400 Subject: [PATCH] Generalize resource watchers Consolidate resource watchers into a single watcher that leverages generics. While most of the resource watchers were similar, some resources have some one off functionality. These watchers have not been touched, however, all that could be refactored to use the generic watcher easily were. --- lib/kube/proxy/server.go | 13 +- lib/kube/proxy/utils_testing.go | 3 +- lib/kube/proxy/watcher.go | 6 +- lib/proxy/peer/client.go | 4 +- lib/proxy/router.go | 9 +- lib/proxy/router_test.go | 4 +- lib/reversetunnel/localsite.go | 23 +- lib/reversetunnel/localsite_test.go | 16 +- lib/reversetunnel/peer.go | 5 +- lib/reversetunnel/remotesite.go | 20 +- lib/reversetunnel/srv.go | 10 +- lib/reversetunnelclient/api.go | 3 +- lib/service/service.go | 1 + lib/services/readonly/readonly.go | 314 +++++ lib/services/watcher.go | 1738 ++++++++------------------- lib/services/watcher_test.go | 123 +- lib/srv/app/server.go | 3 +- lib/srv/app/watcher.go | 6 +- lib/srv/db/server.go | 3 +- lib/srv/db/watcher.go | 6 +- lib/srv/desktop/discovery.go | 5 +- lib/srv/discovery/discovery.go | 63 +- lib/srv/discovery/discovery_test.go | 7 +- lib/srv/regular/sshserver_test.go | 4 +- lib/utils/fncache.go | 2 + lib/web/apiserver.go | 10 +- lib/web/apiserver_test.go | 3 + 27 files changed, 1042 insertions(+), 1362 deletions(-) diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go index 580abc957795b..8c770c16bbdb8 100644 --- a/lib/kube/proxy/server.go +++ b/lib/kube/proxy/server.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/ingress" ) @@ -98,7 +99,7 @@ type TLSServerConfig struct { // kubernetes cluster name. Proxy uses this map to route requests to the correct // kubernetes_service. The servers are kept in memory to avoid making unnecessary // unmarshal calls followed by filtering and to improve memory usage. - KubernetesServersWatcher *services.KubeServerWatcher + KubernetesServersWatcher *services.GenericWatcher[types.KubeServer, readonly.KubeServer] // PROXYProtocolMode controls behavior related to unsigned PROXY protocol headers. PROXYProtocolMode multiplexer.PROXYProtocolMode // InventoryHandle is used to send kube server heartbeats via the inventory control stream. @@ -170,7 +171,7 @@ type TLSServer struct { closeContext context.Context closeFunc context.CancelFunc // kubeClusterWatcher monitors changes to kube cluster resources. - kubeClusterWatcher *services.KubeClusterWatcher + kubeClusterWatcher *services.GenericWatcher[types.KubeCluster, readonly.KubeCluster] // reconciler reconciles proxied kube clusters with kube_clusters resources. reconciler *services.Reconciler[types.KubeCluster] // monitoredKubeClusters contains all kube clusters the proxied kube_clusters are @@ -620,7 +621,9 @@ func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNa }, nil case ProxyService: return func(ctx context.Context, name string) ([]types.KubeServer, error) { - servers, err := t.KubernetesServersWatcher.GetKubeServersByClusterName(ctx, name) + servers, err := t.KubernetesServersWatcher.CurrentResourcesWithFilter(ctx, func(ks readonly.KubeServer) bool { + return ks.GetCluster().GetName() == name + }) return servers, trace.Wrap(err) }, nil case LegacyProxyService: @@ -630,7 +633,9 @@ func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNa // and forward the request to the next proxy. kube, err := t.getKubeClusterWithServiceLabels(name) if err != nil { - servers, err := t.KubernetesServersWatcher.GetKubeServersByClusterName(ctx, name) + servers, err := t.KubernetesServersWatcher.CurrentResourcesWithFilter(ctx, func(ks readonly.KubeServer) bool { + return ks.GetCluster().GetName() == name + }) return servers, trace.Wrap(err) } srv, err := types.NewKubernetesServerV3FromCluster(kube, "", t.HostID) diff --git a/lib/kube/proxy/utils_testing.go b/lib/kube/proxy/utils_testing.go index 4621b7d51bec1..462638df203c4 100644 --- a/lib/kube/proxy/utils_testing.go +++ b/lib/kube/proxy/utils_testing.go @@ -294,6 +294,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo Component: teleport.ComponentKube, Client: client, }, + KubernetesServerGetter: client, }, ) require.NoError(t, err) @@ -387,7 +388,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // Ensure watcher has the correct list of clusters. require.Eventually(t, func() bool { - kubeServers, err := kubeServersWatcher.GetKubernetesServers(ctx) + kubeServers, err := kubeServersWatcher.CurrentResources(ctx) return err == nil && len(kubeServers) == len(cfg.Clusters) }, 3*time.Second, time.Millisecond*100) diff --git a/lib/kube/proxy/watcher.go b/lib/kube/proxy/watcher.go index 047b0a4401f89..24e52e2d9c923 100644 --- a/lib/kube/proxy/watcher.go +++ b/lib/kube/proxy/watcher.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/utils" ) @@ -89,7 +90,7 @@ func (s *TLSServer) startReconciler(ctx context.Context) (err error) { // startKubeClusterResourceWatcher starts watching changes to Kube Clusters resources and // registers/unregisters the proxied Kube Cluster accordingly. -func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.KubeClusterWatcher, error) { +func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.KubeCluster, readonly.KubeCluster], error) { if len(s.ResourceMatchers) == 0 || s.KubeServiceType != KubeService { s.log.Debug("Not initializing Kube Cluster resource watcher.") return nil, nil @@ -102,6 +103,7 @@ func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*servi // Logger: s.log, Client: s.AccessPoint, }, + KubernetesClusterGetter: s.AccessPoint, }) if err != nil { return nil, trace.Wrap(err) @@ -110,7 +112,7 @@ func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*servi defer watcher.Close() for { select { - case clusters := <-watcher.KubeClustersC: + case clusters := <-watcher.ResourcesC: s.monitoredKubeClusters.setResources(clusters) select { case s.reconcileCh <- struct{}{}: diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index c00b67006875f..fe3659a92bf4a 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -51,6 +51,7 @@ import ( // AccessPoint is the subset of the auth cache consumed by the [Client]. type AccessPoint interface { types.Events + services.ProxyGetter } // ClientConfig configures a Client instance. @@ -416,6 +417,7 @@ func (c *Client) sync() { Client: c.config.AccessPoint, Logger: c.config.Log, }, + ProxyGetter: c.config.AccessPoint, ProxyDiffer: func(old, new types.Server) bool { return old.GetPeerAddr() != new.GetPeerAddr() }, @@ -434,7 +436,7 @@ func (c *Client) sync() { case <-proxyWatcher.Done(): c.config.Log.DebugContext(c.ctx, "stopping peer proxy sync: proxy watcher done") return - case proxies := <-proxyWatcher.ProxiesC: + case proxies := <-proxyWatcher.ResourcesC: if err := c.updateConnections(proxies); err != nil { c.config.Log.ErrorContext(c.ctx, "error syncing peer proxies", "error", err) } diff --git a/lib/proxy/router.go b/lib/proxy/router.go index 18e22adc798ee..f54f9718af604 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" ) @@ -383,7 +384,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check // site is the minimum interface needed to match servers // for a reversetunnelclient.RemoteSite. It makes testing easier. type site interface { - GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) + GetNodes(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) } @@ -394,13 +395,13 @@ type remoteSite struct { } // GetNodes uses the wrapped sites NodeWatcher to filter nodes -func (r remoteSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) { +func (r remoteSite) GetNodes(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) { watcher, err := r.site.NodeWatcher() if err != nil { return nil, trace.Wrap(err) } - return watcher.GetNodes(ctx, fn), nil + return watcher.CurrentResourcesWithFilter(ctx, fn) } // GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig @@ -450,7 +451,7 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re var maxScore int scores := make(map[string]int) - matches, err := site.GetNodes(ctx, func(server services.Node) bool { + matches, err := site.GetNodes(ctx, func(server readonly.Server) bool { score := routeMatcher.RouteToServerScore(server) if score < 1 { return false diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index 48268cf355961..177875fda2fd6 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -37,7 +37,7 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" ) @@ -51,7 +51,7 @@ func (t testSite) GetClusterNetworkingConfig(ctx context.Context) (types.Cluster return t.cfg, nil } -func (t testSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) { +func (t testSite) GetNodes(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) { var out []types.Server for _, s := range t.nodes { if fn(s) { diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 54e9d5db3a680..c6bdd82b3d83a 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" @@ -180,7 +181,7 @@ func (s *localSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, err } // NodeWatcher returns a services.NodeWatcher for this cluster. -func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) { +func (s *localSite) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { return s.srv.NodeWatcher, nil } @@ -739,7 +740,11 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch return case <-proxyResyncTicker.Chan(): var req discoveryRequest - req.SetProxies(s.srv.proxyWatcher.GetCurrent()) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + req.SetProxies(proxies) if err := rconn.sendDiscoveryRequest(req); err != nil { logger.WithError(err).Debug("Marking connection invalid on error") @@ -764,9 +769,12 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch if firstHeartbeat { // as soon as the agent connects and sends a first heartbeat // send it the list of current proxies back - current := s.srv.proxyWatcher.GetCurrent() - if len(current) > 0 { - rconn.updateProxies(current) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + if len(proxies) > 0 { + rconn.updateProxies(proxies) } reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Inc() firstHeartbeat = false @@ -935,7 +943,7 @@ func (s *localSite) periodicFunctions() { // sshTunnelStats reports SSH tunnel statistics for the cluster. func (s *localSite) sshTunnelStats() error { - missing := s.srv.NodeWatcher.GetNodes(s.srv.ctx, func(server services.Node) bool { + missing, err := s.srv.NodeWatcher.CurrentResourcesWithFilter(s.srv.ctx, func(server readonly.Server) bool { // Skip over any servers that have a TTL larger than announce TTL (10 // minutes) and are non-IoT SSH servers (they won't have tunnels). // @@ -967,6 +975,9 @@ func (s *localSite) sshTunnelStats() error { return err != nil }) + if err != nil { + return trace.Wrap(err) + } // Update Prometheus metrics and also log if any tunnels are missing. missingSSHTunnels.Set(float64(len(missing))) diff --git a/lib/reversetunnel/localsite_test.go b/lib/reversetunnel/localsite_test.go index 3397aed763683..195a1e76510c2 100644 --- a/lib/reversetunnel/localsite_test.go +++ b/lib/reversetunnel/localsite_test.go @@ -58,14 +58,16 @@ func TestRemoteConnCleanup(t *testing.T) { clock := clockwork.NewFakeClock() + clt := &mockLocalSiteClient{} watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: "test", Logger: utils.NewSlogLoggerForTests(), Clock: clock, - Client: &mockLocalSiteClient{}, + Client: clt, }, - ProxiesC: make(chan []types.Server, 2), + ProxyGetter: clt, + ProxiesC: make(chan []types.Server, 2), }) require.NoError(t, err) require.NoError(t, watcher.WaitInitialization()) @@ -249,17 +251,19 @@ func TestProxyResync(t *testing.T) { proxy2, err := types.NewServer(uuid.NewString(), types.KindProxy, types.ServerSpecV2{}) require.NoError(t, err) + clt := &mockLocalSiteClient{ + proxies: []types.Server{proxy1, proxy2}, + } // set up the watcher and wait for it to be initialized watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: "test", Logger: utils.NewSlogLoggerForTests(), Clock: clock, - Client: &mockLocalSiteClient{ - proxies: []types.Server{proxy1, proxy2}, - }, + Client: clt, }, - ProxiesC: make(chan []types.Server, 2), + ProxyGetter: clt, + ProxiesC: make(chan []types.Server, 2), }) require.NoError(t, err) require.NoError(t, watcher.WaitInitialization()) diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index fc16cbe11cefa..570be5edf4bbe 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" ) func newClusterPeers(clusterName string) *clusterPeers { @@ -90,7 +91,7 @@ func (p *clusterPeers) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, return peer.CachingAccessPoint() } -func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) { +func (p *clusterPeers) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { peer, err := p.pickPeer() if err != nil { return nil, trace.Wrap(err) @@ -202,7 +203,7 @@ func (s *clusterPeer) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, e return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) } -func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) { +func (s *clusterPeer) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { return nil, trace.ConnectionProblem(nil, "unable to fetch node watcher, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 8e8b7e4c3fe79..f9617f33b87d5 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -42,6 +42,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" @@ -85,7 +86,7 @@ type remoteSite struct { remoteAccessPoint authclient.RemoteProxyAccessPoint // nodeWatcher provides access the node set for the remote site - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, readonly.Server] // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation @@ -164,7 +165,7 @@ func (s *remoteSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, er } // NodeWatcher returns the services.NodeWatcher for the remote cluster. -func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) { +func (s *remoteSite) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { return s.nodeWatcher, nil } @@ -429,7 +430,11 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch return case <-proxyResyncTicker.Chan(): var req discoveryRequest - req.SetProxies(s.srv.proxyWatcher.GetCurrent()) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + req.SetProxies(proxies) if err := conn.sendDiscoveryRequest(req); err != nil { logger.WithError(err).Debug("Marking connection invalid on error") @@ -458,9 +463,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch if firstHeartbeat { // as soon as the agent connects and sends a first heartbeat // send it the list of current proxies back - current := s.srv.proxyWatcher.GetCurrent() - if len(current) > 0 { - conn.updateProxies(current) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + if len(proxies) > 0 { + conn.updateProxies(proxies) } firstHeartbeat = false } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 19dfd9e2d43ca..4cf45cf81a15f 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -49,6 +49,7 @@ import ( "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" @@ -114,7 +115,7 @@ type server struct { // proxyWatcher monitors changes to the proxies // and broadcasts updates - proxyWatcher *services.ProxyWatcher + proxyWatcher *services.GenericWatcher[types.Server, readonly.Server] // offlineThreshold is how long to wait for a keep alive message before // marking a reverse tunnel connection as invalid. @@ -201,7 +202,7 @@ type Config struct { LockWatcher *services.LockWatcher // NodeWatcher is a node watcher. - NodeWatcher *services.NodeWatcher + NodeWatcher *services.GenericWatcher[types.Server, readonly.Server] // CertAuthorityWatcher is a cert authority watcher. CertAuthorityWatcher *services.CertAuthorityWatcher @@ -307,9 +308,6 @@ func NewServer(cfg Config) (reversetunnelclient.Server, error) { }, ProxiesC: make(chan []types.Server, 10), ProxyGetter: cfg.LocalAccessPoint, - ProxyDiffer: func(_, _ types.Server) bool { - return true // we always want to store the most recently heartbeated proxy - }, }) if err != nil { cancel() @@ -401,7 +399,7 @@ func (s *server) periodicFunctions() { s.log.Debugf("Closing.") return // Proxies have been updated, notify connected agents about the update. - case proxies := <-s.proxyWatcher.ProxiesC: + case proxies := <-s.proxyWatcher.ResourcesC: s.fanOutProxies(proxies) case <-ticker.C: if err := s.fetchClusterPeers(); err != nil { diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index f7e8dfb47ef63..e044bf4beb012 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -31,6 +31,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/teleagent" ) @@ -123,7 +124,7 @@ type RemoteSite interface { // but is resilient to auth server crashes CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error) // NodeWatcher returns the node watcher that maintains the node set for the site - NodeWatcher() (*services.NodeWatcher, error) + NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/service/service.go b/lib/service/service.go index 01e4fdad21224..fd8262d5ba98c 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -5025,6 +5025,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Logger: process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), Client: accessPoint, }, + KubernetesServerGetter: accessPoint, }) if err != nil { return trace.Wrap(err) diff --git a/lib/services/readonly/readonly.go b/lib/services/readonly/readonly.go index db65197a4338a..c4ed3185ace66 100644 --- a/lib/services/readonly/readonly.go +++ b/lib/services/readonly/readonly.go @@ -137,3 +137,317 @@ func (a sealedAccessGraphSettings) SecretsScanConfig() clusterconfigpb.AccessGra func (a sealedAccessGraphSettings) Clone() *clusterconfigpb.AccessGraphSettings { return protobuf.Clone(a.AccessGraphSettings).(*clusterconfigpb.AccessGraphSettings) } + +// Resource is a read only variant of [types.Resource]. +type Resource interface { + // GetKind returns resource kind + GetKind() string + // GetSubKind returns resource subkind + GetSubKind() string + // GetVersion returns resource version + GetVersion() string + // GetName returns the name of the resource + GetName() string + // Expiry returns object expiry setting + Expiry() time.Time + // GetMetadata returns object metadata + GetMetadata() types.Metadata + // GetRevision returns the revision + GetRevision() string +} + +// ResourceWithOrigin is a read only variant of [types.ResourceWithOrigin]. +type ResourceWithOrigin interface { + Resource + // Origin returns the origin value of the resource. + Origin() string +} + +// ResourceWithLabels is a read only variant of [types.ResourceWithLabels]. +type ResourceWithLabels interface { + ResourceWithOrigin + // GetLabel retrieves the label with the provided key. + GetLabel(key string) (value string, ok bool) + // GetAllLabels returns all resource's labels. + GetAllLabels() map[string]string + // GetStaticLabels returns the resource's static labels. + GetStaticLabels() map[string]string + // MatchSearch goes through select field values of a resource + // and tries to match against the list of search values. + MatchSearch(searchValues []string) bool +} + +// Application is a read only variant of [types.Application]. +type Application interface { + // ResourceWithLabels provides common resource methods. + ResourceWithLabels + // GetNamespace returns the app namespace. + GetNamespace() string + // GetStaticLabels returns the app static labels. + GetStaticLabels() map[string]string + // GetDynamicLabels returns the app dynamic labels. + GetDynamicLabels() map[string]types.CommandLabel + // String returns string representation of the app. + String() string + // GetDescription returns the app description. + GetDescription() string + // GetURI returns the app connection endpoint. + GetURI() string + // GetPublicAddr returns the app public address. + GetPublicAddr() string + // GetInsecureSkipVerify returns the app insecure setting. + GetInsecureSkipVerify() bool + // GetRewrite returns the app rewrite configuration. + GetRewrite() *types.Rewrite + // IsAWSConsole returns true if this app is AWS management console. + IsAWSConsole() bool + // IsAzureCloud returns true if this app represents Azure Cloud instance. + IsAzureCloud() bool + // IsGCP returns true if this app represents GCP instance. + IsGCP() bool + // IsTCP returns true if this app represents a TCP endpoint. + IsTCP() bool + // GetProtocol returns the application protocol. + GetProtocol() string + // GetAWSAccountID returns value of label containing AWS account ID on this app. + GetAWSAccountID() string + // GetAWSExternalID returns the AWS External ID configured for this app. + GetAWSExternalID() string + // GetUserGroups will get the list of user group IDs associated with the application. + GetUserGroups() []string + // Copy returns a copy of this app resource. + Copy() *types.AppV3 + // GetIntegration will return the Integration. + // If present, the Application must use the Integration's credentials instead of ambient credentials to access Cloud APIs. + GetIntegration() string + // GetRequiredAppNames will return a list of required apps names that should be authenticated during this apps authentication process. + GetRequiredAppNames() []string + // GetCORS returns the CORS configuration for the app. + GetCORS() *types.CORSPolicy +} + +// KubeServer is a read only variant of [types.KubeServer]. +type KubeServer interface { + // ResourceWithLabels provides common resource methods. + ResourceWithLabels + // GetNamespace returns server namespace. + GetNamespace() string + // GetTeleportVersion returns the teleport version the server is running on. + GetTeleportVersion() string + // GetHostname returns the server hostname. + GetHostname() string + // GetHostID returns ID of the host the server is running on. + GetHostID() string + // GetRotation gets the state of certificate authority rotation. + GetRotation() types.Rotation + // String returns string representation of the server. + String() string + // Copy returns a copy of this kube server object. + Copy() types.KubeServer + // CloneResource returns a copy of the KubeServer as a ResourceWithLabels + CloneResource() types.ResourceWithLabels + // GetCluster returns the Kubernetes Cluster this kube server proxies. + GetCluster() types.KubeCluster + // GetProxyIDs returns a list of proxy ids this service is connected to. + GetProxyIDs() []string +} + +// KubeCluster is a read only variant of [types.KubeCluster]. +type KubeCluster interface { + // ResourceWithLabels provides common resource methods. + ResourceWithLabels + // GetNamespace returns the kube cluster namespace. + GetNamespace() string + // GetStaticLabels returns the kube cluster static labels. + GetStaticLabels() map[string]string + // GetDynamicLabels returns the kube cluster dynamic labels. + GetDynamicLabels() map[string]types.CommandLabel + // GetKubeconfig returns the kubeconfig payload. + GetKubeconfig() []byte + // String returns string representation of the kube cluster. + String() string + // GetDescription returns the kube cluster description. + GetDescription() string + // GetAzureConfig gets the Azure config. + GetAzureConfig() types.KubeAzure + // GetAWSConfig gets the AWS config. + GetAWSConfig() types.KubeAWS + // GetGCPConfig gets the GCP config. + GetGCPConfig() types.KubeGCP + // IsAzure indentifies if the KubeCluster contains Azure details. + IsAzure() bool + // IsAWS indentifies if the KubeCluster contains AWS details. + IsAWS() bool + // IsGCP indentifies if the KubeCluster contains GCP details. + IsGCP() bool + // IsKubeconfig identifies if the KubeCluster contains kubeconfig data. + IsKubeconfig() bool + // Copy returns a copy of this kube cluster resource. + Copy() *types.KubernetesClusterV3 + // GetCloud gets the cloud this kube cluster is running on, or an empty string if it + // isn't running on a cloud provider. + GetCloud() string +} + +// Database is a read only variant of [types.Database]. +type Database interface { + // ResourceWithLabels provides common resource methods. + ResourceWithLabels + // GetNamespace returns the database namespace. + GetNamespace() string + // GetStaticLabels returns the database static labels. + GetStaticLabels() map[string]string + // GetDynamicLabels returns the database dynamic labels. + GetDynamicLabels() map[string]types.CommandLabel + // String returns string representation of the database. + String() string + // GetDescription returns the database description. + GetDescription() string + // GetProtocol returns the database protocol. + GetProtocol() string + // GetURI returns the database connection endpoint. + GetURI() string + // GetCA returns the database CA certificate. + GetCA() string + // GetTLS returns the database TLS configuration. + GetTLS() types.DatabaseTLS + // GetStatusCA gets the database CA certificate in the status field. + GetStatusCA() string + // GetMySQL returns the database options from spec. + GetMySQL() types.MySQLOptions + // GetOracle returns the database options from spec. + GetOracle() types.OracleOptions + // GetMySQLServerVersion returns the MySQL server version either from configuration or + // reported by the database. + GetMySQLServerVersion() string + // GetAWS returns the database AWS metadata. + GetAWS() types.AWS + // GetGCP returns GCP information for Cloud SQL databases. + GetGCP() types.GCPCloudSQL + // GetAzure returns Azure database server metadata. + GetAzure() types.Azure + // GetAD returns Active Directory database configuration. + GetAD() types.AD + // GetType returns the database authentication type: self-hosted, RDS, Redshift or Cloud SQL. + GetType() string + // GetSecretStore returns secret store configurations. + GetSecretStore() types.SecretStore + // GetManagedUsers returns a list of database users that are managed by Teleport. + GetManagedUsers() []string + // GetMongoAtlas returns Mongo Atlas database metadata. + GetMongoAtlas() types.MongoAtlas + // IsRDS returns true if this is an RDS/Aurora database. + IsRDS() bool + // IsRDSProxy returns true if this is an RDS Proxy database. + IsRDSProxy() bool + // IsRedshift returns true if this is a Redshift database. + IsRedshift() bool + // IsCloudSQL returns true if this is a Cloud SQL database. + IsCloudSQL() bool + // IsAzure returns true if this is an Azure database. + IsAzure() bool + // IsElastiCache returns true if this is an AWS ElastiCache database. + IsElastiCache() bool + // IsMemoryDB returns true if this is an AWS MemoryDB database. + IsMemoryDB() bool + // IsAWSHosted returns true if database is hosted by AWS. + IsAWSHosted() bool + // IsCloudHosted returns true if database is hosted in the cloud (AWS, Azure or Cloud SQL). + IsCloudHosted() bool + // RequireAWSIAMRolesAsUsers returns true for database types that require + // AWS IAM roles as database users. + RequireAWSIAMRolesAsUsers() bool + // SupportAWSIAMRoleARNAsUsers returns true for database types that support + // AWS IAM roles as database users. + SupportAWSIAMRoleARNAsUsers() bool + // Copy returns a copy of this database resource. + Copy() *types.DatabaseV3 + // GetAdminUser returns database privileged user information. + GetAdminUser() types.DatabaseAdminUser + // SupportsAutoUsers returns true if this database supports automatic + // user provisioning. + SupportsAutoUsers() bool + // GetEndpointType returns the endpoint type of the database, if available. + GetEndpointType() string + // GetCloud gets the cloud this database is running on, or an empty string if it + // isn't running on a cloud provider. + GetCloud() string + // IsUsernameCaseInsensitive returns true if the database username is case + // insensitive. + IsUsernameCaseInsensitive() bool +} + +// Server is a read only variant of [types.Server]. +type Server interface { + // ResourceWithLabels provides common resource headers + ResourceWithLabels + // GetTeleportVersion returns the teleport version the server is running on + GetTeleportVersion() string + // GetAddr return server address + GetAddr() string + // GetHostname returns server hostname + GetHostname() string + // GetNamespace returns server namespace + GetNamespace() string + // GetLabels returns server's static label key pairs + GetLabels() map[string]string + // GetCmdLabels gets command labels + GetCmdLabels() map[string]types.CommandLabel + // GetPublicAddr returns a public address where this server can be reached. + GetPublicAddr() string + // GetPublicAddrs returns a list of public addresses where this server can be reached. + GetPublicAddrs() []string + // GetRotation gets the state of certificate authority rotation. + GetRotation() types.Rotation + // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. + GetUseTunnel() bool + // String returns string representation of the server + String() string + // GetPeerAddr returns the peer address of the server. + GetPeerAddr() string + // GetProxyIDs returns a list of proxy ids this service is connected to. + GetProxyIDs() []string + // DeepCopy creates a clone of this server value + DeepCopy() types.Server + + // CloneResource is used to return a clone of the Server and match the CloneAny interface + // This is helpful when interfacing with multiple types at the same time in unified resources + CloneResource() types.ResourceWithLabels + + // GetCloudMetadata gets the cloud metadata for the server. + GetCloudMetadata() *types.CloudMetadata + // GetAWSInfo returns the AWSInfo for the server. + GetAWSInfo() *types.AWSInfo + + // IsOpenSSHNode returns whether the connection to this Server must use OpenSSH. + // This returns true for SubKindOpenSSHNode and SubKindOpenSSHEICENode. + IsOpenSSHNode() bool + + // IsEICE returns whether the Node is an EICE instance. + // Must be `openssh-ec2-ice` subkind and have the AccountID and InstanceID information (AWS Metadata or Labels). + IsEICE() bool + + // GetAWSInstanceID returns the AWS Instance ID if this node comes from an EC2 instance. + GetAWSInstanceID() string + // GetAWSAccountID returns the AWS Account ID if this node comes from an EC2 instance. + GetAWSAccountID() string +} + +// DynamicWindowsDesktop represents a Windows desktop host that is automatically discovered by Windows Desktop Service. +type DynamicWindowsDesktop interface { + // ResourceWithLabels provides common resource methods. + ResourceWithLabels + // GetAddr returns the network address of this host. + GetAddr() string + // GetDomain returns the ActiveDirectory domain of this host. + GetDomain() string + // NonAD checks whether this is a standalone host that + // is not joined to an Active Directory domain. + NonAD() bool + // GetScreenSize returns the desired size of the screen to use for sessions + // to this host. Returns (0, 0) if no screen size is set, which means to + // use the size passed by the client over TDP. + GetScreenSize() (width, height uint32) + // Copy returns a copy of this dynamic Windows desktop + Copy() *types.DynamicWindowsDesktopV1 +} diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 93daf0aee5cd6..7699f5459b070 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/utils" logutils "github.com/gravitational/teleport/lib/utils/log" ) @@ -87,21 +88,21 @@ func watchKindsString(kinds []types.WatchKind) string { // ResourceWatcherConfig configures resource watcher. type ResourceWatcherConfig struct { - // Component is a component used in logs. - Component string + // Clock is used to control time. + Clock clockwork.Clock + // Client is used to create new watchers + Client types.Events // Logger emits log messages. Logger *slog.Logger + // ResetC is a channel to notify of internal watcher reset (used in tests). + ResetC chan time.Duration + // Component is a component used in logs. + Component string // MaxRetryPeriod is the maximum retry period on failed watchers. MaxRetryPeriod time.Duration - // Clock is used to control time. - Clock clockwork.Clock - // Client is used to create new watchers. - Client types.Events // MaxStaleness is a maximum acceptable staleness for the locally maintained // resources, zero implies no staleness detection. MaxStaleness time.Duration - // ResetC is a channel to notify of internal watcher reset (used in tests). - ResetC chan time.Duration // QueueSize is an optional queue size QueueSize int } @@ -165,28 +166,23 @@ func newResourceWatcher(ctx context.Context, collector resourceCollector, cfg Re // resourceWatcher monitors additions, updates and deletions // to a set of resources. type resourceWatcher struct { - ResourceWatcherConfig - collector resourceCollector - - // ctx is a context controlling the lifetime of this resourceWatcher - // instance. - ctx context.Context - cancel context.CancelFunc - - // retry is used to manage backoff logic for watchers. - retry retryutils.Retry - // failureStartedAt records when the current sync failures were first // detected, zero if there are no failures present. failureStartedAt time.Time - + collector resourceCollector + // ctx is a context controlling the lifetime of this resourceWatcher + // instance. + ctx context.Context + // retry is used to manage backoff logic for watchers. + retry retryutils.Retry + cancel context.CancelFunc // LoopC is a channel to check whether the watch loop is running // (used in tests). LoopC chan struct{} - // StaleC is a channel that can trigger the condition of resource staleness // (used in tests). StaleC chan struct{} + ResourceWatcherConfig } // Done returns a channel that signals resource watcher closure. @@ -380,195 +376,523 @@ func (p *resourceWatcher) watch() error { // ProxyWatcherConfig is a ProxyWatcher configuration. type ProxyWatcherConfig struct { - ResourceWatcherConfig // ProxyGetter is used to directly fetch the list of active proxies. ProxyGetter // ProxyDiffer is used to decide whether a put operation on an existing proxy should // trigger a event. ProxyDiffer func(old, new types.Server) bool // ProxiesC is a channel used to report the current proxy set. It receives - // a fresh list at startup and subsequently a list of all known proxies + // a fresh list at startup and subsequently a list of all known proxy // whenever an addition or deletion is detected. ProxiesC chan []types.Server + ResourceWatcherConfig +} + +// NewProxyWatcher returns a new instance of GenericWatcher that is configured +// to watch for changes. +func NewProxyWatcher(ctx context.Context, cfg ProxyWatcherConfig) (*GenericWatcher[types.Server, readonly.Server], error) { + if cfg.ProxyGetter == nil { + return nil, trace.BadParameter("ProxyGetter must be provided") + } + + if cfg.ProxyDiffer == nil { + cfg.ProxyDiffer = func(old, new types.Server) bool { return true } + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server, readonly.Server]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindProxy, + ResourceKey: types.Server.GetName, + ResourceGetter: func(ctx context.Context) ([]types.Server, error) { + return cfg.ProxyGetter.GetProxies() + }, + ResourcesC: cfg.ProxiesC, + ResourceDiffer: cfg.ProxyDiffer, + RequireResourcesForInitialBroadcast: true, + CloneFunc: types.Server.DeepCopy, + }) + return w, trace.Wrap(err) +} + +// DatabaseWatcherConfig is a DatabaseWatcher configuration. +type DatabaseWatcherConfig struct { + // DatabaseGetter is responsible for fetching database resources. + DatabaseGetter + // DatabasesC receives up-to-date list of all database resources. + DatabasesC chan []types.Database + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewDatabaseWatcher returns a new instance of DatabaseWatcher. +func NewDatabaseWatcher(ctx context.Context, cfg DatabaseWatcherConfig) (*GenericWatcher[types.Database, readonly.Database], error) { + if cfg.DatabaseGetter == nil { + return nil, trace.BadParameter("DatabaseGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Database, readonly.Database]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindDatabase, + ResourceKey: types.Database.GetName, + ResourceGetter: func(ctx context.Context) ([]types.Database, error) { + return cfg.DatabaseGetter.GetDatabases(ctx) + }, + ResourcesC: cfg.DatabasesC, + CloneFunc: func(resource types.Database) types.Database { + return resource.Copy() + }, + }) + return w, trace.Wrap(err) +} + +// AppWatcherConfig is an AppWatcher configuration. +type AppWatcherConfig struct { + // AppGetter is responsible for fetching application resources. + AppGetter + // AppsC receives up-to-date list of all application resources. + AppsC chan []types.Application + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewAppWatcher returns a new instance of AppWatcher. +func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*GenericWatcher[types.Application, readonly.Application], error) { + if cfg.AppGetter == nil { + return nil, trace.BadParameter("AppGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Application, readonly.Application]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindApp, + ResourceKey: types.Application.GetName, + ResourceGetter: func(ctx context.Context) ([]types.Application, error) { + return cfg.AppGetter.GetApps(ctx) + }, + ResourcesC: cfg.AppsC, + CloneFunc: func(resource types.Application) types.Application { + return resource.Copy() + }, + }) + + return w, trace.Wrap(err) +} + +// KubeServerWatcherConfig is an KubeServerWatcher configuration. +type KubeServerWatcherConfig struct { + // KubernetesServerGetter is responsible for fetching kube_server resources. + KubernetesServerGetter + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewKubeServerWatcher returns a new instance of KubeServerWatcher. +func NewKubeServerWatcher(ctx context.Context, cfg KubeServerWatcherConfig) (*GenericWatcher[types.KubeServer, readonly.KubeServer], error) { + if cfg.KubernetesServerGetter == nil { + return nil, trace.BadParameter("KubernetesServerGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeServer, readonly.KubeServer]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindKubeServer, + ResourceGetter: func(ctx context.Context) ([]types.KubeServer, error) { + return cfg.KubernetesServerGetter.GetKubernetesServers(ctx) + }, + ResourceKey: func(resource types.KubeServer) string { + return resource.GetHostID() + resource.GetName() + }, + DisableUpdateBroadcast: true, + CloneFunc: types.KubeServer.Copy, + }) + return w, trace.Wrap(err) +} + +// KubeClusterWatcherConfig is an KubeClusterWatcher configuration. +type KubeClusterWatcherConfig struct { + // KubernetesGetter is responsible for fetching kube_cluster resources. + KubernetesClusterGetter + // KubeClustersC receives up-to-date list of all kube_cluster resources. + KubeClustersC chan []types.KubeCluster + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewKubeClusterWatcher returns a new instance of KubeClusterWatcher. +func NewKubeClusterWatcher(ctx context.Context, cfg KubeClusterWatcherConfig) (*GenericWatcher[types.KubeCluster, readonly.KubeCluster], error) { + if cfg.KubernetesClusterGetter == nil { + return nil, trace.BadParameter("KubernetesClusterGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeCluster, readonly.KubeCluster]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindKubernetesCluster, + ResourceGetter: func(ctx context.Context) ([]types.KubeCluster, error) { + return cfg.KubernetesClusterGetter.GetKubernetesClusters(ctx) + }, + ResourceKey: types.KubeCluster.GetName, + ResourcesC: cfg.KubeClustersC, + CloneFunc: func(resource types.KubeCluster) types.KubeCluster { + return resource.Copy() + }, + }) + return w, trace.Wrap(err) +} + +type DynamicWindowsDesktopGetter interface { + ListDynamicWindowsDesktops(ctx context.Context, pageSize int, pageToken string) ([]types.DynamicWindowsDesktop, string, error) +} + +// DynamicWindowsDesktopWatcherConfig is a DynamicWindowsDesktopWatcher configuration. +type DynamicWindowsDesktopWatcherConfig struct { + // DynamicWindowsDesktopGetter is responsible for fetching DynamicWindowsDesktop resources. + DynamicWindowsDesktopGetter + // DynamicWindowsDesktopsC receives up-to-date list of all DynamicWindowsDesktop resources. + DynamicWindowsDesktopsC chan []types.DynamicWindowsDesktop + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig +} + +// NewDynamicWindowsDesktopWatcher returns a new instance of DynamicWindowsDesktopWatcher. +func NewDynamicWindowsDesktopWatcher(ctx context.Context, cfg DynamicWindowsDesktopWatcherConfig) (*GenericWatcher[types.DynamicWindowsDesktop, readonly.DynamicWindowsDesktop], error) { + if cfg.DynamicWindowsDesktopGetter == nil { + return nil, trace.BadParameter("KubernetesClusterGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.DynamicWindowsDesktop, readonly.DynamicWindowsDesktop]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindDynamicWindowsDesktop, + ResourceGetter: func(ctx context.Context) ([]types.DynamicWindowsDesktop, error) { + var desktops []types.DynamicWindowsDesktop + next := "" + for { + d, token, err := cfg.DynamicWindowsDesktopGetter.ListDynamicWindowsDesktops(ctx, defaults.MaxIterationLimit, next) + if err != nil { + return nil, err + } + desktops = append(desktops, d...) + if token == "" { + break + } + next = token + } + return desktops, nil + }, + ResourceKey: types.DynamicWindowsDesktop.GetName, + ResourcesC: cfg.DynamicWindowsDesktopsC, + CloneFunc: func(resource types.DynamicWindowsDesktop) types.DynamicWindowsDesktop { + return resource.Copy() + }, + }) + return w, trace.Wrap(err) +} + +// GenericWatcherConfig is a generic resource watcher configuration. +type GenericWatcherConfig[T any, R any] struct { + // ResourceGetter is used to directly fetch the current set of resources. + ResourceGetter func(context.Context) ([]T, error) + // ResourceDiffer is used to decide whether a put operation on an existing ResourceGetter should + // trigger an event. + ResourceDiffer func(old, new T) bool + // ResourceKey defines how the resources should be keyed. + ResourceKey func(resource T) string + // ResourcesC is a channel used to report the current resourxe set. It receives + // a fresh list at startup and subsequently a list of all known resourxes + // whenever an addition or deletion is detected. + ResourcesC chan []T + // CloneFunc defines how a resource is cloned. All resources provided via + // the broadcast mechanism, or retrieved via [GenericWatcer.CurrentResources] + // or [GenericWatcher.CurrentResourcesWithFilter] will be cloned by this + // mechanism before being provided to callers. + CloneFunc func(resource T) T + ResourceWatcherConfig + // ResourceKind specifies the kind of resource the watcher is monitoring. + ResourceKind string + // RequireResourcesForInitialBroadcast indicates whether an update should be + // performed if the initial set of resources is empty. + RequireResourcesForInitialBroadcast bool + // DisableUpdateBroadcast turns off emitting updates on changes. When this + // mode is opted into, users must invoke [GenericWatcher.CurrentResources] or + // [GenericWatcher.CurrentResourcesWithFilter] manually to retrieve the active + // resource set. + DisableUpdateBroadcast bool } // CheckAndSetDefaults checks parameters and sets default values. -func (cfg *ProxyWatcherConfig) CheckAndSetDefaults() error { +func (cfg *GenericWatcherConfig[T, R]) CheckAndSetDefaults() error { if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } - if cfg.ProxyGetter == nil { - getter, ok := cfg.Client.(ProxyGetter) - if !ok { - return trace.BadParameter("missing parameter ProxyGetter and Client not usable as ProxyGetter") - } - cfg.ProxyGetter = getter + + if cfg.ResourceGetter == nil { + return trace.BadParameter("ResourceGetter not provided to generic resource watcher") + } + + if cfg.ResourceKind == "" { + return trace.BadParameter("ResourceKind not provided to generic resource watcher") + } + + if cfg.ResourceKey == nil { + return trace.BadParameter("ResourceKey not provided to generic resource watcher") + } + + if cfg.ResourceDiffer == nil { + cfg.ResourceDiffer = func(T, T) bool { return true } } - if cfg.ProxiesC == nil { - cfg.ProxiesC = make(chan []types.Server) + + if cfg.ResourcesC == nil { + cfg.ResourcesC = make(chan []T) } return nil } -// NewProxyWatcher returns a new instance of ProxyWatcher. -func NewProxyWatcher(ctx context.Context, cfg ProxyWatcherConfig) (*ProxyWatcher, error) { +// NewGenericResourceWatcher returns a new instance of resource watcher. +func NewGenericResourceWatcher[T any, R any](ctx context.Context, cfg GenericWatcherConfig[T, R]) (*GenericWatcher[T, R], error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } - collector := &proxyCollector{ - ProxyWatcherConfig: cfg, - initializationC: make(chan struct{}), + + cache, err := utils.NewFnCache(utils.FnCacheConfig{ + Context: ctx, + TTL: 3 * time.Second, + Clock: cfg.Clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + collector := &genericCollector[T, R]{ + GenericWatcherConfig: cfg, + initializationC: make(chan struct{}), + cache: cache, } + collector.stale.Store(true) watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) if err != nil { return nil, trace.Wrap(err) } - return &ProxyWatcher{watcher, collector}, nil + return &GenericWatcher[T, R]{watcher, collector}, nil } -// ProxyWatcher is built on top of resourceWatcher to monitor additions -// and deletions to the set of proxies. -type ProxyWatcher struct { +// GenericWatcher is built on top of resourceWatcher to monitor additions +// and deletions to the set of resources. +type GenericWatcher[T any, R any] struct { *resourceWatcher - *proxyCollector + *genericCollector[T, R] +} + +// ResourceCount returns the current number of resources known to the watcher. +func (g *GenericWatcher[T, R]) ResourceCount() int { + g.rw.RLock() + defer g.rw.RUnlock() + return len(g.current) +} + +// CurrentResources returns a copy of the resources known to the watcher. +func (g *GenericWatcher[T, R]) CurrentResources(ctx context.Context) ([]T, error) { + if err := g.refreshStaleResources(ctx); err != nil { + return nil, trace.Wrap(err) + } + + g.rw.RLock() + defer g.rw.RUnlock() + + return resourcesToSlice(g.current, g.CloneFunc), nil +} + +// CurrentResourcesWithFilter returns a copy of the resources known to the watcher +// that match the provided filter. +func (g *GenericWatcher[T, R]) CurrentResourcesWithFilter(ctx context.Context, filter func(R) bool) ([]T, error) { + if err := g.refreshStaleResources(ctx); err != nil { + return nil, trace.Wrap(err) + } + + g.rw.RLock() + defer g.rw.RUnlock() + + r := func(a any) R { + return a.(R) + } + + var out []T + for _, resource := range g.current { + if filter(r(resource)) { + out = append(out, g.CloneFunc(resource)) + } + } + + return out, nil } -// proxyCollector accompanies resourceWatcher when monitoring proxies. -type proxyCollector struct { - ProxyWatcherConfig - // current holds a map of the currently known proxies (keyed by server name, +// genericCollector accompanies resourceWatcher when monitoring proxies. +type genericCollector[T any, R any] struct { + GenericWatcherConfig[T, R] + // current holds a map of the currently known resources (keyed by server name, // RWMutex protected). - current map[string]types.Server - rw sync.RWMutex + current map[string]T initializationC chan struct{} - once sync.Once + // cache is a helper for temporarily storing the results of CurrentResources. + // It's used to limit the number of calls to the backend. + cache *utils.FnCache + rw sync.RWMutex + once sync.Once + // stale is used to indicate that the watcher is stale and needs to be + // refreshed. + stale atomic.Bool } -// GetCurrent returns the currently stored proxies. -func (p *proxyCollector) GetCurrent() []types.Server { - p.rw.RLock() - defer p.rw.RUnlock() - return serverMapValues(p.current) +// resourceKinds specifies the resource kind to watch. +func (g *genericCollector[T, R]) resourceKinds() []types.WatchKind { + return []types.WatchKind{{Kind: g.ResourceKind}} } -// resourceKinds specifies the resource kind to watch. -func (p *proxyCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindProxy}} +// getResources gets the list of current resources. +func (g *genericCollector[T, R]) getResources(ctx context.Context) (map[string]T, error) { + resources, err := g.GenericWatcherConfig.ResourceGetter(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + current := make(map[string]T, len(resources)) + for _, resource := range resources { + current[g.GenericWatcherConfig.ResourceKey(resource)] = resource + } + return current, nil +} + +func (g *genericCollector[T, R]) refreshStaleResources(ctx context.Context) error { + if !g.stale.Load() { + return nil + } + + _, err := utils.FnCacheGet(ctx, g.cache, g.GenericWatcherConfig.ResourceKind, func(ctx context.Context) (any, error) { + current, err := g.getResources(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + // There is a chance that the watcher reinitialized while + // getting resources happened above. Check if we are still stale + if g.stale.CompareAndSwap(true, false) { + g.rw.Lock() + g.current = current + g.rw.Unlock() + } + + return nil, nil + }) + + return trace.Wrap(err) } // getResourcesAndUpdateCurrent is called when the resources should be // (re-)fetched directly. -func (p *proxyCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - proxies, err := p.ProxyGetter.GetProxies() +func (g *genericCollector[T, R]) getResourcesAndUpdateCurrent(ctx context.Context) error { + newCurrent, err := g.getResources(ctx) if err != nil { return trace.Wrap(err) } - newCurrent := make(map[string]types.Server, len(proxies)) - for _, proxy := range proxies { - newCurrent[proxy.GetName()] = proxy - } - p.rw.Lock() - defer p.rw.Unlock() - p.current = newCurrent - // only emit an empty proxy list if the collector has already been initialized - // to prevent an empty slice being sent out on creation of the watcher - if len(proxies) > 0 || (len(proxies) == 0 && p.isInitialized()) { - p.broadcastUpdate(ctx) + g.rw.Lock() + defer g.rw.Unlock() + g.current = newCurrent + g.stale.Store(false) + // Only emit an empty set of resources if the watcher is already initialized, + // or if explicitly opted into by for the watcher. + if len(newCurrent) > 0 || g.isInitialized() || + (!g.RequireResourcesForInitialBroadcast && len(newCurrent) == 0) { + g.broadcastUpdate(ctx) } - p.defineCollectorAsInitialized() + g.defineCollectorAsInitialized() return nil } -func (p *proxyCollector) defineCollectorAsInitialized() { - p.once.Do(func() { +func (g *genericCollector[T, R]) defineCollectorAsInitialized() { + g.once.Do(func() { // mark watcher as initialized. - close(p.initializationC) + close(g.initializationC) }) } // processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *proxyCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - p.rw.Lock() - defer p.rw.Unlock() +func (g *genericCollector[T, R]) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { + g.rw.Lock() + defer g.rw.Unlock() var updated bool for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindProxy { - p.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) + if event.Resource == nil || event.Resource.GetKind() != g.ResourceKind { + g.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) continue } switch event.Type { case types.OpDelete: - delete(p.current, event.Resource.GetName()) - // Always broadcast when a proxy is deleted. + // On delete events, the server description is populated with the host ID. + delete(g.current, event.Resource.GetMetadata().Description+event.Resource.GetName()) + // Always broadcast when a resource is deleted. updated = true case types.OpPut: - server, ok := event.Resource.(types.Server) + resource, ok := event.Resource.(T) if !ok { - p.Logger.WarnContext(ctx, "Received unexpected type", "resource", event.Resource.GetKind()) + g.Logger.WarnContext(ctx, "Received unexpected type", "resource", event.Resource.GetKind()) continue } - current, exists := p.current[server.GetName()] - p.current[server.GetName()] = server - if !exists || (p.ProxyDiffer != nil && p.ProxyDiffer(current, server)) { - updated = true - } + + key := g.ResourceKey(resource) + current := g.current[key] + g.current[key] = resource + updated = g.ResourceDiffer(current, resource) default: - p.Logger.WarnContext(ctx, "Skipping unsupported event type", "event_type", event.Type) + g.Logger.WarnContext(ctx, "Skipping unsupported event type", "event_type", event.Type) } } if updated { - p.broadcastUpdate(ctx) + g.broadcastUpdate(ctx) } } -// broadcastUpdate broadcasts information about updating the proxy set. -func (p *proxyCollector) broadcastUpdate(ctx context.Context) { - names := make([]string, 0, len(p.current)) - for k := range p.current { +// broadcastUpdate broadcasts information about updating the resource set. +func (g *genericCollector[T, R]) broadcastUpdate(ctx context.Context) { + if g.DisableUpdateBroadcast { + return + } + + names := make([]string, 0, len(g.current)) + for k := range g.current { names = append(names, k) } - p.Logger.DebugContext(ctx, "List of known proxies updated", "proxies", names) + g.Logger.DebugContext(ctx, "List of known resources updated", "resources", names) select { - case p.ProxiesC <- serverMapValues(p.current): + case g.ResourcesC <- resourcesToSlice(g.current, g.CloneFunc): case <-ctx.Done(): } } // isInitialized is used to check that the cache has done its initial // sync -func (p *proxyCollector) initializationChan() <-chan struct{} { - return p.initializationC +func (g *genericCollector[T, R]) initializationChan() <-chan struct{} { + return g.initializationC } -func (p *proxyCollector) isInitialized() bool { +func (g *genericCollector[T, R]) isInitialized() bool { select { - case <-p.initializationC: + case <-g.initializationC: return true default: return false } } -func (p *proxyCollector) notifyStale() {} - -func serverMapValues(serverMap map[string]types.Server) []types.Server { - servers := make([]types.Server, 0, len(serverMap)) - for _, server := range serverMap { - servers = append(servers, server) - } - return servers +func (g *genericCollector[T, R]) notifyStale() { + g.stale.Store(true) } // LockWatcherConfig is a LockWatcher configuration. type LockWatcherConfig struct { - ResourceWatcherConfig LockGetter + ResourceWatcherConfig } // CheckAndSetDefaults checks parameters and sets default values. @@ -622,15 +946,15 @@ type lockCollector struct { LockWatcherConfig // current holds a map of the currently known locks (keyed by lock name). current map[string]types.Lock - // isStale indicates whether the local lock view (current) is stale. - isStale bool - // currentRW is a mutex protecting both current and isStale. - currentRW sync.RWMutex // fanout provides support for multiple subscribers to the lock updates. fanout *FanoutV2 // initializationC is used to check whether the initial sync has completed initializationC chan struct{} - once sync.Once + // currentRW is a mutex protecting both current and isStale. + currentRW sync.RWMutex + once sync.Once + // isStale indicates whether the local lock view (current) is stale. + isStale bool } // IsStale is used to check whether the lock watcher is stale. @@ -817,858 +1141,37 @@ func lockMapValues(lockMap map[string]types.Lock) []types.Lock { return locks } -// DatabaseWatcherConfig is a DatabaseWatcher configuration. -type DatabaseWatcherConfig struct { +func resourcesToSlice[T any](resources map[string]T, cloneFunc func(T) T) (slice []T) { + for _, resource := range resources { + slice = append(slice, cloneFunc(resource)) + } + return slice +} + +// CertAuthorityWatcherConfig is a CertAuthorityWatcher configuration. +type CertAuthorityWatcherConfig struct { // ResourceWatcherConfig is the resource watcher configuration. ResourceWatcherConfig - // DatabaseGetter is responsible for fetching database resources. - DatabaseGetter - // DatabasesC receives up-to-date list of all database resources. - DatabasesC chan types.Databases + // AuthorityGetter is responsible for fetching cert authority resources. + AuthorityGetter + // Types restricts which cert authority types are retrieved via the AuthorityGetter. + Types []types.CertAuthType } // CheckAndSetDefaults checks parameters and sets default values. -func (cfg *DatabaseWatcherConfig) CheckAndSetDefaults() error { +func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } - if cfg.DatabaseGetter == nil { - getter, ok := cfg.Client.(DatabaseGetter) + if cfg.AuthorityGetter == nil { + getter, ok := cfg.Client.(AuthorityGetter) if !ok { - return trace.BadParameter("missing parameter DatabaseGetter and Client not usable as DatabaseGetter") + return trace.BadParameter("missing parameter AuthorityGetter and Client not usable as AuthorityGetter") } - cfg.DatabaseGetter = getter + cfg.AuthorityGetter = getter } - if cfg.DatabasesC == nil { - cfg.DatabasesC = make(chan types.Databases) - } - return nil -} - -// NewDatabaseWatcher returns a new instance of DatabaseWatcher. -func NewDatabaseWatcher(ctx context.Context, cfg DatabaseWatcherConfig) (*DatabaseWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - collector := &databaseCollector{ - DatabaseWatcherConfig: cfg, - initializationC: make(chan struct{}), - } - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &DatabaseWatcher{watcher, collector}, nil -} - -// DatabaseWatcher is built on top of resourceWatcher to monitor database resources. -type DatabaseWatcher struct { - *resourceWatcher - *databaseCollector -} - -// databaseCollector accompanies resourceWatcher when monitoring database resources. -type databaseCollector struct { - // DatabaseWatcherConfig is the watcher configuration. - DatabaseWatcherConfig - // current holds a map of the currently known database resources. - current map[string]types.Database - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check that the - initializationC chan struct{} - once sync.Once -} - -// resourceKinds specifies the resource kind to watch. -func (p *databaseCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindDatabase}} -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (p *databaseCollector) initializationChan() <-chan struct{} { - return p.initializationC -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (p *databaseCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - databases, err := p.DatabaseGetter.GetDatabases(ctx) - if err != nil { - return trace.Wrap(err) - } - newCurrent := make(map[string]types.Database, len(databases)) - for _, database := range databases { - newCurrent[database.GetName()] = database - } - p.lock.Lock() - defer p.lock.Unlock() - p.current = newCurrent - p.defineCollectorAsInitialized() - - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case p.DatabasesC <- databases: - } - - return nil -} - -func (p *databaseCollector) defineCollectorAsInitialized() { - p.once.Do(func() { - // mark watcher as initialized. - close(p.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *databaseCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - p.lock.Lock() - defer p.lock.Unlock() - - var updated bool - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindDatabase { - p.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - switch event.Type { - case types.OpDelete: - delete(p.current, event.Resource.GetName()) - updated = true - case types.OpPut: - database, ok := event.Resource.(types.Database) - if !ok { - p.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - p.current[database.GetName()] = database - updated = true - default: - p.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } - - if updated { - select { - case <-ctx.Done(): - case p.DatabasesC <- resourcesToSlice(p.current): - } - } -} - -func (*databaseCollector) notifyStale() {} - -type DynamicWindowsDesktopGetter interface { - ListDynamicWindowsDesktops(ctx context.Context, pageSize int, pageToken string) ([]types.DynamicWindowsDesktop, string, error) -} - -// DynamicWindowsDesktopWatcherConfig is a DynamicWindowsDesktopWatcher configuration. -type DynamicWindowsDesktopWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // DynamicWindowsDesktopGetter is responsible for fetching DynamicWindowsDesktop resources. - DynamicWindowsDesktopGetter - // DynamicWindowsDesktopsC receives up-to-date list of all DynamicWindowsDesktop resources. - DynamicWindowsDesktopsC chan types.DynamicWindowsDesktops -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *DynamicWindowsDesktopWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.DynamicWindowsDesktopGetter == nil { - getter, ok := cfg.Client.(DynamicWindowsDesktopGetter) - if !ok { - return trace.BadParameter("missing parameter DynamicWindowsDesktopGetter and Client %T not usable as DynamicWindowsDesktopGetter", cfg.Client) - } - cfg.DynamicWindowsDesktopGetter = getter - } - if cfg.DynamicWindowsDesktopsC == nil { - cfg.DynamicWindowsDesktopsC = make(chan types.DynamicWindowsDesktops) - } - return nil -} - -// NewDynamicWindowsDesktopWatcher returns a new instance of DynamicWindowsDesktopWatcher. -func NewDynamicWindowsDesktopWatcher(ctx context.Context, cfg DynamicWindowsDesktopWatcherConfig) (*DynamicWindowsDesktopWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - collector := &dynamicWindowsDesktopCollector{ - DynamicWindowsDesktopWatcherConfig: cfg, - initializationC: make(chan struct{}), - } - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &DynamicWindowsDesktopWatcher{watcher, collector}, nil -} - -// DynamicWindowsDesktopWatcher is built on top of resourceWatcher to monitor DynamicWindowsDesktop resources. -type DynamicWindowsDesktopWatcher struct { - *resourceWatcher - *dynamicWindowsDesktopCollector -} - -// dynamicWindowsDesktopCollector accompanies resourceWatcher when monitoring DynamicWindowsDesktop resources. -type dynamicWindowsDesktopCollector struct { - // DynamicWindowsDesktopWatcherConfig is the watcher configuration. - DynamicWindowsDesktopWatcherConfig - // current holds a map of the currently known DynamicWindowsDesktop resources. - current map[string]types.DynamicWindowsDesktop - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check that the - initializationC chan struct{} - once sync.Once -} - -// resourceKinds specifies the resource kind to watch. -func (p *dynamicWindowsDesktopCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindDynamicWindowsDesktop}} -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (p *dynamicWindowsDesktopCollector) initializationChan() <-chan struct{} { - return p.initializationC -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (p *dynamicWindowsDesktopCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - var dynamicWindowsDesktops []types.DynamicWindowsDesktop - next := "" - for { - desktops, token, err := p.DynamicWindowsDesktopGetter.ListDynamicWindowsDesktops(ctx, defaults.MaxIterationLimit, next) - if err != nil { - return trace.Wrap(err) - } - dynamicWindowsDesktops = append(dynamicWindowsDesktops, desktops...) - if token == "" { - break - } - next = token - } - newCurrent := make(map[string]types.DynamicWindowsDesktop, len(dynamicWindowsDesktops)) - for _, dynamicWindowsDesktop := range dynamicWindowsDesktops { - newCurrent[dynamicWindowsDesktop.GetName()] = dynamicWindowsDesktop - } - p.lock.Lock() - defer p.lock.Unlock() - p.current = newCurrent - p.defineCollectorAsInitialized() - - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case p.DynamicWindowsDesktopsC <- dynamicWindowsDesktops: - } - - return nil -} - -func (p *dynamicWindowsDesktopCollector) defineCollectorAsInitialized() { - p.once.Do(func() { - // mark watcher as initialized. - close(p.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *dynamicWindowsDesktopCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - p.lock.Lock() - defer p.lock.Unlock() - - var updated bool - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindDynamicWindowsDesktop { - p.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - switch event.Type { - case types.OpDelete: - delete(p.current, event.Resource.GetName()) - updated = true - case types.OpPut: - dynamicWindowsDesktop, ok := event.Resource.(types.DynamicWindowsDesktop) - if !ok { - p.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - p.current[dynamicWindowsDesktop.GetName()] = dynamicWindowsDesktop - updated = true - default: - p.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } - - if updated { - select { - case <-ctx.Done(): - case p.DynamicWindowsDesktopsC <- resourcesToSlice(p.current): - } - } -} - -func (*dynamicWindowsDesktopCollector) notifyStale() {} - -// AppWatcherConfig is an AppWatcher configuration. -type AppWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // AppGetter is responsible for fetching application resources. - AppGetter - // AppsC receives up-to-date list of all application resources. - AppsC chan types.Apps -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *AppWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.AppGetter == nil { - getter, ok := cfg.Client.(AppGetter) - if !ok { - return trace.BadParameter("missing parameter AppGetter and Client not usable as AppGetter") - } - cfg.AppGetter = getter - } - if cfg.AppsC == nil { - cfg.AppsC = make(chan types.Apps) - } - return nil -} - -// NewAppWatcher returns a new instance of AppWatcher. -func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*AppWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - collector := &appCollector{ - AppWatcherConfig: cfg, - initializationC: make(chan struct{}), - } - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &AppWatcher{watcher, collector}, nil -} - -// AppWatcher is built on top of resourceWatcher to monitor application resources. -type AppWatcher struct { - *resourceWatcher - *appCollector -} - -// appCollector accompanies resourceWatcher when monitoring application resources. -type appCollector struct { - // AppWatcherConfig is the watcher configuration. - AppWatcherConfig - // current holds a map of the currently known application resources. - current map[string]types.Application - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once -} - -// resourceKinds specifies the resource kind to watch. -func (p *appCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindApp}} -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (p *appCollector) initializationChan() <-chan struct{} { - return p.initializationC -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (p *appCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - apps, err := p.AppGetter.GetApps(ctx) - if err != nil { - return trace.Wrap(err) - } - newCurrent := make(map[string]types.Application, len(apps)) - for _, app := range apps { - newCurrent[app.GetName()] = app - } - p.lock.Lock() - defer p.lock.Unlock() - p.current = newCurrent - p.defineCollectorAsInitialized() - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case p.AppsC <- apps: - } - return nil -} - -func (p *appCollector) defineCollectorAsInitialized() { - p.once.Do(func() { - // mark watcher as initialized. - close(p.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *appCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - p.lock.Lock() - defer p.lock.Unlock() - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindApp { - p.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - switch event.Type { - case types.OpDelete: - delete(p.current, event.Resource.GetName()) - p.AppsC <- resourcesToSlice(p.current) - - select { - case <-ctx.Done(): - case p.AppsC <- resourcesToSlice(p.current): - } - - case types.OpPut: - app, ok := event.Resource.(types.Application) - if !ok { - p.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - p.current[app.GetName()] = app - - select { - case <-ctx.Done(): - case p.AppsC <- resourcesToSlice(p.current): - } - default: - p.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } -} - -func (*appCollector) notifyStale() {} - -func resourcesToSlice[T any](resources map[string]T) (slice []T) { - for _, resource := range resources { - slice = append(slice, resource) - } - return slice -} - -// KubeClusterWatcherConfig is an KubeClusterWatcher configuration. -type KubeClusterWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // KubernetesGetter is responsible for fetching kube_cluster resources. - KubernetesClusterGetter - // KubeClustersC receives up-to-date list of all kube_cluster resources. - KubeClustersC chan types.KubeClusters -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *KubeClusterWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.KubernetesClusterGetter == nil { - getter, ok := cfg.Client.(KubernetesClusterGetter) - if !ok { - return trace.BadParameter("missing parameter KubernetesGetter and Client not usable as KubernetesGetter") - } - cfg.KubernetesClusterGetter = getter - } - if cfg.KubeClustersC == nil { - cfg.KubeClustersC = make(chan types.KubeClusters) - } - return nil -} - -// NewKubeClusterWatcher returns a new instance of KubeClusterWatcher. -func NewKubeClusterWatcher(ctx context.Context, cfg KubeClusterWatcherConfig) (*KubeClusterWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - collector := &kubeCollector{ - KubeClusterWatcherConfig: cfg, - initializationC: make(chan struct{}), - } - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &KubeClusterWatcher{watcher, collector}, nil -} - -// KubeClusterWatcher is built on top of resourceWatcher to monitor kube_cluster resources. -type KubeClusterWatcher struct { - *resourceWatcher - *kubeCollector -} - -// kubeCollector accompanies resourceWatcher when monitoring kube_cluster resources. -type kubeCollector struct { - // KubeClusterWatcherConfig is the watcher configuration. - KubeClusterWatcherConfig - // current holds a map of the currently known kube_cluster resources. - current map[string]types.KubeCluster - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (k *kubeCollector) initializationChan() <-chan struct{} { - return k.initializationC -} - -// resourceKinds specifies the resource kind to watch. -func (k *kubeCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindKubernetesCluster}} -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (k *kubeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - clusters, err := k.KubernetesClusterGetter.GetKubernetesClusters(ctx) - if err != nil { - return trace.Wrap(err) - } - newCurrent := make(map[string]types.KubeCluster, len(clusters)) - for _, cluster := range clusters { - newCurrent[cluster.GetName()] = cluster - } - k.lock.Lock() - defer k.lock.Unlock() - k.current = newCurrent - - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case k.KubeClustersC <- clusters: - } - - k.defineCollectorAsInitialized() - - return nil -} - -func (k *kubeCollector) defineCollectorAsInitialized() { - k.once.Do(func() { - // mark watcher as initialized. - close(k.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (k *kubeCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - k.lock.Lock() - defer k.lock.Unlock() - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindKubernetesCluster { - k.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - switch event.Type { - case types.OpDelete: - delete(k.current, event.Resource.GetName()) - k.KubeClustersC <- resourcesToSlice(k.current) - - select { - case <-ctx.Done(): - case k.KubeClustersC <- resourcesToSlice(k.current): - } - - case types.OpPut: - cluster, ok := event.Resource.(types.KubeCluster) - if !ok { - k.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - k.current[cluster.GetName()] = cluster - - select { - case <-ctx.Done(): - case k.KubeClustersC <- resourcesToSlice(k.current): - } - default: - k.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } -} - -func (*kubeCollector) notifyStale() {} - -// KubeServerWatcherConfig is an KubeServerWatcher configuration. -type KubeServerWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // KubernetesServerGetter is responsible for fetching kube_server resources. - KubernetesServerGetter -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *KubeServerWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.KubernetesServerGetter == nil { - getter, ok := cfg.Client.(KubernetesServerGetter) - if !ok { - return trace.BadParameter("missing parameter KubernetesServerGetter and Client not usable as KubernetesServerGetter") - } - cfg.KubernetesServerGetter = getter - } - return nil -} - -// NewKubeServerWatcher returns a new instance of KubeServerWatcher. -func NewKubeServerWatcher(ctx context.Context, cfg KubeServerWatcherConfig) (*KubeServerWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - cache, err := utils.NewFnCache(utils.FnCacheConfig{ - Context: ctx, - TTL: 3 * time.Second, - Clock: cfg.Clock, - }) - if err != nil { - return nil, trace.Wrap(err) - } - collector := &kubeServerCollector{ - KubeServerWatcherConfig: cfg, - initializationC: make(chan struct{}), - cache: cache, - } - // start the collector as staled. - collector.stale.Store(true) - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - return &KubeServerWatcher{watcher, collector}, nil -} - -// KubeServerWatcher is built on top of resourceWatcher to monitor kube_server resources. -type KubeServerWatcher struct { - *resourceWatcher - *kubeServerCollector -} - -// GetKubeServersByClusterName returns a list of kubernetes servers for the specified cluster. -func (k *KubeServerWatcher) GetKubeServersByClusterName(ctx context.Context, clusterName string) ([]types.KubeServer, error) { - k.refreshStaleKubeServers(ctx) - - k.lock.RLock() - defer k.lock.RUnlock() - var servers []types.KubeServer - for _, server := range k.current { - if server.GetCluster().GetName() == clusterName { - servers = append(servers, server.Copy()) - } - } - if len(servers) == 0 { - return nil, trace.NotFound("no kubernetes servers found for cluster %q", clusterName) - } - - return servers, nil -} - -// GetKubernetesServers returns a list of kubernetes servers for all clusters. -func (k *KubeServerWatcher) GetKubernetesServers(ctx context.Context) ([]types.KubeServer, error) { - k.refreshStaleKubeServers(ctx) - - k.lock.RLock() - defer k.lock.RUnlock() - servers := make([]types.KubeServer, 0, len(k.current)) - for _, server := range k.current { - servers = append(servers, server.Copy()) - } - return servers, nil -} - -// kubeServerCollector accompanies resourceWatcher when monitoring kube_server resources. -type kubeServerCollector struct { - // KubeServerWatcherConfig is the watcher configuration. - KubeServerWatcherConfig - // current holds a map of the currently known kube_server resources. - current map[kubeServersKey]types.KubeServer - // lock protects the "current" map. - lock sync.RWMutex - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once - // stale is used to indicate that the watcher is stale and needs to be - // refreshed. - stale atomic.Bool - // cache is a helper for temporarily storing the results of GetKubernetesServers. - // It's used to limit the amount of calls to the backend. - cache *utils.FnCache -} - -// kubeServersKey is used to uniquely identify a kube_server resource. -type kubeServersKey struct { - hostID string - resourceName string -} - -// isInitialized is used to check that the cache has done its initial -// sync -func (k *kubeServerCollector) initializationChan() <-chan struct{} { - return k.initializationC -} - -// resourceKinds specifies the resource kind to watch. -func (k *kubeServerCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindKubeServer}} -} - -// getResourcesAndUpdateCurrent refreshes the list of current resources. -func (k *kubeServerCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - newCurrent, err := k.getResources(ctx) - if err != nil { - return trace.Wrap(err) - } - - k.lock.Lock() - k.current = newCurrent - k.lock.Unlock() - - k.stale.Store(false) - - k.defineCollectorAsInitialized() - return nil -} - -// getResourcesAndUpdateCurrent gets the list of current resources. -func (k *kubeServerCollector) getResources(ctx context.Context) (map[kubeServersKey]types.KubeServer, error) { - servers, err := k.KubernetesServerGetter.GetKubernetesServers(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - current := make(map[kubeServersKey]types.KubeServer, len(servers)) - for _, server := range servers { - key := kubeServersKey{ - hostID: server.GetHostID(), - resourceName: server.GetName(), - } - current[key] = server - } - return current, nil -} - -func (k *kubeServerCollector) defineCollectorAsInitialized() { - k.once.Do(func() { - // mark watcher as initialized. - close(k.initializationC) - }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (k *kubeServerCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - k.lock.Lock() - defer k.lock.Unlock() - - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindKubeServer { - k.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - - switch event.Type { - case types.OpDelete: - key := kubeServersKey{ - // On delete events, the server description is populated with the host ID. - hostID: event.Resource.GetMetadata().Description, - resourceName: event.Resource.GetName(), - } - delete(k.current, key) - case types.OpPut: - server, ok := event.Resource.(types.KubeServer) - if !ok { - k.Logger.WarnContext(ctx, "Received unexpected resource type", "resource", event.Resource.GetKind()) - continue - } - - key := kubeServersKey{ - hostID: server.GetHostID(), - resourceName: server.GetName(), - } - k.current[key] = server - default: - k.Logger.WarnContext(ctx, "Received unsupported event type", "event_type", event.Type) - } - } -} - -func (k *kubeServerCollector) notifyStale() { - k.stale.Store(true) -} - -// refreshStaleKubeServers attempts to reload kube servers from the cache if -// the collector is stale. This ensures that no matter the health of -// the collector callers will be returned the most up to date node -// set as possible. -func (k *kubeServerCollector) refreshStaleKubeServers(ctx context.Context) error { - if !k.stale.Load() { - return nil - } - - _, err := utils.FnCacheGet(ctx, k.cache, "kube_servers", func(ctx context.Context) (any, error) { - current, err := k.getResources(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - // There is a chance that the watcher reinitialized while - // getting kube servers happened above. Check if we are still stale - if k.stale.CompareAndSwap(true, false) { - k.lock.Lock() - k.current = current - k.lock.Unlock() - } - - return nil, nil - }) - - return trace.Wrap(err) -} - -// CertAuthorityWatcherConfig is a CertAuthorityWatcher configuration. -type CertAuthorityWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig - // AuthorityGetter is responsible for fetching cert authority resources. - AuthorityGetter - // Types restricts which cert authority types are retrieved via the AuthorityGetter. - Types []types.CertAuthType -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.AuthorityGetter == nil { - getter, ok := cfg.Client.(AuthorityGetter) - if !ok { - return trace.BadParameter("missing parameter AuthorityGetter and Client not usable as AuthorityGetter") - } - cfg.AuthorityGetter = getter - } - if len(cfg.Types) == 0 { - return trace.BadParameter("missing parameter Types") + if len(cfg.Types) == 0 { + return trace.BadParameter("missing parameter Types") } return nil } @@ -1712,17 +1215,15 @@ type CertAuthorityWatcher struct { // caCollector accompanies resourceWatcher when monitoring cert authority resources. type caCollector struct { - CertAuthorityWatcherConfig fanout *FanoutV2 - - // lock protects concurrent access to cas - lock sync.RWMutex - // cas maps ca type -> cluster -> ca - cas map[types.CertAuthType]map[string]types.CertAuthority + cas map[types.CertAuthType]map[string]types.CertAuthority // initializationC is used to check whether the initial sync has completed initializationC chan struct{} - once sync.Once filter types.CertAuthorityFilter + CertAuthorityWatcherConfig + // lock protects concurrent access to cas + lock sync.RWMutex + once sync.Once } // Subscribe is used to subscribe to the lock updates. @@ -1859,285 +1360,40 @@ func (c *caCollector) notifyStale() {} // NodeWatcherConfig is a NodeWatcher configuration. type NodeWatcherConfig struct { - ResourceWatcherConfig // NodesGetter is used to directly fetch the list of active nodes. NodesGetter -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *NodeWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.NodesGetter == nil { - getter, ok := cfg.Client.(NodesGetter) - if !ok { - return trace.BadParameter("missing parameter NodesGetter and Client not usable as NodesGetter") - } - cfg.NodesGetter = getter - } - return nil + ResourceWatcherConfig } // NewNodeWatcher returns a new instance of NodeWatcher. -func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*NodeWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - cache, err := utils.NewFnCache(utils.FnCacheConfig{ - Context: ctx, - TTL: 3 * time.Second, - Clock: cfg.Clock, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - collector := &nodeCollector{ - NodeWatcherConfig: cfg, - current: map[string]types.Server{}, - initializationC: make(chan struct{}), - cache: cache, - stale: true, - } - - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - - return &NodeWatcher{resourceWatcher: watcher, nodeCollector: collector}, nil -} - -// NodeWatcher is built on top of resourceWatcher to monitor additions -// and deletions to the set of nodes. -type NodeWatcher struct { - *resourceWatcher - *nodeCollector -} - -// nodeCollector accompanies resourceWatcher when monitoring nodes. -type nodeCollector struct { - NodeWatcherConfig - - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once - - cache *utils.FnCache - - rw sync.RWMutex - // current holds a map of the currently known nodes keyed by server name - current map[string]types.Server - stale bool -} - -// Node is a readonly subset of the types.Server interface which -// users may filter by in GetNodes. -type Node interface { - // ResourceWithLabels provides common resource headers - types.ResourceWithLabels - // GetTeleportVersion returns the teleport version the server is running on - GetTeleportVersion() string - // GetAddr return server address - GetAddr() string - // GetPublicAddrs returns all public addresses where this server can be reached. - GetPublicAddrs() []string - // GetHostname returns server hostname - GetHostname() string - // GetNamespace returns server namespace - GetNamespace() string - // GetCmdLabels gets command labels - GetCmdLabels() map[string]types.CommandLabel - // GetRotation gets the state of certificate authority rotation. - GetRotation() types.Rotation - // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. - GetUseTunnel() bool - // GetProxyIDs returns a list of proxy ids this server is connected to. - GetProxyIDs() []string - // IsEICE returns whether the Node is an EICE instance. - // Must be `openssh-ec2-ice` subkind and have the AccountID and InstanceID information (AWS Metadata or Labels). - IsEICE() bool -} - -// GetNodes allows callers to retrieve a subset of nodes that match the filter provided. The -// returned servers are a copy and can be safely modified. It is intentionally hard to retrieve -// the full set of nodes to reduce the number of copies needed since the number of nodes can get -// quite large and doing so can be expensive. -func (n *nodeCollector) GetNodes(ctx context.Context, fn func(n Node) bool) []types.Server { - // Attempt to freshen our data first. - n.refreshStaleNodes(ctx) - - n.rw.RLock() - defer n.rw.RUnlock() - - var matched []types.Server - for _, server := range n.current { - if fn(server) { - matched = append(matched, server.DeepCopy()) - } - } - - return matched -} - -// GetNode allows callers to retrieve a node based on its name. The -// returned server are a copy and can be safely modified. -func (n *nodeCollector) GetNode(ctx context.Context, name string) (types.Server, error) { - // Attempt to freshen our data first. - n.refreshStaleNodes(ctx) - - n.rw.RLock() - defer n.rw.RUnlock() - - server, found := n.current[name] - if !found { - return nil, trace.NotFound("server does not exist") - } - return server.DeepCopy(), nil -} - -// refreshStaleNodes attempts to reload nodes from the NodeGetter if -// the collecter is stale. This ensures that no matter the health of -// the collecter callers will be returned the most up to date node -// set as possible. -func (n *nodeCollector) refreshStaleNodes(ctx context.Context) error { - n.rw.RLock() - if !n.stale { - n.rw.RUnlock() - return nil - } - n.rw.RUnlock() - - _, err := utils.FnCacheGet(ctx, n.cache, "nodes", func(ctx context.Context) (any, error) { - current, err := n.getNodes(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - n.rw.Lock() - defer n.rw.Unlock() - - // There is a chance that the watcher reinitialized while - // getting nodes happened above. Check if we are still stale - // now that the lock is held to ensure that the refresh is - // still necessary. - if !n.stale { - return nil, nil - } - - n.current = current - return nil, trace.Wrap(err) - }) - - return trace.Wrap(err) -} - -func (n *nodeCollector) NodeCount() int { - n.rw.RLock() - defer n.rw.RUnlock() - return len(n.current) -} - -// resourceKinds specifies the resource kind to watch. -func (n *nodeCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindNode}} -} - -// getResourcesAndUpdateCurrent is called when the resources should be -// (re-)fetched directly. -func (n *nodeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - newCurrent, err := n.getNodes(ctx) - if err != nil { - return trace.Wrap(err) - } - defer n.defineCollectorAsInitialized() - - if len(newCurrent) == 0 { - return nil - } - - n.rw.Lock() - defer n.rw.Unlock() - n.current = newCurrent - n.stale = false - return nil -} - -func (n *nodeCollector) getNodes(ctx context.Context) (map[string]types.Server, error) { - nodes, err := n.NodesGetter.GetNodes(ctx, apidefaults.Namespace) - if err != nil { - return nil, trace.Wrap(err) - } - - if len(nodes) == 0 { - return map[string]types.Server{}, nil - } - - current := make(map[string]types.Server, len(nodes)) - for _, node := range nodes { - current[node.GetName()] = node +func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*GenericWatcher[types.Server, readonly.Server], error) { + if cfg.NodesGetter == nil { + return nil, trace.BadParameter("NodesGetter must be provided") } - return current, nil -} - -func (n *nodeCollector) defineCollectorAsInitialized() { - n.once.Do(func() { - // mark watcher as initialized. - close(n.initializationC) + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server, readonly.Server]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindNode, + ResourceGetter: func(ctx context.Context) ([]types.Server, error) { + return cfg.NodesGetter.GetNodes(ctx, apidefaults.Namespace) + }, + ResourceKey: types.Server.GetName, + DisableUpdateBroadcast: true, + CloneFunc: types.Server.DeepCopy, }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (n *nodeCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - n.rw.Lock() - defer n.rw.Unlock() - - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindNode { - n.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - - switch event.Type { - case types.OpDelete: - delete(n.current, event.Resource.GetName()) - case types.OpPut: - server, ok := event.Resource.(types.Server) - if !ok { - n.Logger.WarnContext(ctx, "Received unexpected type", "resource", event.Resource.GetKind()) - continue - } - - n.current[server.GetName()] = server - default: - n.Logger.WarnContext(ctx, "Skipping unsupported event type", "event_type", event.Type) - } - } -} - -func (n *nodeCollector) initializationChan() <-chan struct{} { - return n.initializationC -} - -func (n *nodeCollector) notifyStale() { - n.rw.Lock() - defer n.rw.Unlock() - n.stale = true + return w, trace.Wrap(err) } // AccessRequestWatcherConfig is a AccessRequestWatcher configuration. type AccessRequestWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig // AccessRequestGetter is responsible for fetching access request resources. AccessRequestGetter - // Filter is the filter to use to monitor access requests. - Filter types.AccessRequestFilter // AccessRequestsC receives up-to-date list of all access request resources. AccessRequestsC chan types.AccessRequests + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig + // Filter is the filter to use to monitor access requests. + Filter types.AccessRequestFilter } // CheckAndSetDefaults checks parameters and sets default values. @@ -2186,11 +1442,11 @@ type accessRequestCollector struct { AccessRequestWatcherConfig // current holds a map of the currently known access request resources. current map[string]types.AccessRequest - // lock protects the "current" map. - lock sync.RWMutex // initializationC is used to check that the watcher has been initialized properly. initializationC chan struct{} - once sync.Once + // lock protects the "current" map. + lock sync.RWMutex + once sync.Once } // resourceKinds specifies the resource kind to watch. @@ -2250,7 +1506,7 @@ func (p *accessRequestCollector) processEventsAndUpdateCurrent(ctx context.Conte delete(p.current, event.Resource.GetName()) select { case <-ctx.Done(): - case p.AccessRequestsC <- resourcesToSlice(p.current): + case p.AccessRequestsC <- resourcesToSlice(p.current, types.AccessRequest.Copy): } case types.OpPut: accessRequest, ok := event.Resource.(types.AccessRequest) @@ -2261,7 +1517,7 @@ func (p *accessRequestCollector) processEventsAndUpdateCurrent(ctx context.Conte p.current[accessRequest.GetName()] = accessRequest select { case <-ctx.Done(): - case p.AccessRequestsC <- resourcesToSlice(p.current): + case p.AccessRequestsC <- resourcesToSlice(p.current, types.AccessRequest.Copy): } default: @@ -2274,14 +1530,14 @@ func (*accessRequestCollector) notifyStale() {} // OktaAssignmentWatcherConfig is a OktaAssignmentWatcher configuration. type OktaAssignmentWatcherConfig struct { - // RWCfg is the resource watcher configuration. - RWCfg ResourceWatcherConfig // OktaAssignments is responsible for fetching Okta assignments. OktaAssignments OktaAssignmentsGetter - // PageSize is the number of Okta assignments to list at a time. - PageSize int // OktaAssignmentsC receives up-to-date list of all Okta assignment resources. OktaAssignmentsC chan types.OktaAssignments + // RWCfg is the resource watcher configuration. + RWCfg ResourceWatcherConfig + // PageSize is the number of Okta assignments to list at a time. + PageSize int } // CheckAndSetDefaults checks parameters and sets default values. @@ -2346,16 +1602,16 @@ func (o *OktaAssignmentWatcher) Done() <-chan struct{} { // oktaAssignmentCollector accompanies resourceWatcher when monitoring Okta assignment resources. type oktaAssignmentCollector struct { - logger *slog.Logger // OktaAssignmentWatcherConfig is the watcher configuration. - cfg OktaAssignmentWatcherConfig - // mu guards "current" - mu sync.RWMutex + cfg OktaAssignmentWatcherConfig + logger *slog.Logger // current holds a map of the currently known Okta assignment resources. current map[string]types.OktaAssignment // initializationC is used to check that the watcher has been initialized properly. initializationC chan struct{} - once sync.Once + // mu guards "current" + mu sync.RWMutex + once sync.Once } // resourceKinds specifies the resource kind to watch. @@ -2423,7 +1679,7 @@ func (c *oktaAssignmentCollector) processEventsAndUpdateCurrent(ctx context.Cont switch event.Type { case types.OpDelete: delete(c.current, event.Resource.GetName()) - resources := resourcesToSlice(c.current) + resources := resourcesToSlice(c.current, types.OktaAssignment.Copy) select { case <-ctx.Done(): case c.cfg.OktaAssignmentsC <- resources: @@ -2435,7 +1691,7 @@ func (c *oktaAssignmentCollector) processEventsAndUpdateCurrent(ctx context.Cont continue } c.current[oktaAssignment.GetName()] = oktaAssignment - resources := resourcesToSlice(c.current) + resources := resourcesToSlice(c.current, types.OktaAssignment.Copy) select { case <-ctx.Done(): diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 3ffe202bb7087..730c7430696a4 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" ) @@ -131,7 +132,8 @@ func TestProxyWatcher(t *testing.T) { Events: local.NewEventsService(bk), }, }, - ProxiesC: make(chan []types.Server, 10), + ProxyGetter: presence, + ProxiesC: make(chan []types.Server, 10), }) require.NoError(t, err) t.Cleanup(w.Close) @@ -143,7 +145,7 @@ func TestProxyWatcher(t *testing.T) { // The first event is always the current list of proxies. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], proxy)) case <-w.Done(): @@ -158,7 +160,7 @@ func TestProxyWatcher(t *testing.T) { // Watcher should detect the proxy list change. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 2) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -171,7 +173,7 @@ func TestProxyWatcher(t *testing.T) { // Watcher should detect the proxy list change. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], proxy2)) case <-w.Done(): @@ -185,7 +187,7 @@ func TestProxyWatcher(t *testing.T) { // Watcher should detect the proxy list change. select { - case changeset := <-w.ProxiesC: + case changeset := <-w.ResourcesC: require.Empty(t, changeset) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -562,14 +564,15 @@ func TestDatabaseWatcher(t *testing.T) { Events: local.NewEventsService(bk), }, }, - DatabasesC: make(chan types.Databases, 10), + DatabaseGetter: databasesService, + DatabasesC: make(chan []types.Database, 10), }) require.NoError(t, err) t.Cleanup(w.Close) // Initially there are no databases so watcher should send an empty list. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Empty(t, changeset) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -583,7 +586,7 @@ func TestDatabaseWatcher(t *testing.T) { // The first event is always the current list of databases. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], database1)) case <-w.Done(): @@ -598,7 +601,7 @@ func TestDatabaseWatcher(t *testing.T) { // Watcher should detect the database list change. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 2) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -611,7 +614,7 @@ func TestDatabaseWatcher(t *testing.T) { // Watcher should detect the database list change. select { - case changeset := <-w.DatabasesC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], database2)) case <-w.Done(): @@ -661,14 +664,15 @@ func TestAppWatcher(t *testing.T) { Events: local.NewEventsService(bk), }, }, - AppsC: make(chan types.Apps, 10), + AppGetter: appService, + AppsC: make(chan []types.Application, 10), }) require.NoError(t, err) t.Cleanup(w.Close) // Initially there are no apps so watcher should send an empty list. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Empty(t, changeset) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -682,7 +686,7 @@ func TestAppWatcher(t *testing.T) { // The first event is always the current list of apps. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], app1)) case <-w.Done(): @@ -697,7 +701,7 @@ func TestAppWatcher(t *testing.T) { // Watcher should detect the app list change. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 2) case <-w.Done(): t.Fatal("Watcher has unexpectedly exited.") @@ -710,7 +714,7 @@ func TestAppWatcher(t *testing.T) { // Watcher should detect the database list change. select { - case changeset := <-w.AppsC: + case changeset := <-w.ResourcesC: require.Len(t, changeset, 1) require.Empty(t, resourceDiff(changeset[0], app2)) case <-w.Done(): @@ -909,6 +913,7 @@ func TestNodeWatcherFallback(t *testing.T) { }, MaxStaleness: time.Minute, }, + NodesGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -922,15 +927,14 @@ func TestNodeWatcherFallback(t *testing.T) { nodes = append(nodes, node) } - require.Empty(t, w.NodeCount()) + require.Empty(t, w.ResourceCount()) require.False(t, w.IsInitialized()) - got := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) + got, err := w.CurrentResources(ctx) + require.NoError(t, err) require.Len(t, nodes, len(got)) - require.Len(t, nodes, w.NodeCount()) + require.Len(t, nodes, w.ResourceCount()) require.False(t, w.IsInitialized()) } @@ -961,6 +965,7 @@ func TestNodeWatcher(t *testing.T) { }, MaxStaleness: time.Minute, }, + NodesGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -974,25 +979,27 @@ func TestNodeWatcher(t *testing.T) { nodes = append(nodes, node) } - require.Eventually(t, func() bool { - filtered := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) - return len(filtered) == len(nodes) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(nodes)) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") - require.Len(t, w.GetNodes(ctx, func(n services.Node) bool { return n.GetUseTunnel() }), 3) + filtered, err := w.CurrentResourcesWithFilter(ctx, func(n readonly.Server) bool { return n.GetUseTunnel() }) + require.NoError(t, err) + require.Len(t, filtered, 3) require.NoError(t, presence.DeleteNode(ctx, apidefaults.Namespace, nodes[0].GetName())) - require.Eventually(t, func() bool { - filtered := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) - return len(filtered) == len(nodes)-1 + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(nodes)-1) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") - require.Empty(t, w.GetNodes(ctx, func(n services.Node) bool { return n.GetName() == nodes[0].GetName() })) + filtered, err = w.CurrentResourcesWithFilter(ctx, func(n readonly.Server) bool { return n.GetName() == nodes[0].GetName() }) + require.NoError(t, err) + require.Empty(t, filtered) } func newNodeServer(t *testing.T, name, hostname, addr string, tunnel bool) types.Server { @@ -1032,6 +1039,7 @@ func TestKubeServerWatcher(t *testing.T) { }, MaxStaleness: time.Minute, }, + KubernetesServerGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -1057,55 +1065,66 @@ func TestKubeServerWatcher(t *testing.T) { kubeServers = append(kubeServers, kubeServer) } - require.Eventually(t, func() bool { - filtered, err := w.GetKubernetesServers(context.Background()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(context.Background()) assert.NoError(t, err) - return len(filtered) == len(kubeServers) + assert.Len(t, filtered, len(kubeServers)) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive kube servers.") // Test filtering by cluster name. - filtered, err := w.GetKubeServersByClusterName(context.Background(), kubeServers[0].GetName()) + filtered, err := w.CurrentResourcesWithFilter(context.Background(), func(ks readonly.KubeServer) bool { + return ks.GetName() == kubeServers[0].GetName() + }) require.NoError(t, err) require.Len(t, filtered, 1) // Test Deleting a kube server. require.NoError(t, presence.DeleteKubernetesServer(ctx, kubeServers[0].GetHostID(), kubeServers[0].GetName())) - require.Eventually(t, func() bool { - kube, err := w.GetKubernetesServers(context.Background()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + kube, err := w.CurrentResources(context.Background()) assert.NoError(t, err) - return len(kube) == len(kubeServers)-1 + assert.Len(t, kube, len(kubeServers)-1) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive the delete event.") - filtered, err = w.GetKubeServersByClusterName(context.Background(), kubeServers[0].GetName()) - require.Error(t, err) + filtered, err = w.CurrentResourcesWithFilter(context.Background(), func(ks readonly.KubeServer) bool { + return ks.GetName() == kubeServers[0].GetName() + }) + require.NoError(t, err) require.Empty(t, filtered) // Test adding a kube server with the same name as an existing one. kubeServer := newKubeServer(t, kubeServers[1].GetName(), "addr", uuid.NewString()) _, err = presence.UpsertKubernetesServer(ctx, kubeServer) require.NoError(t, err) - require.Eventually(t, func() bool { - filtered, err := w.GetKubeServersByClusterName(context.Background(), kubeServers[1].GetName()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResourcesWithFilter(context.Background(), func(ks readonly.KubeServer) bool { + return ks.GetName() == kubeServers[1].GetName() + }) assert.NoError(t, err) - return len(filtered) == 2 - }, time.Second, time.Millisecond, "Timeout waiting for watcher to the new registered kube server.") + assert.Len(t, filtered, 2) + }, 1000*time.Second, time.Millisecond, "Timeout waiting for watcher to the new registered kube server.") // Test deleting all kube servers with the same name. - filtered, err = w.GetKubeServersByClusterName(context.Background(), kubeServers[1].GetName()) + filtered, err = w.CurrentResourcesWithFilter(context.Background(), func(ks readonly.KubeServer) bool { + return ks.GetName() == kubeServers[1].GetName() + }) assert.NoError(t, err) for _, server := range filtered { require.NoError(t, presence.DeleteKubernetesServer(ctx, server.GetHostID(), server.GetName())) } - require.Eventually(t, func() bool { - filtered, err := w.GetKubeServersByClusterName(context.Background(), kubeServers[1].GetName()) - return len(filtered) == 0 && err != nil + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResourcesWithFilter(context.Background(), func(ks readonly.KubeServer) bool { + return ks.GetName() == kubeServers[1].GetName() + }) + assert.NoError(t, err) + assert.Empty(t, filtered) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive the two delete events.") require.NoError(t, presence.DeleteAllKubernetesServers(ctx)) - require.Eventually(t, func() bool { - filtered, err := w.GetKubernetesServers(context.Background()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(context.Background()) assert.NoError(t, err) - return len(filtered) == 0 + assert.Empty(t, filtered) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive all delete events.") } diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index 83684289cdb3b..393086b69c1c2 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport/lib/labels" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/utils" ) @@ -153,7 +154,7 @@ type Server struct { reconcileCh chan struct{} // watcher monitors changes to application resources. - watcher *services.AppWatcher + watcher *services.GenericWatcher[types.Application, readonly.Application] } // monitoredApps is a collection of applications from different sources diff --git a/lib/srv/app/watcher.go b/lib/srv/app/watcher.go index ac355fd6b9fd2..c88e73dd7f5ab 100644 --- a/lib/srv/app/watcher.go +++ b/lib/srv/app/watcher.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/utils" ) @@ -67,7 +68,7 @@ func (s *Server) startReconciler(ctx context.Context) error { // startResourceWatcher starts watching changes to application resources and // registers/unregisters the proxied applications accordingly. -func (s *Server) startResourceWatcher(ctx context.Context) (*services.AppWatcher, error) { +func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Application, readonly.Application], error) { if len(s.c.ResourceMatchers) == 0 { s.log.DebugContext(ctx, "Not initializing application resource watcher.") return nil, nil @@ -80,6 +81,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.AppWatcher // Log: s.log, Client: s.c.AccessPoint, }, + AppGetter: s.c.AccessPoint, }) if err != nil { return nil, trace.Wrap(err) @@ -88,7 +90,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.AppWatcher defer watcher.Close() for { select { - case apps := <-watcher.AppsC: + case apps := <-watcher.ResourcesC: appsWithAddr := make(types.Apps, 0, len(apps)) for _, app := range apps { appsWithAddr = append(appsWithAddr, s.guessPublicAddr(app)) diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 292615fe9f46e..10f54db7343f8 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -49,6 +49,7 @@ import ( "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/db/cassandra" "github.com/gravitational/teleport/lib/srv/db/clickhouse" @@ -313,7 +314,7 @@ type Server struct { // heartbeats holds heartbeats for database servers. heartbeats map[string]srv.HeartbeatI // watcher monitors changes to database resources. - watcher *services.DatabaseWatcher + watcher *services.GenericWatcher[types.Database, readonly.Database] // proxiedDatabases contains databases this server currently is proxying. // Proxied databases are reconciled against monitoredDatabases below. proxiedDatabases map[string]types.Database diff --git a/lib/srv/db/watcher.go b/lib/srv/db/watcher.go index 58010eb6ad8d9..b65386a981247 100644 --- a/lib/srv/db/watcher.go +++ b/lib/srv/db/watcher.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" discovery "github.com/gravitational/teleport/lib/srv/discovery/common" dbfetchers "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/utils" @@ -69,7 +70,7 @@ func (s *Server) startReconciler(ctx context.Context) error { // startResourceWatcher starts watching changes to database resources and // registers/unregisters the proxied databases accordingly. -func (s *Server) startResourceWatcher(ctx context.Context) (*services.DatabaseWatcher, error) { +func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Database, readonly.Database], error) { if len(s.cfg.ResourceMatchers) == 0 { s.log.DebugContext(ctx, "Not starting database resource watcher.") return nil, nil @@ -81,6 +82,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.DatabaseWa Logger: s.log, Client: s.cfg.AccessPoint, }, + DatabaseGetter: s.cfg.AccessPoint, }) if err != nil { return nil, trace.Wrap(err) @@ -90,7 +92,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.DatabaseWa defer watcher.Close() for { select { - case databases := <-watcher.DatabasesC: + case databases := <-watcher.ResourcesC: s.monitoredDatabases.setResources(databases) select { case s.reconcileCh <- struct{}{}: diff --git a/lib/srv/desktop/discovery.go b/lib/srv/desktop/discovery.go index c44bcc0a9cabb..9b4fcd921f9b1 100644 --- a/lib/srv/desktop/discovery.go +++ b/lib/srv/desktop/discovery.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/lib/auth/windows" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/utils" ) @@ -311,7 +312,7 @@ func (s *WindowsService) ldapEntryToWindowsDesktop(ctx context.Context, entry *l // startDynamicReconciler starts resource watcher and reconciler that registers/unregisters Windows desktops // according to the up-to-date list of dynamic Windows desktops resources. -func (s *WindowsService) startDynamicReconciler(ctx context.Context) (*services.DynamicWindowsDesktopWatcher, error) { +func (s *WindowsService) startDynamicReconciler(ctx context.Context) (*services.GenericWatcher[types.DynamicWindowsDesktop, readonly.DynamicWindowsDesktop], error) { if len(s.cfg.ResourceMatchers) == 0 { s.cfg.Logger.DebugContext(ctx, "Not starting dynamic desktop resource watcher.") return nil, nil @@ -354,7 +355,7 @@ func (s *WindowsService) startDynamicReconciler(ctx context.Context) (*services. defer watcher.Close() for { select { - case desktops := <-watcher.DynamicWindowsDesktopsC: + case desktops := <-watcher.ResourcesC: newResources = make(map[string]types.WindowsDesktop) for _, dynamicDesktop := range desktops { desktop, err := s.toWindowsDesktop(dynamicDesktop) diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 6e93a8a8eddc6..095da62b6475f 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -55,6 +55,7 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv/discovery/common" "github.com/gravitational/teleport/lib/srv/discovery/fetchers" aws_sync "github.com/gravitational/teleport/lib/srv/discovery/fetchers/aws-sync" @@ -267,7 +268,7 @@ type Server struct { // cancelfn is used with ctx when stopping the discovery server cancelfn context.CancelFunc // nodeWatcher is a node watcher. - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, readonly.Server] // ec2Watcher periodically retrieves EC2 instances. ec2Watcher *server.Watcher @@ -777,13 +778,16 @@ func (s *Server) initGCPWatchers(ctx context.Context, matchers []types.GCPMatche return nil } -func (s *Server) filterExistingEC2Nodes(instances *server.EC2Instances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingEC2Nodes(instances *server.EC2Instances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n readonly.Server) bool { labels := n.GetAllLabels() _, accountOK := labels[types.AWSAccountIDLabel] _, instanceOK := labels[types.AWSInstanceIDLabel] return accountOK && instanceOK }) + if err != nil { + return trace.Wrap(err) + } var filtered []server.EC2Instance outer: @@ -800,6 +804,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func genEC2InstancesLogStr(instances []server.EC2Instance) string { @@ -850,7 +855,9 @@ func (s *Server) handleEC2Instances(instances *server.EC2Instances) error { // EICE Nodes must never be filtered, so that we can extend their expiration and sync labels. totalInstancesFound := len(instances.Instances) if !instances.Rotation && instances.EnrollMode != types.InstallParamEnrollMode_INSTALL_PARAM_ENROLL_MODE_EICE { - s.filterExistingEC2Nodes(instances) + if err := s.filterExistingEC2Nodes(instances); err != nil { + return trace.Wrap(err) + } } instancesAlreadyEnrolled := totalInstancesFound - len(instances.Instances) @@ -904,12 +911,24 @@ func (s *Server) heartbeatEICEInstance(instances *server.EC2Instances) { continue } - existingNode, err := s.nodeWatcher.GetNode(s.ctx, eiceNode.GetName()) + existingNodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(s readonly.Server) bool { + return s.GetName() == eiceNode.GetName() + }) if err != nil && !trace.IsNotFound(err) { s.Log.Warnf("Error finding the existing node with name %q: %v", eiceNode.GetName(), err) continue } + var existingNode types.Server + switch len(existingNodes) { + case 0: + case 1: + existingNode = existingNodes[0] + default: + s.Log.Warnf("Found multiple matching nodes with name %q", eiceNode.GetName()) + continue + } + // EICE Node's Name are deterministic (based on the Account and Instance ID). // // To reduce load, nodes are skipped if @@ -1064,7 +1083,7 @@ func (s *Server) findUnrotatedEC2Nodes(ctx context.Context) ([]types.Server, err if err != nil { return nil, trace.Wrap(err) } - found := s.nodeWatcher.GetNodes(ctx, func(n services.Node) bool { + found, err := s.nodeWatcher.CurrentResourcesWithFilter(ctx, func(n readonly.Server) bool { if n.GetSubKind() != types.SubKindOpenSSHNode { return false } @@ -1077,6 +1096,9 @@ func (s *Server) findUnrotatedEC2Nodes(ctx context.Context) ([]types.Server, err return mostRecentCertRotation.After(n.GetRotation().LastRotated) }) + if err != nil { + return nil, trace.Wrap(err) + } if len(found) == 0 { return nil, trace.NotFound("no unrotated nodes found") @@ -1118,13 +1140,18 @@ func (s *Server) handleEC2Discovery() { } } -func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n readonly.Server) bool { labels := n.GetAllLabels() _, subscriptionOK := labels[types.SubscriptionIDLabel] _, vmOK := labels[types.VMIDLabel] return subscriptionOK && vmOK }) + + if err != nil { + return trace.Wrap(err) + } + var filtered []*armcompute.VirtualMachine outer: for _, inst := range instances.Instances { @@ -1144,6 +1171,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func (s *Server) handleAzureInstances(instances *server.AzureInstances) error { @@ -1151,7 +1179,9 @@ func (s *Server) handleAzureInstances(instances *server.AzureInstances) error { if err != nil { return trace.Wrap(err) } - s.filterExistingAzureNodes(instances) + if err := s.filterExistingAzureNodes(instances); err != nil { + return trace.Wrap(err) + } if len(instances.Instances) == 0 { return trace.Wrap(errNoInstances) } @@ -1206,14 +1236,19 @@ func (s *Server) handleAzureDiscovery() { } } -func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n readonly.Server) bool { labels := n.GetAllLabels() _, projectIDOK := labels[types.ProjectIDLabelDiscovery] _, zoneOK := labels[types.ZoneLabelDiscovery] _, nameOK := labels[types.NameLabelDiscovery] return projectIDOK && zoneOK && nameOK }) + + if err != nil { + return trace.Wrap(err) + } + var filtered []*gcpimds.Instance outer: for _, inst := range instances.Instances { @@ -1230,6 +1265,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func (s *Server) handleGCPInstances(instances *server.GCPInstances) error { @@ -1237,7 +1273,9 @@ func (s *Server) handleGCPInstances(instances *server.GCPInstances) error { if err != nil { return trace.Wrap(err) } - s.filterExistingGCPNodes(instances) + if err := s.filterExistingGCPNodes(instances); err != nil { + return trace.Wrap(err) + } if len(instances.Instances) == 0 { return trace.Wrap(errNoInstances) } @@ -1730,6 +1768,7 @@ func (s *Server) initTeleportNodeWatcher() (err error) { Client: s.AccessPoint, MaxStaleness: time.Minute, }, + NodesGetter: s.AccessPoint, }) return trace.Wrap(err) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 974dd1bfcdad1..ab384e65d74f5 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -852,11 +852,10 @@ func TestDiscoveryServerConcurrency(t *testing.T) { // We must get only one EC2 EICE Node. // Even when two servers are discovering the same EC2 Instance, they will use the same name when converting to EICE Node. - require.Eventually(t, func() bool { + require.EventuallyWithT(t, func(t *assert.CollectT) { allNodes, err := tlsServer.Auth().GetNodes(ctx, "default") - require.NoError(t, err) - - return len(allNodes) == 1 + assert.NoError(t, err) + assert.Len(t, allNodes, 1) }, 1*time.Second, 50*time.Millisecond) // We should never get a duplicate instance. diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 42df4c9d4017c..8d6a64154ee9b 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -70,6 +70,7 @@ import ( "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" sess "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" @@ -2865,12 +2866,13 @@ func newLockWatcher(ctx context.Context, t testing.TB, client types.Events) *ser return lockWatcher } -func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *services.NodeWatcher { +func newNodeWatcher(ctx context.Context, t *testing.T, client *authclient.Client) *services.GenericWatcher[types.Server, readonly.Server] { nodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: "test", Client: client, }, + NodesGetter: client, }) require.NoError(t, err) t.Cleanup(nodeWatcher.Close) diff --git a/lib/utils/fncache.go b/lib/utils/fncache.go index e45a8b3a2d821..84f5be17478bb 100644 --- a/lib/utils/fncache.go +++ b/lib/utils/fncache.go @@ -245,6 +245,8 @@ func FnCacheGetWithTTL[T any](ctx context.Context, cache *FnCache, key any, ttl switch { case err != nil: return ret, err + case t == nil: + return ret, nil case !ok: return ret, trace.BadParameter("value retrieved was %T, expected %T", t, ret) } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index faaffa2ff3e1e..d45effc964395 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -98,6 +98,7 @@ import ( "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/secret" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -164,7 +165,7 @@ type Handler struct { // nodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, readonly.Server] // tracer is used to create spans. tracer oteltrace.Tracer @@ -300,7 +301,7 @@ type Config struct { // NodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. - NodeWatcher *services.NodeWatcher + NodeWatcher *services.GenericWatcher[types.Server, readonly.Server] // PresenceChecker periodically runs the mfa ceremony for moderated // sessions. @@ -3544,9 +3545,12 @@ func (h *Handler) siteNodeConnect( WebsocketConn: ws, SSHDialTimeout: dialTimeout, HostNameResolver: func(serverID string) (string, error) { - matches := nw.GetNodes(r.Context(), func(n services.Node) bool { + matches, err := nw.CurrentResourcesWithFilter(r.Context(), func(n readonly.Server) bool { return n.GetName() == serverID }) + if err != nil { + return "", trace.Wrap(err) + } if len(matches) != 1 { return "", trace.NotFound("unable to resolve hostname for server %s", serverID) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5e3fea922d245..a585579f7ddda 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -381,6 +381,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { Component: teleport.ComponentProxy, Client: s.proxyClient, }, + NodesGetter: s.proxyClient, }) require.NoError(t, err) @@ -8185,6 +8186,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula Component: teleport.ComponentProxy, Client: client, }, + NodesGetter: client, }) require.NoError(t, err) t.Cleanup(proxyNodeWatcher.Close) @@ -9075,6 +9077,7 @@ func startKubeWithoutCleanup(ctx context.Context, t *testing.T, cfg startKubeOpt Client: client, Clock: clock, }, + KubernetesServerGetter: client, }) require.NoError(t, err)