diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index 727763dec094a..957e90779c2ab 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -29,6 +29,7 @@ import ( "log/slog" "net/http" "os" + "path" "path/filepath" "runtime" "syscall" @@ -42,26 +43,23 @@ import ( ) const ( - checksumType = "sha256" + // checksumType for Teleport tgzs + checksumType = "sha256" + // checksumHexLen is the length of the Teleport checksum. checksumHexLen = sha256.Size * 2 // bytes to hex + // maxServiceFileSize is the maximum size allowed for a systemd service file. + maxServiceFileSize = 1_000_000 // 1 MB + // configFileMode is the mode used for new configuration files. + configFileMode = 0644 + // systemDirMode is the mode used for new directories. + systemDirMode = 0755 ) var ( - // tgzExtractPaths describes how to extract the Teleport tgz. - // See utils.Extract for more details on how this list is parsed. - // Paths must use tarball-style / separators (not filepath). - tgzExtractPaths = []utils.ExtractPath{ - {Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service", DirMode: 0755}, - {Src: "teleport/examples", Skip: true, DirMode: 0755}, - {Src: "teleport/install", Skip: true, DirMode: 0755}, - {Src: "teleport/README.md", Dst: "share/README.md", DirMode: 0755}, - {Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md", DirMode: 0755}, - {Src: "teleport/VERSION", Dst: "share/VERSION", DirMode: 0755}, - {Src: "teleport", Dst: "bin", DirMode: 0755}, - } - - // servicePath contains the path to the Teleport SystemD service within the version directory. - servicePath = filepath.Join("etc", "systemd", "teleport.service") + // serviceDir contains the relative path to the Teleport SystemD service dir. + serviceDir = filepath.Join("lib", "systemd", "system") + // serviceName contains the name of the Teleport SystemD service file. + serviceName = "teleport.service" ) // LocalInstaller manages the creation and removal of installations @@ -71,8 +69,12 @@ type LocalInstaller struct { InstallDir string // LinkBinDir contains symlinks to the linked installation's binaries. LinkBinDir string - // LinkServiceDir contains a symlink to the linked installation's systemd service. + // LinkServiceDir contains a copy of the linked installation's systemd service. LinkServiceDir string + // SystemBinDir contains binaries for the system (packaged) install of Teleport. + SystemBinDir string + // SystemServiceDir contains the systemd service file for the system (packaged) install of Teleport. + SystemServiceDir string // HTTP is an HTTP client for downloading Teleport. HTTP *http.Client // Log contains a logger. @@ -210,11 +212,11 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string, }() // Extract tgz into version directory. - if err := li.extract(ctx, versionDir, f, n); err != nil { + if err := li.extract(ctx, versionDir, f, n, flags); err != nil { return trace.Errorf("failed to extract teleport: %w", err) } // Write the checksum last. This marks the version directory as valid. - err = renameio.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), 0755) + err = renameio.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), configFileMode) if err != nil { return trace.Errorf("failed to write checksum: %w", err) } @@ -335,8 +337,8 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64, return shaReader.Sum(nil), nil } -func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64) error { - if err := os.MkdirAll(dstDir, 0755); err != nil { +func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64, flags InstallFlags) error { + if err := os.MkdirAll(dstDir, systemDirMode); err != nil { return trace.Wrap(err) } free, err := utils.FreeDiskWithReserve(dstDir, li.ReservedFreeInstallDisk) @@ -353,13 +355,32 @@ func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Rea } li.Log.InfoContext(ctx, "Extracting Teleport tarball.", "path", dstDir, "size", max) - err = utils.Extract(zr, dstDir, tgzExtractPaths...) + err = utils.Extract(zr, dstDir, tgzExtractPaths(flags&(FlagEnterprise|FlagFIPS) != 0)...) if err != nil { return trace.Wrap(err) } return nil } +// tgzExtractPaths describes how to extract the Teleport tgz. +// See utils.Extract for more details on how this list is parsed. +// Paths must use tarball-style / separators (not filepath). +func tgzExtractPaths(ent bool) []utils.ExtractPath { + prefix := "teleport" + if ent { + prefix += "-ent" + } + return []utils.ExtractPath{ + {Src: path.Join(prefix, "examples/systemd/teleport.service"), Dst: filepath.Join(serviceDir, serviceName), DirMode: systemDirMode}, + {Src: path.Join(prefix, "examples"), Skip: true, DirMode: systemDirMode}, + {Src: path.Join(prefix, "install"), Skip: true, DirMode: systemDirMode}, + {Src: path.Join(prefix, "README.md"), Dst: "share/README.md", DirMode: systemDirMode}, + {Src: path.Join(prefix, "CHANGELOG.md"), Dst: "share/CHANGELOG.md", DirMode: systemDirMode}, + {Src: path.Join(prefix, "VERSION"), Dst: "share/VERSION", DirMode: systemDirMode}, + {Src: prefix, Dst: "bin", DirMode: systemDirMode}, + } +} + func uncompressedSize(f io.Reader) (int64, error) { // NOTE: The gzip length trailer is very unreliable, // but we could optimize this in the future if @@ -395,24 +416,76 @@ func (li *LocalInstaller) List(ctx context.Context) (versions []string, err erro // The revert function restores the previous linking. // See Installer interface for additional specs. func (li *LocalInstaller) Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) { - // setup revert function - type symlink struct { - old, new string + revert = func(context.Context) bool { return true } + versionDir, err := li.versionDir(version) + if err != nil { + return revert, trace.Wrap(err) } - var revertLinks []symlink + revert, err = li.forceLinks(ctx, + filepath.Join(versionDir, "bin"), + filepath.Join(versionDir, serviceDir), + ) + if err != nil { + return revert, trace.Wrap(err) + } + return revert, nil +} + +// LinkSystem links the system (package) version into the system LinkBinDir and LinkServiceDir. +// The revert function restores the previous linking. +// See Installer interface for additional specs. +func (li *LocalInstaller) LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) { + revert, err = li.forceLinks(ctx, li.SystemBinDir, li.SystemServiceDir) + return revert, trace.Wrap(err) +} + +// symlink from oldname to newname +type symlink struct { + oldname, newname string +} + +// smallFile is a file small enough to be stored in memory. +type smallFile struct { + name string + data []byte + mode os.FileMode +} + +// forceLinks replaces binary links and service files using files in binDir and svcDir. +// Existing links and files are replaced, but mismatched links and files will result in error. +// forceLinks will revert any overridden links or files if it hits an error. +// If successful, forceLinks may also be reverted after it returns by calling revert. +// The revert function returns true if reverting succeeds. +func (li *LocalInstaller) forceLinks(ctx context.Context, binDir, svcDir string) (revert func(context.Context) bool, err error) { + // setup revert function + var ( + revertLinks []symlink + revertFiles []smallFile + ) revert = func(ctx context.Context) bool { // This function is safe to call repeatedly. - // Returns true only when all symlinks are successfully reverted. - var keep []symlink + // Returns true only when all changes are successfully reverted. + var ( + keepLinks []symlink + keepFiles []smallFile + ) for _, l := range revertLinks { - err := renameio.Symlink(l.old, l.new) + err := renameio.Symlink(l.oldname, l.newname) + if err != nil { + keepLinks = append(keepLinks, l) + li.Log.ErrorContext(ctx, "Failed to revert symlink", "oldname", l.oldname, "newname", l.newname, errorKey, err) + } + } + for _, f := range revertFiles { + err := renameio.WriteFile(f.name, f.data, f.mode) if err != nil { - keep = append(keep, l) - li.Log.ErrorContext(ctx, "Failed to revert symlink", "old", l.old, "new", l.new, "err", err) + keepFiles = append(keepFiles, f) + li.Log.ErrorContext(ctx, "Failed to revert files", "name", f.name, errorKey, err) } } - revertLinks = keep - return len(revertLinks) == 0 + revertLinks = keepLinks + revertFiles = keepFiles + return len(revertLinks) == 0 && len(revertFiles) == 0 } // revert immediately on error, so caller can ignore revert arg defer func() { @@ -421,24 +494,18 @@ func (li *LocalInstaller) Link(ctx context.Context, version string) (revert func } }() - versionDir, err := li.versionDir(version) - if err != nil { - return revert, trace.Wrap(err) - } - // ensure target directories exist before trying to create links - err = os.MkdirAll(li.LinkBinDir, 0755) + err = os.MkdirAll(li.LinkBinDir, systemDirMode) if err != nil { return revert, trace.Wrap(err) } - err = os.MkdirAll(li.LinkServiceDir, 0755) + err = os.MkdirAll(li.LinkServiceDir, systemDirMode) if err != nil { return revert, trace.Wrap(err) } // create binary links - binDir := filepath.Join(versionDir, "bin") entries, err := os.ReadDir(binDir) if err != nil { return revert, trace.Errorf("failed to find Teleport binary directory: %w", err) @@ -450,14 +517,14 @@ func (li *LocalInstaller) Link(ctx context.Context, version string) (revert func } oldname := filepath.Join(binDir, entry.Name()) newname := filepath.Join(li.LinkBinDir, entry.Name()) - orig, err := tryLink(oldname, newname) - if err != nil { + orig, err := forceLink(oldname, newname) + if err != nil && !errors.Is(err, os.ErrExist) { return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err) } if orig != "" { revertLinks = append(revertLinks, symlink{ - old: orig, - new: newname, + oldname: orig, + newname: newname, }) } linked++ @@ -466,46 +533,220 @@ func (li *LocalInstaller) Link(ctx context.Context, version string) (revert func return revert, trace.Errorf("no binaries available to link") } - // create systemd service link + // create systemd service file - oldname := filepath.Join(versionDir, servicePath) - newname := filepath.Join(li.LinkServiceDir, filepath.Base(servicePath)) - orig, err := tryLink(oldname, newname) - if err != nil { - return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err) + src := filepath.Join(svcDir, serviceName) + dst := filepath.Join(li.LinkServiceDir, serviceName) + orig, err := forceCopy(dst, src, maxServiceFileSize) + if err != nil && !errors.Is(err, os.ErrExist) { + return revert, trace.Errorf("failed to create file for %s: %w", serviceName, err) } - if orig != "" { - revertLinks = append(revertLinks, symlink{ - old: orig, - new: newname, - }) + if orig != nil { + revertFiles = append(revertFiles, *orig) } return revert, nil } -// tryLink attempts to create a symlink, atomically replacing an existing link if already present. -// If a non-symlink file or directory exists in newname already, tryLink errors. -func tryLink(oldname, newname string) (orig string, err error) { +// forceLink attempts to create a symlink, atomically replacing an existing link if already present. +// If a non-symlink file or directory exists in newname already, forceLink errors. +// If the link is already present with the desired oldname, forceLink returns os.ErrExist. +func forceLink(oldname, newname string) (orig string, err error) { + exec, err := isExecutable(oldname) + if err != nil { + return "", trace.Wrap(err) + } + if !exec { + return "", trace.Errorf("%s is not a regular executable file", oldname) + } orig, err = os.Readlink(newname) if errors.Is(err, os.ErrInvalid) || errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper // important: do not attempt to replace a non-linked install of Teleport - return orig, trace.Errorf("refusing to replace file at %s", newname) + return "", trace.Errorf("refusing to replace file at %s", newname) } if err != nil && !errors.Is(err, os.ErrNotExist) { - return orig, trace.Wrap(err) + return "", trace.Wrap(err) } if orig == oldname { - return "", nil + return "", trace.Wrap(os.ErrExist) } - // TODO(sclevine): verify oldname is valid binary err = renameio.Symlink(oldname, newname) if err != nil { - return orig, trace.Wrap(err) + return "", trace.Wrap(err) } return orig, nil } +// isExecutable returns true for regular files that are executable by all users (0111). +func isExecutable(path string) (bool, error) { + fi, err := os.Lstat(path) + if err != nil { + return false, trace.Wrap(err) + } + // TODO(sclevine): verify path is valid binary + return fi.Mode().IsRegular() && + fi.Mode()&0111 == 0111, nil +} + +// forceCopy atomically copies a file from src to dst, replacing an existing file at dst if needed. +// Both src and dst must be smaller than n. +// forceCopy returns the original file path, mode, and contents as orig. +// If an irregular file, too large file, or directory exists in path already, forceCopy errors. +// If the file is already present with the desired contents, forceCopy returns os.ErrExist. +func forceCopy(dst, src string, n int64) (orig *smallFile, err error) { + srcData, err := readFileN(src, n) + if err != nil { + return nil, trace.Wrap(err) + } + fi, err := os.Lstat(dst) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, trace.Wrap(err) + } + if err == nil { + orig = &smallFile{ + name: dst, + mode: fi.Mode(), + } + if !orig.mode.IsRegular() { + return nil, trace.Errorf("refusing to replace irregular file at %s", dst) + } + orig.data, err = readFileN(dst, n) + if err != nil { + return nil, trace.Wrap(err) + } + if bytes.Equal(srcData, orig.data) { + return nil, trace.Wrap(os.ErrExist) + } + } + err = renameio.WriteFile(dst, srcData, configFileMode) + if err != nil { + return nil, trace.Wrap(err) + } + return orig, nil +} + +// readFileN reads a file up to n, or errors if it is too large. +func readFileN(name string, n int64) ([]byte, error) { + f, err := os.Open(name) + if err != nil { + return nil, err + } + defer f.Close() + data, err := utils.ReadAtMost(f, n) + return data, trace.Wrap(err) +} + +// TryLink links the specified version, but only in the case that +// no installation of Teleport is already linked or partially linked. +// See Installer interface for additional specs. +func (li *LocalInstaller) TryLink(ctx context.Context, version string) error { + versionDir, err := li.versionDir(version) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(li.tryLinks(ctx, + filepath.Join(versionDir, "bin"), + filepath.Join(versionDir, serviceDir), + )) +} + +// TryLinkSystem links the system installation, but only in the case that +// no installation of Teleport is already linked or partially linked. +// See Installer interface for additional specs. +func (li *LocalInstaller) TryLinkSystem(ctx context.Context) error { + return trace.Wrap(li.tryLinks(ctx, li.SystemBinDir, li.SystemServiceDir)) +} + +// tryLinks create binary and service links for files in binDir and svcDir if links are not already present. +// Existing links that point to files outside binDir or svcDir, as well as existing non-link files, will error. +// tryLinks will not attempt to create any links if linking could result in an error. +// However, concurrent changes to links may result in an error with partially-complete linking. +func (li *LocalInstaller) tryLinks(ctx context.Context, binDir, svcDir string) error { + // ensure target directories exist before trying to create links + err := os.MkdirAll(li.LinkBinDir, systemDirMode) + if err != nil { + return trace.Wrap(err) + } + err = os.MkdirAll(li.LinkServiceDir, systemDirMode) + if err != nil { + return trace.Wrap(err) + } + + // validate that we can link all system binaries before attempting linking + + var links []symlink + var linked int + entries, err := os.ReadDir(binDir) + if err != nil { + return trace.Errorf("failed to find Teleport binary directory: %w", err) + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + oldname := filepath.Join(binDir, entry.Name()) + newname := filepath.Join(li.LinkBinDir, entry.Name()) + ok, err := needsLink(oldname, newname) + if err != nil { + return trace.Errorf("error evaluating link for %s: %w", filepath.Base(oldname), err) + } + if ok { + links = append(links, symlink{oldname, newname}) + } + linked++ + } + // bail if no binaries can be linked + if linked == 0 { + return trace.Errorf("no binaries available to link") + } + + // link binaries that are missing links + for _, link := range links { + if err := os.Symlink(link.oldname, link.newname); err != nil { + return trace.Errorf("failed to create symlink for %s: %w", filepath.Base(link.oldname), err) + } + } + + // if any binaries are linked from binDir, always link the service from svcDir + src := filepath.Join(svcDir, serviceName) + dst := filepath.Join(li.LinkServiceDir, serviceName) + _, err = forceCopy(dst, src, maxServiceFileSize) + if err != nil && !errors.Is(err, os.ErrExist) { + return trace.Errorf("error writing %s: %w", serviceName, err) + } + + return nil +} + +// needsLink returns true when a symlink from oldname to newname needs to be created, or false if it exists. +// If a non-symlink file or directory exists at newname, needsLink errors. +// If a symlink to a different location exists, needsLink errors with ErrLinked. +func needsLink(oldname, newname string) (ok bool, err error) { + exec, err := isExecutable(oldname) + if err != nil { + return false, trace.Wrap(err) + } + if !exec { + return false, trace.Errorf("%s is not a regular executable file", oldname) + } + orig, err := os.Readlink(newname) + if errors.Is(err, os.ErrInvalid) || + errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper + // important: do not attempt to replace a non-linked install of Teleport + return false, trace.Errorf("refusing to replace file at %s", newname) + } + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + if err != nil { + return false, trace.Wrap(err) + } + if orig != oldname { + return false, trace.Errorf("refusing to replace link at %s: %w", newname, ErrLinked) + } + return false, nil +} + // versionDir returns the storage directory for a Teleport version. // versionDir will fail if the version cannot be used to construct the directory name. // For example, it ensures that ".." cannot be provided to return a system directory. @@ -516,7 +757,7 @@ func (li *LocalInstaller) versionDir(version string) (string, error) { } versionDir := filepath.Join(installDir, version) if filepath.Dir(versionDir) != filepath.Clean(installDir) { - return "", trace.Errorf("refusing to directory outside of version directory") + return "", trace.Errorf("refusing to link directory outside of version directory") } return versionDir, nil } @@ -541,10 +782,13 @@ func (li *LocalInstaller) isLinked(versionDir string) (bool, error) { return true, nil } } - v, err := os.Readlink(filepath.Join(li.LinkServiceDir, filepath.Base(servicePath))) + linkData, err := readFileN(filepath.Join(li.LinkServiceDir, serviceName), maxServiceFileSize) + if err != nil { + return false, nil + } + versionData, err := readFileN(filepath.Join(versionDir, serviceDir, serviceName), maxServiceFileSize) if err != nil { return false, nil } - return filepath.Clean(v) == - filepath.Join(versionDir, servicePath), nil + return bytes.Equal(linkData, versionData), nil } diff --git a/lib/autoupdate/agent/installer_test.go b/lib/autoupdate/agent/installer_test.go index d4f58f782dc62..ee9701ac2202c 100644 --- a/lib/autoupdate/agent/installer_test.go +++ b/lib/autoupdate/agent/installer_test.go @@ -84,7 +84,6 @@ func TestLocalInstaller_Install(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() @@ -137,7 +136,7 @@ func TestLocalInstaller_Install(t *testing.T) { require.Equal(t, expectedPath+"."+checksumType, shaPath) for _, p := range []string{ - filepath.Join(dir, version, "etc", "systemd", "teleport.service"), + filepath.Join(dir, version, "lib", "systemd", "system", "teleport.service"), filepath.Join(dir, version, "bin", "teleport"), filepath.Join(dir, version, "bin", "tsh"), } { @@ -194,15 +193,17 @@ func testTGZ(t *testing.T, version string) (tgz *bytes.Buffer, shasum string) { func TestLocalInstaller_Link(t *testing.T) { t.Parallel() const version = "new-version" + servicePath := filepath.Join(serviceDir, serviceName) tests := []struct { - name string - installDirs []string - installFiles []string - existingLinks []string - existingFiles []string + name string + installDirs []string + installFiles []string + installFileMode os.FileMode + existingLinks []string + existingFiles []string - resultLinks []string + resultPaths []string errMatch string }{ { @@ -210,9 +211,9 @@ func TestLocalInstaller_Link(t *testing.T) { installDirs: []string{ "bin", "bin/somedir", - "etc", - "etc/systemd", - "etc/systemd/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", "somedir", }, installFiles: []string{ @@ -222,22 +223,44 @@ func TestLocalInstaller_Link(t *testing.T) { servicePath, "README", }, + installFileMode: os.ModePerm, - resultLinks: []string{ + resultPaths: []string{ "bin/teleport", "bin/tsh", "bin/tbot", "lib/systemd/system/teleport.service", }, }, + { + name: "present with non-executable files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: 0644, + + errMatch: "executable", + }, { name: "present with existing links", installDirs: []string{ "bin", "bin/somedir", - "etc", - "etc/systemd", - "etc/systemd/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", "somedir", }, installFiles: []string{ @@ -247,14 +270,17 @@ func TestLocalInstaller_Link(t *testing.T) { servicePath, "README", }, + installFileMode: os.ModePerm, existingLinks: []string{ "bin/teleport", "bin/tsh", "bin/tbot", + }, + existingFiles: []string{ "lib/systemd/system/teleport.service", }, - resultLinks: []string{ + resultPaths: []string{ "bin/teleport", "bin/tsh", "bin/tbot", @@ -266,9 +292,9 @@ func TestLocalInstaller_Link(t *testing.T) { installDirs: []string{ "bin", "bin/somedir", - "etc", - "etc/systemd", - "etc/systemd/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", "somedir", }, installFiles: []string{ @@ -278,12 +304,11 @@ func TestLocalInstaller_Link(t *testing.T) { servicePath, "README", }, + installFileMode: os.ModePerm, existingLinks: []string{ "bin/teleport", "bin/tsh", "bin/tbot", - }, - existingFiles: []string{ "lib/systemd/system/teleport.service", }, @@ -294,9 +319,9 @@ func TestLocalInstaller_Link(t *testing.T) { installDirs: []string{ "bin", "bin/somedir", - "etc", - "etc/systemd", - "etc/systemd/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", "somedir", }, installFiles: []string{ @@ -306,6 +331,7 @@ func TestLocalInstaller_Link(t *testing.T) { servicePath, "README", }, + installFileMode: os.ModePerm, existingLinks: []string{ "bin/teleport", "bin/tbot", @@ -333,7 +359,6 @@ func TestLocalInstaller_Link(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { versionsDir := t.TempDir() versionDir := filepath.Join(versionsDir, version) @@ -346,7 +371,7 @@ func TestLocalInstaller_Link(t *testing.T) { require.NoError(t, err) } for _, n := range tt.installFiles { - err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), os.ModePerm) + err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), tt.installFileMode) require.NoError(t, err) } @@ -397,7 +422,7 @@ func TestLocalInstaller_Link(t *testing.T) { require.NoError(t, err) // verify links - for _, link := range tt.resultLinks { + for _, link := range tt.resultPaths { v, err := os.ReadFile(filepath.Join(linkDir, link)) require.NoError(t, err) require.Equal(t, filepath.Base(link), string(v)) @@ -420,9 +445,234 @@ func TestLocalInstaller_Link(t *testing.T) { } } +func TestLocalInstaller_TryLink(t *testing.T) { + t.Parallel() + const version = "new-version" + servicePath := filepath.Join(serviceDir, serviceName) + + tests := []struct { + name string + installDirs []string + installFiles []string + installFileMode os.FileMode + existingLinks []string + existingFiles []string + + resultPaths []string + errMatch string + }{ + { + name: "present with new links", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + + resultPaths: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + "lib/systemd/system/teleport.service", + }, + }, + { + name: "present with non-executable files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: 0644, + + errMatch: "executable", + }, + { + name: "present with existing links", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + existingFiles: []string{ + "lib/systemd/system/teleport.service", + }, + + errMatch: "refusing", + }, + { + name: "conflicting systemd files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "lib/systemd/system/teleport.service", + }, + + errMatch: "replace irregular file", + }, + { + name: "conflicting bin files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingFiles: []string{ + "bin/tsh", + }, + + errMatch: "replace file", + }, + { + name: "no links", + installFiles: []string{"README"}, + installDirs: []string{"bin"}, + + errMatch: "no binaries", + }, + { + name: "no bin directory", + installFiles: []string{"README"}, + + errMatch: "binary directory", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + versionsDir := t.TempDir() + versionDir := filepath.Join(versionsDir, version) + err := os.MkdirAll(versionDir, 0o755) + require.NoError(t, err) + + // setup files in version directory + for _, d := range tt.installDirs { + err := os.Mkdir(filepath.Join(versionDir, d), os.ModePerm) + require.NoError(t, err) + } + for _, n := range tt.installFiles { + err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), tt.installFileMode) + require.NoError(t, err) + } + + // setup files in system links directory + linkDir := t.TempDir() + for _, n := range tt.existingLinks { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.Symlink(filepath.Base(n)+".old", filepath.Join(linkDir, n)) + require.NoError(t, err) + } + for _, n := range tt.existingFiles { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(linkDir, n), []byte(filepath.Base(n)), os.ModePerm) + require.NoError(t, err) + } + + installer := &LocalInstaller{ + InstallDir: versionsDir, + LinkBinDir: filepath.Join(linkDir, "bin"), + LinkServiceDir: filepath.Join(linkDir, "lib/systemd/system"), + Log: slog.Default(), + } + ctx := context.Background() + err = installer.TryLink(ctx, version) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + + // verify no changes + for _, link := range tt.existingLinks { + v, err := os.Readlink(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link)+".old", v) + } + for _, n := range tt.existingFiles { + v, err := os.ReadFile(filepath.Join(linkDir, n)) + require.NoError(t, err) + require.Equal(t, filepath.Base(n), string(v)) + } + return + } + require.NoError(t, err) + + // verify links + for _, link := range tt.resultPaths { + v, err := os.ReadFile(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link), string(v)) + } + }) + } +} + func TestLocalInstaller_Remove(t *testing.T) { t.Parallel() const version = "existing-version" + servicePath := filepath.Join(serviceDir, serviceName) tests := []struct { name string @@ -457,8 +707,8 @@ func TestLocalInstaller_Remove(t *testing.T) { }, { name: "version linked", - dirs: []string{"bin", "bin/somedir", "somedir"}, - files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"}, + dirs: []string{"bin", "bin/somedir", "somedir", "lib", "lib/systemd", "lib/systemd/system"}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README", servicePath}, createVersion: version, linkedVersion: version, removeVersion: version, @@ -504,7 +754,6 @@ func TestLocalInstaller_Remove(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { versionsDir := t.TempDir() versionDir := filepath.Join(versionsDir, tt.createVersion) diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/insecure_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/insecure_URL.golden new file mode 100644 index 0000000000000..d3da980a1afde --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/insecure_URL.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: "" + group: "" + url_template: http://example.com + enabled: false +status: + active_version: "" + backup_version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/install_error.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/install_error.golden new file mode 100644 index 0000000000000..2ddb840b01794 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/install_error.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: "" + group: "" + url_template: "" + enabled: false +status: + active_version: "" + backup_version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/invalid_metadata.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/invalid_metadata.golden new file mode 100644 index 0000000000000..df0c99fe5fe7e --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/invalid_metadata.golden @@ -0,0 +1,10 @@ +version: "" +kind: "" +spec: + proxy: "" + group: "" + url_template: "" + enabled: false +status: + active_version: "" + backup_version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden new file mode 100644 index 0000000000000..6ff42e075b57a --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: http://example.com + enabled: true +status: + active_version: "" + backup_version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden new file mode 100644 index 0000000000000..3b9e19637eef5 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: "" + enabled: true +status: + active_version: "" + backup_version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden new file mode 100644 index 0000000000000..e47fe44a13da0 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden @@ -0,0 +1,10 @@ +version: "" +kind: "" +spec: + proxy: localhost + group: "" + url_template: "" + enabled: false +status: + active_version: "" + backup_version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/reload_fails.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/reload_fails.golden new file mode 100644 index 0000000000000..3628297dd9443 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/reload_fails.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: https://example.com + enabled: true +status: + active_version: old-version + backup_version: backup-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/sync_fails.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/sync_fails.golden new file mode 100644 index 0000000000000..3628297dd9443 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/sync_fails.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: https://example.com + enabled: true +status: + active_version: old-version + backup_version: backup-version diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index f2358ce60c0fd..9625481df2cd2 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -41,6 +41,13 @@ import ( libutils "github.com/gravitational/teleport/lib/utils" ) +const ( + // DefaultLinkDir is the default location where Teleport is linked. + DefaultLinkDir = "/usr/local" + // DefaultSystemDir is the location where packaged Teleport binaries and services are installed. + DefaultSystemDir = "/usr/local/teleport-system" +) + const ( // cdnURITemplate is the default template for the Teleport tgz download. cdnURITemplate = "https://cdn.teleport.dev/teleport{{if .Enterprise}}-ent{{end}}-v{{.Version}}-{{.OS}}-{{.Arch}}{{if .FIPS}}-fips{{end}}-bin.tar.gz" @@ -50,6 +57,14 @@ const ( reservedFreeDisk = 10_000_000 // 10 MB ) +// Log keys +const ( + targetVersionKey = "target_version" + activeVersionKey = "active_version" + backupVersionKey = "backup_version" + errorKey = "error" +) + const ( // updateConfigName specifies the name of the file inside versionsDirName containing configuration for the teleport update. updateConfigName = "update.yaml" @@ -59,14 +74,6 @@ const ( updateConfigKind = "update_config" ) -// Log keys -const ( - targetVersionKey = "target_version" - activeVersionKey = "active_version" - backupVersionKey = "backup_version" - errorKey = "error" -) - // UpdateConfig describes the update.yaml file schema. type UpdateConfig struct { // Version of the configuration file @@ -124,7 +131,10 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { cfg.Log = slog.Default() } if cfg.LinkDir == "" { - cfg.LinkDir = "/usr/local" + cfg.LinkDir = DefaultLinkDir + } + if cfg.SystemDir == "" { + cfg.SystemDir = DefaultSystemDir } if cfg.VersionsDir == "" { cfg.VersionsDir = filepath.Join(libdefaults.DataDir, "versions") @@ -135,12 +145,15 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { InsecureSkipVerify: cfg.InsecureSkipVerify, ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName), Installer: &LocalInstaller{ - InstallDir: cfg.VersionsDir, - LinkBinDir: filepath.Join(cfg.LinkDir, "bin"), - LinkServiceDir: filepath.Join(cfg.LinkDir, "lib", "systemd", "system"), - HTTP: client, - Log: cfg.Log, - + InstallDir: cfg.VersionsDir, + LinkBinDir: filepath.Join(cfg.LinkDir, "bin"), + // For backwards-compatibility with symlinks created by package-based installs, we always + // link into /lib/systemd/system, even though, e.g., /usr/local/lib/systemd/system would work. + LinkServiceDir: filepath.Join("/", serviceDir), + SystemBinDir: filepath.Join(cfg.SystemDir, "bin"), + SystemServiceDir: filepath.Join(cfg.SystemDir, serviceDir), + HTTP: client, + Log: cfg.Log, ReservedFreeTmpDisk: reservedFreeDisk, ReservedFreeInstallDisk: reservedFreeDisk, }, @@ -165,6 +178,8 @@ type LocalUpdaterConfig struct { VersionsDir string // LinkDir for installing Teleport (usually /usr/local). LinkDir string + // SystemDir for package-installed Teleport installations (usually /usr/local/teleport-system). + SystemDir string } // Updater implements the agent-local logic for Teleport agent auto-updates. @@ -188,11 +203,22 @@ type Installer interface { // Install the Teleport agent at version from the download template. // Install must be idempotent. Install(ctx context.Context, version, template string, flags InstallFlags) error - // Link the Teleport agent at the specified version into the system location. + // Link the Teleport agent at the specified version of Teleport into the linking locations. // The revert function must restore the previous linking, returning false on any failure. - // Link must be idempotent. - // Link's revert function must be idempotent. + // Link must be idempotent. Link's revert function must be idempotent. Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) + // LinkSystem links the system installation of Teleport into the linking locations. + // The revert function must restore the previous linking, returning false on any failure. + // LinkSystem must be idempotent. LinkSystem's revert function must be idempotent. + LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) + // TryLink links the specified version of Teleport into the linking locations. + // Unlike Link, TryLink will fail if existing links to other locations are present. + // TryLink must be idempotent. + TryLink(ctx context.Context, version string) error + // TryLinkSystem links the system installation of Teleport into the linking locations. + // Unlike LinkSystem, TryLinkSystem will fail if existing links to other locations are present. + // TryLinkSystem must be idempotent. + TryLinkSystem(ctx context.Context) error // List the installed versions of Teleport. List(ctx context.Context) (versions []string, err error) // Remove the Teleport agent at version. @@ -480,8 +506,9 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.") } else if err := u.Process.Sync(ctx); err != nil { u.Log.ErrorContext(ctx, "Failed to sync configuration after failed restart.", errorKey, err) + } else { + u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") } - u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") return trace.Errorf("failed to validate configuration for new version %q of Teleport: %w", targetVersion, err) } @@ -502,8 +529,9 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s u.Log.ErrorContext(ctx, "Invalid configuration found after reverting Teleport to older version. Installation likely broken.", errorKey, err) } else if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { u.Log.ErrorContext(ctx, "Failed to revert Teleport to older version. Installation likely broken.", errorKey, err) + } else { + u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") } - u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") return trace.Errorf("failed to start new version %q of Teleport: %w", targetVersion, err) } @@ -558,7 +586,7 @@ func readConfig(path string) (*UpdateConfig, error) { // writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted. func writeConfig(filename string, cfg *UpdateConfig) error { opts := []renameio.Option{ - renameio.WithPermissions(0755), + renameio.WithPermissions(configFileMode), renameio.WithExistingPermissions(), } t, err := renameio.NewPendingFile(filename, opts...) @@ -589,3 +617,37 @@ func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error { } return nil } + +// LinkPackage creates links from the system (package) installation of Teleport, if they are needed. +// LinkPackage returns nils and warns if an auto-updates version is already linked, but auto-updates is disabled. +// LinkPackage returns an error only if an unknown version of Teleport is present (e.g., manually copied files). +// This function is idempotent. +func (u *Updater) LinkPackage(ctx context.Context) error { + cfg, err := readConfig(u.ConfigPath) + if err != nil { + return trace.Errorf("failed to read %s: %w", updateConfigName, err) + } + if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil { + return trace.Wrap(err) + } + activeVersion := cfg.Status.ActiveVersion + if cfg.Spec.Enabled { + u.Log.InfoContext(ctx, "Automatic updates enabled. Skipping system package link.", activeVersionKey, activeVersion) + return nil + } + // If an active version is set, but auto-updates is disabled, try to link the system installation in case the config is stale. + // If any links are present, this will return ErrLinked and not create any system links. + // This state is important to log as a warning, + if err := u.Installer.TryLinkSystem(ctx); errors.Is(err, ErrLinked) { + u.Log.WarnContext(ctx, "Automatic updates disabled, but a non-package version of Teleport is linked.", activeVersionKey, activeVersion) + return nil + } else if err != nil { + return trace.Errorf("failed to link system package installation: %w", err) + } + // TODO(sclevine): only if systemd files change + if err := u.Process.Sync(ctx); err != nil { + return trace.Errorf("failed to validate configuration for packaged installation of Teleport: %w", err) + } + u.Log.InfoContext(ctx, "Successfully linked system package installation.") + return nil +} diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go index 16943d8a3e7e1..1197ac3d5a795 100644 --- a/lib/autoupdate/agent/updater_test.go +++ b/lib/autoupdate/agent/updater_test.go @@ -81,7 +81,6 @@ func TestUpdater_Disable(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() cfgPath := filepath.Join(dir, "update.yaml") @@ -138,6 +137,7 @@ func TestUpdater_Update(t *testing.T) { removedVersion string installedVersion string installedTemplate string + linkedVersion string requestGroup string syncCalls int reloadCalls int @@ -162,6 +162,7 @@ func TestUpdater_Update(t *testing.T) { installedVersion: "16.3.0", installedTemplate: "https://example.com", + linkedVersion: "16.3.0", requestGroup: "group", syncCalls: 1, reloadCalls: 1, @@ -233,14 +234,15 @@ func TestUpdater_Update(t *testing.T) { Version: updateConfigVersion, Kind: updateConfigKind, Spec: UpdateSpec{ - URLTemplate: "https://example.com", - Enabled: true, + Enabled: true, }, }, inWindow: true, installErr: errors.New("install error"), - errMatch: "install error", + installedVersion: "16.3.0", + installedTemplate: cdnURITemplate, + errMatch: "install error", }, { name: "version already installed in window", @@ -289,6 +291,7 @@ func TestUpdater_Update(t *testing.T) { installedVersion: "16.3.0", installedTemplate: "https://example.com", + linkedVersion: "16.3.0", removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, @@ -331,6 +334,7 @@ func TestUpdater_Update(t *testing.T) { installedVersion: "16.3.0", installedTemplate: "https://example.com", + linkedVersion: "16.3.0", removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, @@ -359,6 +363,7 @@ func TestUpdater_Update(t *testing.T) { installedVersion: "16.3.0", installedTemplate: "https://example.com", + linkedVersion: "16.3.0", removedVersion: "backup-version", syncCalls: 2, reloadCalls: 0, @@ -384,6 +389,7 @@ func TestUpdater_Update(t *testing.T) { installedVersion: "16.3.0", installedTemplate: "https://example.com", + linkedVersion: "16.3.0", removedVersion: "backup-version", syncCalls: 2, reloadCalls: 2, @@ -393,7 +399,6 @@ func TestUpdater_Update(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { var requestedGroup string server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -481,12 +486,12 @@ func TestUpdater_Update(t *testing.T) { if tt.errMatch != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.errMatch) - return + } else { + require.NoError(t, err) } - require.NoError(t, err) require.Equal(t, tt.installedVersion, installedVersion) require.Equal(t, tt.installedTemplate, installedTemplate) - require.Equal(t, tt.installedVersion, linkedVersion) + require.Equal(t, tt.linkedVersion, linkedVersion) require.Equal(t, tt.removedVersion, removedVersion) require.Equal(t, tt.flags, installedFlags) require.Equal(t, tt.requestGroup, requestedGroup) @@ -495,6 +500,8 @@ func TestUpdater_Update(t *testing.T) { require.Equal(t, tt.revertCalls, revertCalls) if tt.cfg == nil { + _, err := os.Stat(cfgPath) + require.Error(t, err) return } @@ -510,6 +517,128 @@ func TestUpdater_Update(t *testing.T) { } } +func TestUpdater_LinkPackage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *UpdateConfig // nil -> file not present + tryLinkSystemErr error + + syncCalls int + tryLinkSystemCalls int + errMatch string + }{ + { + name: "updates enabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: true, + }, + }, + + tryLinkSystemCalls: 0, + syncCalls: 0, + }, + { + name: "updates disabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + + tryLinkSystemCalls: 1, + syncCalls: 1, + }, + { + name: "already linked", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + tryLinkSystemErr: ErrLinked, + + tryLinkSystemCalls: 1, + syncCalls: 0, + }, + { + name: "link error", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + tryLinkSystemErr: errors.New("bad"), + + tryLinkSystemCalls: 1, + syncCalls: 0, + errMatch: "bad", + }, + { + name: "no config", + tryLinkSystemCalls: 1, + syncCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "update.yaml") + + // Create config file only if provided in test case + if tt.cfg != nil { + b, err := yaml.Marshal(tt.cfg) + require.NoError(t, err) + err = os.WriteFile(cfgPath, b, 0600) + require.NoError(t, err) + } + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + VersionsDir: dir, + }) + require.NoError(t, err) + + var tryLinkSystemCalls int + updater.Installer = &testInstaller{ + FuncTryLinkSystem: func(_ context.Context) error { + tryLinkSystemCalls++ + return tt.tryLinkSystemErr + }, + } + var syncCalls int + updater.Process = &testProcess{ + FuncSync: func(_ context.Context) error { + syncCalls++ + return nil + }, + } + + ctx := context.Background() + err = updater.LinkPackage(ctx) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.tryLinkSystemCalls, tryLinkSystemCalls) + require.Equal(t, tt.syncCalls, syncCalls) + }) + } +} + func TestUpdater_Enable(t *testing.T) { t.Parallel() @@ -525,6 +654,7 @@ func TestUpdater_Enable(t *testing.T) { removedVersion string installedVersion string installedTemplate string + linkedVersion string requestGroup string syncCalls int reloadCalls int @@ -547,6 +677,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: "https://example.com", + linkedVersion: "16.3.0", requestGroup: "group", syncCalls: 1, reloadCalls: 1, @@ -572,6 +703,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "new-version", installedTemplate: "https://example.com/new", + linkedVersion: "new-version", syncCalls: 1, reloadCalls: 1, }, @@ -590,6 +722,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, }, @@ -610,13 +743,13 @@ func TestUpdater_Enable(t *testing.T) { cfg: &UpdateConfig{ Version: updateConfigVersion, Kind: updateConfigKind, - Spec: UpdateSpec{ - URLTemplate: "https://example.com", - }, }, installErr: errors.New("install error"), - errMatch: "install error", + installedVersion: "16.3.0", + linkedVersion: "", + installedTemplate: cdnURITemplate, + errMatch: "install error", }, { name: "version already installed", @@ -630,6 +763,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 0, }, @@ -646,6 +780,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, @@ -663,6 +798,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", removedVersion: "", syncCalls: 1, reloadCalls: 0, @@ -672,6 +808,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, }, @@ -680,6 +817,7 @@ func TestUpdater_Enable(t *testing.T) { flags: FlagEnterprise | FlagFIPS, installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, }, @@ -694,6 +832,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", syncCalls: 2, reloadCalls: 0, revertCalls: 1, @@ -705,6 +844,7 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + linkedVersion: "16.3.0", syncCalls: 2, reloadCalls: 2, revertCalls: 1, @@ -713,7 +853,6 @@ func TestUpdater_Enable(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() cfgPath := filepath.Join(dir, "update.yaml") @@ -803,12 +942,12 @@ func TestUpdater_Enable(t *testing.T) { if tt.errMatch != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.errMatch) - return + } else { + require.NoError(t, err) } - require.NoError(t, err) require.Equal(t, tt.installedVersion, installedVersion) require.Equal(t, tt.installedTemplate, installedTemplate) - require.Equal(t, tt.installedVersion, linkedVersion) + require.Equal(t, tt.linkedVersion, linkedVersion) require.Equal(t, tt.removedVersion, removedVersion) require.Equal(t, tt.flags, installedFlags) require.Equal(t, tt.requestGroup, requestedGroup) @@ -816,6 +955,12 @@ func TestUpdater_Enable(t *testing.T) { require.Equal(t, tt.reloadCalls, reloadCalls) require.Equal(t, tt.revertCalls, revertCalls) + if tt.cfg == nil && err != nil { + _, err := os.Stat(cfgPath) + require.Error(t, err) + return + } + data, err := os.ReadFile(cfgPath) require.NoError(t, err) data = blankTestAddr(data) @@ -835,10 +980,13 @@ func blankTestAddr(s []byte) []byte { } type testInstaller struct { - FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error - FuncRemove func(ctx context.Context, version string) error - FuncLink func(ctx context.Context, version string) (revert func(context.Context) bool, err error) - FuncList func(ctx context.Context) (versions []string, err error) + FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error + FuncRemove func(ctx context.Context, version string) error + FuncLink func(ctx context.Context, version string) (revert func(context.Context) bool, err error) + FuncLinkSystem func(ctx context.Context) (revert func(context.Context) bool, err error) + FuncTryLink func(ctx context.Context, version string) error + FuncTryLinkSystem func(ctx context.Context) error + FuncList func(ctx context.Context) (versions []string, err error) } func (ti *testInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error { @@ -853,6 +1001,18 @@ func (ti *testInstaller) Link(ctx context.Context, version string) (revert func( return ti.FuncLink(ctx, version) } +func (ti *testInstaller) LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) { + return ti.FuncLinkSystem(ctx) +} + +func (ti *testInstaller) TryLink(ctx context.Context, version string) error { + return ti.FuncTryLink(ctx, version) +} + +func (ti *testInstaller) TryLinkSystem(ctx context.Context) error { + return ti.FuncTryLinkSystem(ctx) +} + func (ti *testInstaller) List(ctx context.Context) (versions []string, err error) { return ti.FuncList(ctx) } diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go index 26acd11f7c1de..d559ad3e75cdd 100644 --- a/tool/teleport-update/main.go +++ b/tool/teleport-update/main.go @@ -20,6 +20,7 @@ package main import ( "context" + "errors" "log/slog" "os" "os/signal" @@ -61,8 +62,6 @@ const ( versionsDirName = "versions" // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution. lockFileName = ".lock" - // defaultLinkDir is the default location where Teleport binaries and services are linked. - defaultLinkDir = "/usr/local" ) var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater) @@ -92,15 +91,15 @@ func Run(args []string) error { ctx := context.Background() ctx, _ = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) - app := libutils.InitCLIParser("teleport-updater", appHelp).Interspersed(false) + app := libutils.InitCLIParser("teleport-update", appHelp).Interspersed(false) app.Flag("debug", "Verbose logging to stdout."). Short('d').BoolVar(&ccfg.Debug) app.Flag("data-dir", "Teleport data directory. Access to this directory should be limited."). Default(libdefaults.DataDir).StringVar(&ccfg.DataDir) app.Flag("log-format", "Controls the format of output logs. Can be `json` or `text`. Defaults to `text`."). Default(libutils.LogFormatText).EnumVar(&ccfg.LogFormat, libutils.LogFormatJSON, libutils.LogFormatText) - app.Flag("link-dir", "Directory to create system symlinks to binaries and services."). - Default(defaultLinkDir).Hidden().StringVar(&ccfg.LinkDir) + app.Flag("link-dir", "Directory to link the active Teleport installation into."). + Default(autoupdate.DefaultLinkDir).Hidden().StringVar(&ccfg.LinkDir) app.HelpFlag.Short('h') @@ -121,6 +120,8 @@ func Run(args []string) error { updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.") + linkCmd := app.Command("link", "Link the system installation of Teleport from the Teleport package, if auto-updates is disabled.") + libutils.UpdateAppUsageTemplate(app, args) command, err := app.Parse(args) if err != nil { @@ -140,6 +141,8 @@ func Run(args []string) error { err = cmdDisable(ctx, &ccfg) case updateCmd.FullCommand(): err = cmdUpdate(ctx, &ccfg) + case linkCmd.FullCommand(): + err = cmdLink(ctx, &ccfg) case versionCmd.FullCommand(): modules.GetModules().PrintVersion() default: @@ -186,6 +189,7 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error { updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ VersionsDir: versionsDir, LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, Log: plog, }) if err != nil { @@ -218,6 +222,7 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error { updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ VersionsDir: versionsDir, LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, Log: plog, }) if err != nil { @@ -250,6 +255,7 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ VersionsDir: versionsDir, LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, Log: plog, }) if err != nil { @@ -260,3 +266,36 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { } return nil } + +// cmdLink creates system package links if no version is linked and auto-updates is disabled. +func cmdLink(ctx context.Context, ccfg *cliConfig) error { + versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) + + // Skip operation and warn if the updater is currently running. + unlock, err := libutils.FSTryReadLock(filepath.Join(versionsDir, lockFileName)) + if errors.Is(err, libutils.ErrUnsuccessfulLockTry) { + plog.WarnContext(ctx, "Updater is currently running. Skipping package linking.") + return nil + } + if err != nil { + return trace.Errorf("failed to grab concurrent execution lock: %w", err) + } + defer func() { + if err := unlock(); err != nil { + plog.DebugContext(ctx, "Failed to close lock file", "error", err) + } + }() + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + VersionsDir: versionsDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) + } + if err := updater.LinkPackage(ctx); err != nil { + return trace.Wrap(err) + } + return nil +}