From 57c9c86615e6020ce5843e79660b39a884cd185e Mon Sep 17 00:00:00 2001 From: Gyuho Lee <6799218+gyuho@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:47:01 +0800 Subject: [PATCH] feat(internal/server): dynamically refresh containerd, docker, kubelet components (#78) * feat(internal/server): dynamically refresh containerd, docker, kubelet components Signed-off-by: Gyuho Lee * clean up dmesg Signed-off-by: Gyuho Lee * use binary locate for default + fallback Signed-off-by: Gyuho Lee --------- Signed-off-by: Gyuho Lee --- cmd/gpud/command/command.go | 9 +- components/docker/container/config.go | 3 +- config/config.go | 7 + config/default.go | 223 +++++++++++++++++--------- internal/server/handlers.go | 11 +- internal/server/server.go | 140 +++++++++++++++- pkg/file/file.go | 33 ++++ pkg/file/file_test.go | 13 ++ 8 files changed, 347 insertions(+), 92 deletions(-) create mode 100644 pkg/file/file.go create mode 100644 pkg/file/file_test.go diff --git a/cmd/gpud/command/command.go b/cmd/gpud/command/command.go index 3f711ae8..9f7da594 100644 --- a/cmd/gpud/command/command.go +++ b/cmd/gpud/command/command.go @@ -28,7 +28,8 @@ var ( pprof bool - retentionPeriod time.Duration + retentionPeriod time.Duration + refreshComponentsInterval time.Duration webEnable bool webAdmin bool @@ -157,6 +158,12 @@ sudo rm /etc/systemd/system/gpud.service Destination: &retentionPeriod, Value: config.DefaultRetentionPeriod.Duration, }, + &cli.DurationFlag{ + Name: "refresh-components-interval", + Usage: "set the time period to refresh selected components", + Destination: &refreshComponentsInterval, + Value: config.DefaultRefreshComponentsInterval.Duration, + }, &cli.BoolTFlag{ Name: "web-enable", Usage: "enable local web interface (default: true)", diff --git a/components/docker/container/config.go b/components/docker/container/config.go index 2bf66a27..0fd5f54b 100644 --- a/components/docker/container/config.go +++ b/components/docker/container/config.go @@ -8,8 +8,7 @@ import ( ) type Config struct { - Query query_config.Config `json:"query"` - Endpoint string `json:"endpoint"` + Query query_config.Config `json:"query"` } func ParseConfig(b any, db *sql.DB) (*Config, error) { diff --git a/config/config.go b/config/config.go index 0fe46406..5b98da54 100644 --- a/config/config.go +++ b/config/config.go @@ -33,6 +33,10 @@ type Config struct { // Once elapsed, old states/metrics are purged/compacted. RetentionPeriod metav1.Duration `json:"retention_period"` + // Interval at which to refresh selected components. + // Disables refresh if not set. + RefreshComponentsInterval metav1.Duration `json:"refresh_components_interval"` + // Set true to enable profiler. Pprof bool `json:"pprof"` @@ -65,6 +69,9 @@ func (config *Config) Validate() error { if config.RetentionPeriod.Duration < time.Minute { return fmt.Errorf("retention_period must be at least 1 minute, got %d", config.RetentionPeriod.Duration) } + if config.RefreshComponentsInterval.Duration < time.Minute { + return fmt.Errorf("refresh_components_interval must be at least 1 minute, got %d", config.RefreshComponentsInterval.Duration) + } if config.Web != nil && config.Web.RefreshPeriod.Duration < time.Minute { return fmt.Errorf("web_refresh_period must be at least 1 minute, got %d", config.Web.RefreshPeriod.Duration) } diff --git a/config/default.go b/config/default.go index 0ddaf9b3..7cca5743 100644 --- a/config/default.go +++ b/config/default.go @@ -45,6 +45,7 @@ import ( component_systemd "github.com/leptonai/gpud/components/systemd" "github.com/leptonai/gpud/components/tailscale" "github.com/leptonai/gpud/log" + pkg_file "github.com/leptonai/gpud/pkg/file" pkd_systemd "github.com/leptonai/gpud/pkg/systemd" "github.com/leptonai/gpud/systemd" "github.com/leptonai/gpud/version" @@ -59,8 +60,9 @@ const ( ) var ( - DefaultRefreshPeriod = metav1.Duration{Duration: time.Minute} - DefaultRetentionPeriod = metav1.Duration{Duration: 30 * time.Minute} + DefaultRefreshPeriod = metav1.Duration{Duration: time.Minute} + DefaultRetentionPeriod = metav1.Duration{Duration: 30 * time.Minute} + DefaultRefreshComponentsInterval = metav1.Duration{Duration: time.Minute} ) func DefaultConfig(ctx context.Context, opts ...OpOption) (*Config, error) { @@ -69,8 +71,6 @@ func DefaultConfig(ctx context.Context, opts ...OpOption) (*Config, error) { return nil, err } - asRoot := stdos.Geteuid() == 0 // running as root - cfg := &Config{ APIVersion: DefaultAPIVersion, @@ -90,8 +90,9 @@ func DefaultConfig(ctx context.Context, opts ...OpOption) (*Config, error) { os.Name: nil, }, - RetentionPeriod: DefaultRetentionPeriod, - Pprof: false, + RetentionPeriod: DefaultRetentionPeriod, + RefreshComponentsInterval: DefaultRefreshComponentsInterval, + Pprof: false, Web: &Web{ Enable: true, @@ -107,86 +108,22 @@ func DefaultConfig(ctx context.Context, opts ...OpOption) (*Config, error) { cfg.Components[file.Name] = options.filesToCheck } - if runtime.GOOS == "linux" { - containerdSocketExists := false - containerdRunning := false - - if _, err := stdos.Stat(containerd_pod.DefaultSocketFile); err == nil { - log.Logger.Debugw("containerd default socket file exists, containerd installed", "file", containerd_pod.DefaultSocketFile) - containerdSocketExists = true - } else { - log.Logger.Debugw("containerd default socket file does not exist, skip containerd check", "file", containerd_pod.DefaultSocketFile, "error", err) - } - - cctx, ccancel := context.WithTimeout(ctx, 5*time.Second) - defer ccancel() - if _, _, conn, err := containerd_pod.Connect(cctx, containerd_pod.DefaultContainerRuntimeEndpoint); err == nil { - log.Logger.Debugw("containerd default cri endpoint open, containerd running", "endpoint", containerd_pod.DefaultContainerRuntimeEndpoint) - containerdRunning = true - _ = conn.Close() - } else { - log.Logger.Debugw("containerd default cri endpoint not open, skip containerd checking", "endpoint", containerd_pod.DefaultContainerRuntimeEndpoint, "error", err) - } - - if containerdSocketExists && containerdRunning { - log.Logger.Debugw("auto-detected containerd -- configuring containerd pod component") - cfg.Components[containerd_pod.Name] = containerd_pod.Config{ - Query: query_config.DefaultConfig(), - Endpoint: containerd_pod.DefaultContainerRuntimeEndpoint, - } - } - } else { - log.Logger.Debugw("ignoring default containerd pod checking since it's not linux", "os", runtime.GOOS) + if cc, exists := DefaultDockerContainerComponent(ctx); exists { + cfg.Components[docker_container.Name] = cc } - - if runtime.GOOS == "linux" { - // check if the TCP port is open/used - conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", k8s_pod.DefaultKubeletReadOnlyPort), 3*time.Second) - if err != nil { - log.Logger.Debugw("tcp port is not open", "port", k8s_pod.DefaultKubeletReadOnlyPort, "error", err) - } else { - log.Logger.Debugw("tcp port is open", "port", k8s_pod.DefaultKubeletReadOnlyPort) - conn.Close() - - kerr := k8s_pod.CheckKubeletReadOnlyPort(ctx, k8s_pod.DefaultKubeletReadOnlyPort) - // check - if kerr != nil { - log.Logger.Debugw("kubelet readonly port is not open", "port", k8s_pod.DefaultKubeletReadOnlyPort, "error", kerr) - } else { - log.Logger.Debugw("auto-detected kubelet readonly port -- configuring k8s pod components", "port", k8s_pod.DefaultKubeletReadOnlyPort) - - // "k8s_pod" requires kubelet read-only port - // assume if kubelet is running, it opens the most common read-only port 10255 - cfg.Components[k8s_pod.Name] = k8s_pod.Config{ - Query: query_config.DefaultConfig(), - Port: k8s_pod.DefaultKubeletReadOnlyPort, - } - } - } - } else { - log.Logger.Debugw("ignoring default kubelet checking since it's not linux", "os", runtime.GOOS) + if cc, exists := DefaultContainerdComponent(ctx); exists { + cfg.Components[containerd_pod.Name] = cc } - - if docker_container.IsDockerRunning() { - log.Logger.Debugw("auto-detected docker -- configuring docker container component") - cfg.Components[docker_container.Name] = nil + if cc, exists := DefaultK8sPodComponent(ctx); exists { + cfg.Components[k8s_pod.Name] = cc } if _, err := stdos.Stat(power_supply.DefaultBatteryCapacityFile); err == nil { cfg.Components[power_supply.Name] = nil } - if runtime.GOOS == "linux" { - if dmesg.DmesgExists() { - if asRoot { - log.Logger.Debugw("auto-detected dmesg -- configuring dmesg component") - cfg.Components[dmesg.Name] = dmesg.DefaultConfig() - } else { - log.Logger.Debugw("auto-detected dmesg but running as root -- skipping") - } - } - } else { - log.Logger.Debugw("auto-detect dmesg not supported -- skipping", "os", runtime.GOOS) + if cc, exists := DefaultDmesgComponent(); exists { + cfg.Components[dmesg.Name] = cc } if runtime.GOOS == "linux" { @@ -351,3 +288,133 @@ func DefaultFifoFile() (string, error) { } return filepath.Join(f, "gpud.fifo"), nil } + +func DefaultContainerdComponent(ctx context.Context) (any, bool) { + if runtime.GOOS != "linux" { + log.Logger.Debugw("ignoring default containerd pod checking since it's not linux", "os", runtime.GOOS) + return nil, false + } + + p, err := pkg_file.LocateExecutable("containerd") + if p != "" && err == nil { + log.Logger.Debugw("containerd found in PATH", "path", p) + return containerd_pod.Config{ + Query: query_config.DefaultConfig(), + Endpoint: containerd_pod.DefaultContainerRuntimeEndpoint, + }, true + } + log.Logger.Debugw("containerd not found in PATH -- fallback to containerd run checks", "error", err) + + containerdSocketExists := false + containerdRunning := false + + if _, err := stdos.Stat(containerd_pod.DefaultSocketFile); err == nil { + log.Logger.Debugw("containerd default socket file exists, containerd installed", "file", containerd_pod.DefaultSocketFile) + containerdSocketExists = true + } else { + log.Logger.Debugw("containerd default socket file does not exist, skip containerd check", "file", containerd_pod.DefaultSocketFile, "error", err) + } + + cctx, ccancel := context.WithTimeout(ctx, 5*time.Second) + defer ccancel() + + if _, _, conn, err := containerd_pod.Connect(cctx, containerd_pod.DefaultContainerRuntimeEndpoint); err == nil { + log.Logger.Debugw("containerd default cri endpoint open, containerd running", "endpoint", containerd_pod.DefaultContainerRuntimeEndpoint) + containerdRunning = true + _ = conn.Close() + } else { + log.Logger.Debugw("containerd default cri endpoint not open, skip containerd checking", "endpoint", containerd_pod.DefaultContainerRuntimeEndpoint, "error", err) + } + + if containerdSocketExists && containerdRunning { + log.Logger.Debugw("auto-detected containerd -- configuring containerd pod component") + return containerd_pod.Config{ + Query: query_config.DefaultConfig(), + Endpoint: containerd_pod.DefaultContainerRuntimeEndpoint, + }, true + } + return nil, false +} + +func DefaultDockerContainerComponent(ctx context.Context) (any, bool) { + p, err := pkg_file.LocateExecutable("docker") + if p != "" && err == nil { + log.Logger.Debugw("docker found in PATH", "path", p) + return docker_container.Config{ + Query: query_config.DefaultConfig(), + }, true + } + log.Logger.Debugw("docker not found in PATH -- fallback to docker run checks", "error", err) + + if docker_container.IsDockerRunning() { + log.Logger.Debugw("auto-detected docker -- configuring docker container component") + return docker_container.Config{ + Query: query_config.DefaultConfig(), + }, true + } + return nil, false +} + +func DefaultK8sPodComponent(ctx context.Context) (any, bool) { + if runtime.GOOS != "linux" { + log.Logger.Debugw("ignoring default kubelet checking since it's not linux", "os", runtime.GOOS) + return nil, false + } + + p, err := pkg_file.LocateExecutable("kubelet") + if p != "" && err == nil { + log.Logger.Debugw("kubelet found in PATH", "path", p) + return k8s_pod.Config{ + Query: query_config.DefaultConfig(), + Port: k8s_pod.DefaultKubeletReadOnlyPort, + }, true + } + log.Logger.Debugw("kubelet not found in PATH -- fallback to kubelet run checks", "error", err) + + // check if the TCP port is open/used + conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", k8s_pod.DefaultKubeletReadOnlyPort), 3*time.Second) + if err != nil { + log.Logger.Debugw("tcp port is not open", "port", k8s_pod.DefaultKubeletReadOnlyPort, "error", err) + } else { + log.Logger.Debugw("tcp port is open", "port", k8s_pod.DefaultKubeletReadOnlyPort) + conn.Close() + + kerr := k8s_pod.CheckKubeletReadOnlyPort(ctx, k8s_pod.DefaultKubeletReadOnlyPort) + // check + if kerr != nil { + log.Logger.Debugw("kubelet readonly port is not open", "port", k8s_pod.DefaultKubeletReadOnlyPort, "error", kerr) + } else { + log.Logger.Debugw("auto-detected kubelet readonly port -- configuring k8s pod components", "port", k8s_pod.DefaultKubeletReadOnlyPort) + + // "k8s_pod" requires kubelet read-only port + // assume if kubelet is running, it opens the most common read-only port 10255 + return k8s_pod.Config{ + Query: query_config.DefaultConfig(), + Port: k8s_pod.DefaultKubeletReadOnlyPort, + }, true + } + } + + return nil, false +} + +func DefaultDmesgComponent() (any, bool) { + if runtime.GOOS != "linux" { + log.Logger.Debugw("ignoring default dmesg since it's not linux", "os", runtime.GOOS) + return nil, false + } + + asRoot := stdos.Geteuid() == 0 // running as root + if !asRoot { + log.Logger.Debugw("auto-detected dmesg but running as root -- skipping") + return nil, false + } + + if dmesg.DmesgExists() { + log.Logger.Debugw("auto-detected dmesg -- configuring dmesg component") + return dmesg.DefaultConfig(), true + } + + log.Logger.Debugw("dmesg does not exist -- skipping dmesg component") + return nil, false +} diff --git a/internal/server/handlers.go b/internal/server/handlers.go index 0c1e27ec..961b299f 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -7,6 +7,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" lep_components "github.com/leptonai/gpud/components" @@ -17,9 +18,11 @@ import ( ) type globalHandler struct { - cfg *lep_config.Config - components map[string]lep_components.Component - componentNames []string + cfg *lep_config.Config + components map[string]lep_components.Component + + componentNamesMu sync.RWMutex + componentNames []string } func newGlobalHandler(cfg *lep_config.Config, components map[string]lep_components.Component) *globalHandler { @@ -61,6 +64,8 @@ func (g *globalHandler) getReqTime(c *gin.Context) (time.Time, time.Time, error) func (g *globalHandler) getReqComponents(c *gin.Context) ([]string, error) { components := c.Query("components") if components == "" { + g.componentNamesMu.RLock() + defer g.componentNamesMu.RUnlock() return g.componentNames, nil } diff --git a/internal/server/server.go b/internal/server/server.go index ccf08bab..a58d2a7e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -728,15 +728,15 @@ func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID case <-ctx.Done(): return case <-ticker.C: - if err := state.Compact(ctx, db); err != nil { - log.Logger.Errorw("failed to compact state database", "error", err) - } - if err := state.RecordMetrics(ctx, db); err != nil { - log.Logger.Errorw("failed to record metrics", "error", err) - } - ticker.Reset(config.RetentionPeriod.Duration) } + + if err := state.Compact(ctx, db); err != nil { + log.Logger.Errorw("failed to compact state database", "error", err) + } + if err := state.RecordMetrics(ctx, db); err != nil { + log.Logger.Errorw("failed to record metrics", "error", err) + } } }() } @@ -747,7 +747,9 @@ func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID } var componentNames []string + componentSet := make(map[string]struct{}) for _, c := range allComponents { + componentSet[c.Name()] = struct{}{} componentNames = append(componentNames, c.Name()) if strings.Contains(c.Name(), "nvidia") { s.nvidiaComponentsExist = true @@ -820,7 +822,8 @@ func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID // the middleware automatically gzip-compresses the response with the response header "Content-Encoding: gzip" v1.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithExcludedPaths([]string{"/update/"}))) - registeredPaths := newGlobalHandler(config, components.GetAllComponents()).registerComponentRoutes(v1) + ghler := newGlobalHandler(config, components.GetAllComponents()) + registeredPaths := ghler.registerComponentRoutes(v1) for i := range registeredPaths { registeredPaths[i].Path = path.Join(v1.BasePath(), registeredPaths[i].Path) } @@ -871,6 +874,127 @@ func New(ctx context.Context, config *lepconfig.Config, endpoint string, cliUID } } + // refresh components in case containerd, docker, or k8s kubelet starts afterwards + if config.RefreshComponentsInterval.Duration > 0 { + go func() { + ticker := time.NewTicker(config.RefreshComponentsInterval.Duration) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + ticker.Reset(config.RefreshComponentsInterval.Duration) + } + + componentsToAdd := make([]components.Component, 0) + + // NOTE: systemd unit update still requires gpud restarts + for _, name := range []string{ + containerd_pod.Name, + docker_container.Name, + k8s_pod.Name, + } { + if _, ok := componentSet[name]; ok { + continue + } + + if cc, exists := lepconfig.DefaultContainerdComponent(ctx); exists { + ccfg := containerd_pod.Config{Query: defaultQueryCfg} + if cc != nil { + parsed, err := containerd_pod.ParseConfig(cc, db) + if err != nil { + log.Logger.Errorw("failed to parse component %s config: %w", name, err) + continue + } + ccfg = *parsed + } + if err := ccfg.Validate(); err != nil { + log.Logger.Errorw("failed to validate component %s config: %w", name, err) + continue + } + componentsToAdd = append(componentsToAdd, containerd_pod.New(ctx, ccfg)) + } + + if cc, exists := lepconfig.DefaultDockerContainerComponent(ctx); exists { + ccfg := docker_container.Config{Query: defaultQueryCfg} + if cc != nil { + parsed, err := docker_container.ParseConfig(cc, db) + if err != nil { + log.Logger.Errorw("failed to parse component %s config: %w", name, err) + continue + } + ccfg = *parsed + } + if err := ccfg.Validate(); err != nil { + log.Logger.Errorw("failed to validate component %s config: %w", name, err) + continue + } + componentsToAdd = append(componentsToAdd, docker_container.New(ctx, ccfg)) + } + + if cc, exists := lepconfig.DefaultK8sPodComponent(ctx); exists { + ccfg := k8s_pod.Config{Query: defaultQueryCfg} + if cc != nil { + parsed, err := k8s_pod.ParseConfig(cc, db) + if err != nil { + log.Logger.Errorw("failed to parse component %s config: %w", name, err) + continue + } + ccfg = *parsed + } + if err := ccfg.Validate(); err != nil { + log.Logger.Errorw("failed to validate component %s config: %w", name, err) + continue + } + componentsToAdd = append(componentsToAdd, k8s_pod.New(ctx, ccfg)) + } + } + + if len(componentsToAdd) == 0 { + continue + } + + for i := range componentsToAdd { + metrics.SetRegistered(componentsToAdd[i].Name()) + componentsToAdd[i] = metrics.NewWatchableComponent(componentsToAdd[i]) + + // fails if already registered + if err := components.RegisterComponent(componentsToAdd[i].Name(), componentsToAdd[i]); err != nil { + log.Logger.Warnw("failed to register component", "name", componentsToAdd[i].Name(), "error", err) + continue + } + + if orig, ok := componentsToAdd[i].(interface{ Unwrap() interface{} }); ok { + if prov, ok := orig.Unwrap().(components.PromRegisterer); ok { + log.Logger.Debugw("registering prometheus collectors", "component", componentsToAdd[i].Name()) + if err := prov.RegisterCollectors(promReg, db, components_metrics_state.DefaultTableName); err != nil { + log.Logger.Errorw("failed to register metrics for component", "component", componentsToAdd[i].Name(), "error", err) + } + } else { + log.Logger.Debugw("component does not implement components.PromRegisterer", "component", componentsToAdd[i].Name()) + } + } else { + log.Logger.Debugw("component does not implement interface{ Unwrap() interface{} }", "component", componentsToAdd[i].Name()) + } + } + + newComponentNames := make([]string, len(componentNames)) + copy(newComponentNames, componentNames) + for _, c := range componentsToAdd { + newComponentNames = append(newComponentNames, c.Name()) + } + if err = state.UpdateComponents(ctx, db, s.uid, strings.Join(newComponentNames, ",")); err != nil { + log.Logger.Errorw("failed to update components", "error", err) + } + + ghler.componentNamesMu.Lock() + ghler.componentNames = newComponentNames + ghler.componentNamesMu.Unlock() + } + }() + } + go s.updateToken(ctx, db, uid, endpoint) go func() { diff --git a/pkg/file/file.go b/pkg/file/file.go new file mode 100644 index 00000000..b716601c --- /dev/null +++ b/pkg/file/file.go @@ -0,0 +1,33 @@ +// Package file implements file utils. +package file + +import ( + "fmt" + "os" + "os/exec" +) + +func LocateExecutable(bin string) (string, error) { + execPath, err := exec.LookPath(bin) + if err == nil { + return execPath, CheckExecutable(execPath) + } + return "", fmt.Errorf("executable %q not found in PATH: %w", bin, err) +} + +func CheckExecutable(file string) error { + s, err := os.Stat(file) + if err != nil { + return err + } + + if s.IsDir() { + return fmt.Errorf("%q is a directory", file) + } + + if s.Mode()&0111 == 0 { + return fmt.Errorf("%q is not executable", file) + } + + return nil +} diff --git a/pkg/file/file_test.go b/pkg/file/file_test.go new file mode 100644 index 00000000..876bf363 --- /dev/null +++ b/pkg/file/file_test.go @@ -0,0 +1,13 @@ +package file + +import ( + "testing" +) + +func TestLocateExecutable(t *testing.T) { + execPath, err := LocateExecutable("ls") + if err != nil { + t.Fatalf("LocateExecutable() failed: %v", err) + } + t.Logf("found executable %q", execPath) +}