From f18fb0ebff2f57184e4e5a91669854af4e869bf1 Mon Sep 17 00:00:00 2001 From: Evan Gibler <20933572+egibs@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:40:16 -0600 Subject: [PATCH] Split up archive.go into type-specific files; add wider zlib support (#723) * Split up archive.go into type-specific files Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Move archive code to new package; leave tests to avoid circular imports Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --------- Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --- pkg/action/archive.go | 736 --------------------------------- pkg/action/archive_test.go | 36 +- pkg/action/diff.go | 5 +- pkg/action/scan.go | 7 +- pkg/archive/archive.go | 197 +++++++++ pkg/archive/bz2.go | 57 +++ pkg/archive/deb.go | 98 +++++ pkg/archive/gzip.go | 51 +++ pkg/{action => archive}/oci.go | 6 +- pkg/archive/rpm.go | 110 +++++ pkg/archive/tar.go | 156 +++++++ pkg/archive/zip.go | 87 ++++ pkg/archive/zlib.go | 51 +++ pkg/programkind/programkind.go | 53 +++ 14 files changed, 889 insertions(+), 761 deletions(-) delete mode 100644 pkg/action/archive.go create mode 100644 pkg/archive/archive.go create mode 100644 pkg/archive/bz2.go create mode 100644 pkg/archive/deb.go create mode 100644 pkg/archive/gzip.go rename pkg/{action => archive}/oci.go (91%) create mode 100644 pkg/archive/rpm.go create mode 100644 pkg/archive/tar.go create mode 100644 pkg/archive/zip.go create mode 100644 pkg/archive/zlib.go diff --git a/pkg/action/archive.go b/pkg/action/archive.go deleted file mode 100644 index c52cf123..00000000 --- a/pkg/action/archive.go +++ /dev/null @@ -1,736 +0,0 @@ -package action - -import ( - "archive/tar" - "archive/zip" - "compress/bzip2" - "compress/gzip" - "compress/zlib" - "context" - "errors" - "fmt" - "io" - "io/fs" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - - "github.com/cavaliergopher/cpio" - "github.com/cavaliergopher/rpm" - "github.com/chainguard-dev/clog" - "github.com/chainguard-dev/malcontent/pkg/programkind" - - "github.com/egibs/go-debian/deb" - - "github.com/ulikunitz/xz" -) - -var archiveMap = map[string]bool{ - ".apk": true, - ".bz2": true, - ".bzip2": true, - ".deb": true, - ".gem": true, - ".gz": true, - ".jar": true, - ".rpm": true, - ".tar": true, - ".tar.gz": true, - ".tar.xz": true, - ".tgz": true, - ".whl": true, - ".xz": true, - ".zip": true, -} - -// isSupportedArchive returns whether a path can be processed by our archive extractor. -func isSupportedArchive(path string) bool { - return archiveMap[getExt(path)] -} - -// isValidPath checks if the target file is within the given directory. -func isValidPath(target, dir string) bool { - return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir)) -} - -// getExt returns the extension of a file path -// and attempts to avoid including fragments of filenames with other dots before the extension. -func getExt(path string) string { - base := filepath.Base(path) - - // Handle files with version numbers in the name - // e.g. file1.2.3.tar.gz -> .tar.gz - re := regexp.MustCompile(`\d+\.\d+\.\d+$`) - base = re.ReplaceAllString(base, "") - - ext := filepath.Ext(base) - - if ext != "" && strings.Contains(base, ".") { - parts := strings.Split(base, ".") - if len(parts) > 2 { - subExt := fmt.Sprintf(".%s%s", parts[len(parts)-2], ext) - if isValidExt := func(ext string) bool { - _, ok := archiveMap[ext] - return ok - }(subExt); isValidExt { - return subExt - } - } - } - - return ext -} - -const maxBytes = 1 << 29 // 512MB - -// extractTar extracts .apk and .tar* archives. -func extractTar(ctx context.Context, d string, f string) error { - logger := clog.FromContext(ctx).With("dir", d, "file", f) - logger.Debug("extracting tar") - - // Check if the file is valid - _, err := os.Stat(f) - if err != nil { - return fmt.Errorf("failed to stat file: %w", err) - } - - filename := filepath.Base(f) - tf, err := os.Open(f) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) - } - defer tf.Close() - // Set offset to the file origin regardless of type - _, err = tf.Seek(0, io.SeekStart) - if err != nil { - return fmt.Errorf("failed to seek to start: %w", err) - } - - var tr *tar.Reader - - switch { - case strings.Contains(f, ".apk") || strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz"): - gzStream, err := gzip.NewReader(tf) - if err != nil { - return fmt.Errorf("failed to create gzip reader: %w", err) - } - defer gzStream.Close() - tr = tar.NewReader(gzStream) - case strings.Contains(filename, ".tar.xz"): - xzStream, err := xz.NewReader(tf) - if err != nil { - return fmt.Errorf("failed to create xz reader: %w", err) - } - tr = tar.NewReader(xzStream) - case strings.Contains(filename, ".xz"): - xzStream, err := xz.NewReader(tf) - if err != nil { - return fmt.Errorf("failed to create xz reader: %w", err) - } - uncompressed := strings.Trim(filepath.Base(f), ".xz") - target := filepath.Join(d, uncompressed) - if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { - return fmt.Errorf("failed to create directory for file: %w", err) - } - - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - f, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer f.Close() - if _, err = io.Copy(f, xzStream); err != nil { - return fmt.Errorf("failed to write decompressed xz output: %w", err) - } - return nil - case strings.Contains(filename, ".tar.bz2") || strings.Contains(filename, ".tbz"): - br := bzip2.NewReader(tf) - tr = tar.NewReader(br) - default: - tr = tar.NewReader(tf) - } - - for { - header, err := tr.Next() - - if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { - break - } - - if err != nil { - return fmt.Errorf("failed to read tar header: %w", err) - } - - clean := filepath.Clean(header.Name) - if filepath.IsAbs(clean) || strings.Contains(clean, "../") { - return fmt.Errorf("path is absolute or contains a relative path traversal: %s", clean) - } - - target := filepath.Join(d, clean) - if !isValidPath(target, d) { - return fmt.Errorf("invalid file path: %s", target) - } - - switch header.Typeflag { - case tar.TypeDir: - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - case tar.TypeReg: - if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { - return fmt.Errorf("failed to create parent directory: %w", err) - } - - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - - if _, err := io.Copy(out, io.LimitReader(tr, maxBytes)); err != nil { - out.Close() - return fmt.Errorf("failed to copy file: %w", err) - } - - if err := out.Close(); err != nil { - return fmt.Errorf("failed to close file: %w", err) - } - case tar.TypeSymlink: - // Skip symlinks for targets that do not exist - _, err = os.Readlink(target) - if os.IsNotExist(err) { - continue - } - // Ensure that symlinks are not relative path traversals - // #nosec G305 // L208 handles the check - linkReal, err := filepath.EvalSymlinks(filepath.Join(d, header.Linkname)) - if err != nil { - return fmt.Errorf("failed to evaluate symlink: %w", err) - } - if !isValidPath(target, d) { - return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) - } - if err := os.Symlink(linkReal, target); err != nil { - return fmt.Errorf("failed to create symlink: %w", err) - } - } - } - return nil -} - -// extractGzip extracts .gz archives. -func extractGzip(ctx context.Context, d string, f string) error { - logger := clog.FromContext(ctx).With("dir", d, "file", f) - - // Check if the file is valid - _, err := os.Stat(f) - if err != nil { - return fmt.Errorf("failed to stat file: %w", err) - } - - gf, err := os.Open(f) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) - } - defer gf.Close() - - // Determine if we're extracting a gzip- or zlib-compressed file - ft, err := programkind.File(f) - if err != nil { - return fmt.Errorf("failed to determine file type: %w", err) - } - - logger.Debugf("extracting %s", ft.Ext) - - base := filepath.Base(f) - target := filepath.Join(d, base[:len(base)-len(filepath.Ext(base))]) - - switch ft.Ext { - case "gzip": - gr, err := gzip.NewReader(gf) - if err != nil { - return fmt.Errorf("failed to create gzip reader: %w", err) - } - defer gr.Close() - - ef, err := os.Create(target) - if err != nil { - return fmt.Errorf("failed to create extracted file: %w", err) - } - defer ef.Close() - - if _, err := io.Copy(ef, io.LimitReader(gr, maxBytes)); err != nil { - return fmt.Errorf("failed to copy file: %w", err) - } - case "Z": - zr, err := zlib.NewReader(gf) - if err != nil { - return fmt.Errorf("failed to create zlib reader: %w", err) - } - defer zr.Close() - - ef, err := os.Create(target) - if err != nil { - return fmt.Errorf("failed to create extracted file: %w", err) - } - defer ef.Close() - - if _, err := io.Copy(ef, io.LimitReader(zr, maxBytes)); err != nil { - return fmt.Errorf("failed to copy file: %w", err) - } - } - - return nil -} - -// extractZip extracts .jar and .zip archives. -func extractZip(ctx context.Context, d string, f string) error { - logger := clog.FromContext(ctx).With("dir", d, "file", f) - logger.Debug("extracting zip") - - // Check if the file is valid - _, err := os.Stat(f) - if err != nil { - return fmt.Errorf("failed to stat file %s: %w", f, err) - } - - read, err := zip.OpenReader(f) - if err != nil { - return fmt.Errorf("failed to open zip file %s: %w", f, err) - } - defer read.Close() - - for _, file := range read.File { - clean := filepath.Clean(filepath.ToSlash(file.Name)) - if strings.Contains(clean, "..") { - logger.Warnf("skipping potentially unsafe file path: %s", file.Name) - continue - } - - name := filepath.Join(d, clean) - if !isValidPath(name, d) { - logger.Warnf("skipping file path outside extraction directory: %s", name) - continue - } - - // Check if a directory with the same name exists - if info, err := os.Stat(name); err == nil && info.IsDir() { - continue - } - - if file.Mode().IsDir() { - mode := file.Mode() | 0o700 - err := os.MkdirAll(name, mode) - if err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - continue - } - - open, err := file.Open() - if err != nil { - return fmt.Errorf("failed to open file in zip: %w", err) - } - - err = os.MkdirAll(filepath.Dir(name), 0o700) - if err != nil { - open.Close() - return fmt.Errorf("failed to create directory: %w", err) - } - - mode := file.Mode() | 0o200 - create, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) - if err != nil { - open.Close() - return fmt.Errorf("failed to create file: %w", err) - } - - if _, err = io.Copy(create, io.LimitReader(open, maxBytes)); err != nil { - open.Close() - create.Close() - return fmt.Errorf("failed to copy file: %w", err) - } - - open.Close() - create.Close() - } - return nil -} - -// extractRPM extracts .rpm packages. -func extractRPM(ctx context.Context, d, f string) error { - logger := clog.FromContext(ctx).With("dir", d, "file", f) - logger.Debug("extracting rpm") - - rpmFile, err := os.Open(f) - if err != nil { - return fmt.Errorf("failed to open RPM file: %w", err) - } - defer rpmFile.Close() - - pkg, err := rpm.Read(rpmFile) - if err != nil { - return fmt.Errorf("failed to read RPM package headers: %w", err) - } - - if format := pkg.PayloadFormat(); format != "cpio" { - return fmt.Errorf("unsupported payload format: %s", format) - } - - payloadOffset, err := rpmFile.Seek(0, io.SeekCurrent) - if err != nil { - return fmt.Errorf("failed to get payload offset: %w", err) - } - - if _, err := rpmFile.Seek(payloadOffset, io.SeekStart); err != nil { - return fmt.Errorf("failed to seek to payload: %w", err) - } - - var cr *cpio.Reader - switch compression := pkg.PayloadCompression(); compression { - case "gzip": - gzStream, err := gzip.NewReader(rpmFile) - if err != nil { - return fmt.Errorf("failed to create gzip reader: %w", err) - } - defer gzStream.Close() - cr = cpio.NewReader(gzStream) - case "xz": - xzStream, err := xz.NewReader(rpmFile) - if err != nil { - return fmt.Errorf("failed to create xz reader: %w", err) - } - cr = cpio.NewReader(xzStream) - default: - return fmt.Errorf("unsupported compression format: %s", compression) - } - - for { - header, err := cr.Next() - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - break - } - if err != nil { - return fmt.Errorf("failed to read cpio header: %w", err) - } - - clean := filepath.Clean(header.Name) - if filepath.IsAbs(clean) || strings.Contains(clean, "../") { - return fmt.Errorf("path is absolute or contains a relative path traversal: %s", clean) - } - - target := filepath.Join(d, clean) - - if header.FileInfo().IsDir() { - if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - continue - } - - if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { - return fmt.Errorf("failed to create parent directory: %w", err) - } - - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - - if _, err := io.Copy(out, io.LimitReader(cr, maxBytes)); err != nil { - out.Close() - return fmt.Errorf("failed to copy file: %w", err) - } - - if err := out.Close(); err != nil { - return fmt.Errorf("failed to close file: %w", err) - } - } - - return nil -} - -// extractDeb extracts .deb packages. -func extractDeb(ctx context.Context, d, f string) error { - logger := clog.FromContext(ctx).With("dir", d, "file", f) - logger.Debug("extracting deb") - - fd, err := os.Open(f) - if err != nil { - panic(err) - } - defer fd.Close() - - df, err := deb.Load(fd, f) - if err != nil { - panic(err) - } - defer df.Close() - - for { - header, err := df.Data.Next() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return fmt.Errorf("failed to read tar header: %w", err) - } - - clean := filepath.Clean(header.Name) - if filepath.IsAbs(clean) || strings.Contains(clean, "../") { - return fmt.Errorf("path is absolute or contains a relative path traversal: %s", clean) - } - - target := filepath.Join(d, clean) - - switch header.Typeflag { - case tar.TypeDir: - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - case tar.TypeReg: - if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { - return fmt.Errorf("failed to create parent directory: %w", err) - } - - // #nosec G115 - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - - if _, err := io.Copy(out, io.LimitReader(df.Data, maxBytes)); err != nil { - out.Close() - return fmt.Errorf("failed to copy file: %w", err) - } - - if err := out.Close(); err != nil { - return fmt.Errorf("failed to close file: %w", err) - } - case tar.TypeSymlink: - // Skip symlinks for targets that do not exist - _, err = os.Readlink(target) - if os.IsNotExist(err) { - continue - } - // Ensure that symlinks are not relative path traversals - // #nosec G305 // L208 handles the check - linkReal, err := filepath.EvalSymlinks(filepath.Join(d, header.Linkname)) - if err != nil { - return fmt.Errorf("failed to evaluate symlink: %w", err) - } - if !isValidPath(linkReal, d) { - return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) - } - if err := os.Symlink(linkReal, target); err != nil { - return fmt.Errorf("failed to create symlink: %w", err) - } - } - } - - return nil -} - -func extractBz2(ctx context.Context, d, f string) error { - logger := clog.FromContext(ctx).With("dir", d, "file", f) - logger.Debug("extracting bzip2 file") - - // Check if the file is valid - _, err := os.Stat(f) - if err != nil { - return fmt.Errorf("failed to stat file: %w", err) - } - - tf, err := os.Open(f) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) - } - defer tf.Close() - // Set offset to the file origin regardless of type - _, err = tf.Seek(0, io.SeekStart) - if err != nil { - return fmt.Errorf("failed to seek to start: %w", err) - } - - br := bzip2.NewReader(tf) - uncompressed := strings.TrimSuffix(filepath.Base(f), ".bz2") - uncompressed = strings.TrimSuffix(uncompressed, ".bzip2") - target := filepath.Join(d, uncompressed) - if err := os.MkdirAll(d, 0o700); err != nil { - return fmt.Errorf("failed to create directory for file: %w", err) - } - - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - if _, err := io.Copy(out, io.LimitReader(br, maxBytes)); err != nil { - out.Close() - return fmt.Errorf("failed to copy file: %w", err) - } - return nil -} - -func extractNestedArchive( - ctx context.Context, - d string, - f string, - extracted *sync.Map, -) error { - isArchive := false - ext := getExt(f) - if _, ok := archiveMap[ext]; ok { - isArchive = true - } - //nolint:nestif // ignore complexity of 8 - if isArchive { - // Ensure the file was extracted and exists - fullPath := filepath.Join(d, f) - if _, err := os.Stat(fullPath); os.IsNotExist(err) { - return fmt.Errorf("file does not exist: %w", err) - } - extract := extractionMethod(ext) - if extract == nil { - return fmt.Errorf("unsupported archive type: %s", ext) - } - - err := extract(ctx, d, fullPath) - if err != nil { - return fmt.Errorf("extract nested archive: %w", err) - } - // Mark the file as extracted - extracted.Store(f, true) - - // Remove the nested archive file - // This is done to prevent the file from being scanned - if err := os.Remove(fullPath); err != nil { - return fmt.Errorf("failed to remove file: %w", err) - } - - // Check if there are any newly extracted files that are also archives - files, err := os.ReadDir(d) - if err != nil { - return fmt.Errorf("failed to read directory after extraction: %w", err) - } - for _, file := range files { - relPath := filepath.Join(d, file.Name()) - if _, isExtracted := extracted.Load(relPath); !isExtracted { - if err := extractNestedArchive(ctx, d, file.Name(), extracted); err != nil { - return fmt.Errorf("failed to extract nested archive %s: %w", file.Name(), err) - } - } - } - } - return nil -} - -// extractArchiveToTempDir creates a temporary directory and extracts the archive file for scanning. -func extractArchiveToTempDir(ctx context.Context, path string) (string, error) { - logger := clog.FromContext(ctx).With("path", path) - logger.Debug("creating temp dir") - - tmpDir, err := os.MkdirTemp("", filepath.Base(path)) - if err != nil { - return "", fmt.Errorf("failed to create temp dir: %w", err) - } - - ext := getExt(path) - extract := extractionMethod(ext) - if extract == nil { - return "", fmt.Errorf("unsupported archive type: %s", path) - } - err = extract(ctx, tmpDir, path) - if err != nil { - return "", fmt.Errorf("failed to extract %s: %w", path, err) - } - - var extractedFiles sync.Map - files, err := os.ReadDir(tmpDir) - if err != nil { - return "", fmt.Errorf("failed to read files in directory %s: %w", tmpDir, err) - } - for _, file := range files { - extractedFiles.Store(filepath.Join(tmpDir, file.Name()), false) - } - - extractedFiles.Range(func(key, _ any) bool { - if key == nil { - return true - } - //nolint: nestif // ignoring complexity of 11 - if file, ok := key.(string); ok { - ext := getExt(file) - info, err := os.Stat(file) - if err != nil { - return false - } - switch mode := info.Mode(); { - case mode.IsDir(): - err = filepath.WalkDir(file, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - rel, err := filepath.Rel(tmpDir, path) - if err != nil { - return fmt.Errorf("filepath.Rel: %w", err) - } - if !d.IsDir() { - if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil { - return fmt.Errorf("failed to extract nested archive %s: %w", rel, err) - } - } - - return nil - }) - if err != nil { - return false - } - return true - case mode.IsRegular(): - if _, ok := archiveMap[ext]; ok { - rel, err := filepath.Rel(tmpDir, file) - if err != nil { - return false - } - if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil { - return false - } - } - return true - } - } - return true - }) - - return tmpDir, nil -} - -func extractionMethod(ext string) func(context.Context, string, string) error { - switch ext { - case ".jar", ".zip", ".whl": - return extractZip - case ".gz": - return extractGzip - case ".apk", ".gem", ".tar", ".tar.bz2", ".tar.gz", ".tgz", ".tar.xz", ".tbz", ".xz": - return extractTar - case ".bz2", ".bzip2": - return extractBz2 - case ".rpm": - return extractRPM - case ".deb": - return extractDeb - default: - return nil - } -} diff --git a/pkg/action/archive_test.go b/pkg/action/archive_test.go index 571f903e..1f3f5d23 100644 --- a/pkg/action/archive_test.go +++ b/pkg/action/archive_test.go @@ -12,7 +12,9 @@ import ( "github.com/chainguard-dev/clog" "github.com/chainguard-dev/clog/slogtest" + "github.com/chainguard-dev/malcontent/pkg/archive" "github.com/chainguard-dev/malcontent/pkg/malcontent" + "github.com/chainguard-dev/malcontent/pkg/programkind" "github.com/chainguard-dev/malcontent/pkg/render" "github.com/chainguard-dev/malcontent/rules" thirdparty "github.com/chainguard-dev/malcontent/third_party" @@ -25,23 +27,23 @@ func TestExtractionMethod(t *testing.T) { ext string want func(context.Context, string, string) error }{ - {"apk", ".apk", extractTar}, - {"gem", ".gem", extractTar}, - {"gzip", ".gz", extractGzip}, - {"jar", ".jar", extractZip}, - {"tar.gz", ".tar.gz", extractTar}, - {"tar.xz", ".tar.xz", extractTar}, - {"tar", ".tar", extractTar}, - {"tgz", ".tgz", extractTar}, + {"apk", ".apk", archive.ExtractTar}, + {"gem", ".gem", archive.ExtractTar}, + {"gzip", ".gz", archive.ExtractGzip}, + {"jar", ".jar", archive.ExtractZip}, + {"tar.gz", ".tar.gz", archive.ExtractTar}, + {"tar.xz", ".tar.xz", archive.ExtractTar}, + {"tar", ".tar", archive.ExtractTar}, + {"tgz", ".tgz", archive.ExtractTar}, {"unknown", ".unknown", nil}, {"upx", ".upx", nil}, - {"zip", ".zip", extractZip}, + {"zip", ".zip", archive.ExtractZip}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := extractionMethod(tt.ext) + got := archive.ExtractionMethod(tt.ext) if (got == nil) != (tt.want == nil) { t.Errorf("extractionMethod() for extension %v did not return expected result", tt.ext) } @@ -75,7 +77,7 @@ func TestExtractionMultiple(t *testing.T) { t.Run(tt.path, func(t *testing.T) { t.Parallel() ctx := context.Background() - dir, err := extractArchiveToTempDir(ctx, tt.path) + dir, err := archive.ExtractArchiveToTempDir(ctx, tt.path) if err != nil { t.Fatal(err) } @@ -102,7 +104,7 @@ func TestExtractionMultiple(t *testing.T) { func TestExtractTar(t *testing.T) { t.Parallel() ctx := context.Background() - dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.tar.gz")) + dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.tar.gz")) if err != nil { t.Fatal(err) } @@ -130,7 +132,7 @@ func TestExtractTar(t *testing.T) { func TestExtractGzip(t *testing.T) { t.Parallel() ctx := context.Background() - dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.gz")) + dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.gz")) if err != nil { t.Fatal(err) } @@ -158,7 +160,7 @@ func TestExtractGzip(t *testing.T) { func TestExtractZip(t *testing.T) { t.Parallel() ctx := context.Background() - dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.zip")) + dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.zip")) if err != nil { t.Fatal(err) } @@ -186,7 +188,7 @@ func TestExtractZip(t *testing.T) { func TestExtractNestedArchive(t *testing.T) { t.Parallel() ctx := context.Background() - dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko_nested.tar.gz")) + dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko_nested.tar.gz")) if err != nil { t.Fatal(err) } @@ -325,7 +327,7 @@ func TestGetExt(t *testing.T) { for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { t.Parallel() - if got := getExt(tt.path); got != tt.want { + if got := programkind.GetExt(tt.path); got != tt.want { t.Errorf("Ext() = %v, want %v", got, tt.want) } }) @@ -402,7 +404,7 @@ func TestIsValidPath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := isValidPath(tt.target, tt.baseDir) + result := archive.IsValidPath(tt.target, tt.baseDir) if result != tt.expected { t.Errorf("isValidPath(%q, %q) = %v, want %v", tt.target, tt.baseDir, result, tt.expected) } diff --git a/pkg/action/diff.go b/pkg/action/diff.go index 3daa893a..c801921f 100644 --- a/pkg/action/diff.go +++ b/pkg/action/diff.go @@ -16,6 +16,7 @@ import ( "github.com/agext/levenshtein" "github.com/chainguard-dev/clog" "github.com/chainguard-dev/malcontent/pkg/malcontent" + "github.com/chainguard-dev/malcontent/pkg/programkind" orderedmap "github.com/wk8/go-ordered-map/v2" "golang.org/x/sync/errgroup" ) @@ -151,8 +152,8 @@ func Diff(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) var srcBase, destBase string srcCh := make(chan map[string]*malcontent.FileReport, 1) destCh := make(chan map[string]*malcontent.FileReport, 1) - srcIsArchive := isSupportedArchive(srcPath) - destIsArchive := isSupportedArchive(destPath) + srcIsArchive := programkind.IsSupportedArchive(srcPath) + destIsArchive := programkind.IsSupportedArchive(destPath) srcInfo, err := os.Stat(srcPath) if err != nil { diff --git a/pkg/action/scan.go b/pkg/action/scan.go index 9aab19ab..2b29465b 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -17,6 +17,7 @@ import ( "sync" "github.com/chainguard-dev/clog" + "github.com/chainguard-dev/malcontent/pkg/archive" "github.com/chainguard-dev/malcontent/pkg/compile" "github.com/chainguard-dev/malcontent/pkg/malcontent" "github.com/chainguard-dev/malcontent/pkg/programkind" @@ -294,7 +295,7 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report if c.OCI { // store the image URI for later use imageURI = scanPath - ociExtractPath, err = oci(ctx, imageURI) + ociExtractPath, err = archive.OCI(ctx, imageURI) logger.Debug("oci image", slog.Any("scanPath", scanPath), slog.Any("ociExtractPath", ociExtractPath)) if err != nil { return nil, fmt.Errorf("failed to prepare OCI image for scanning: %w", err) @@ -432,7 +433,7 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report case <-scanCtx.Done(): return scanCtx.Err() default: - if isSupportedArchive(path) { + if programkind.IsSupportedArchive(path) { return handleArchive(path) } return handleFile(path) @@ -493,7 +494,7 @@ func processArchive(ctx context.Context, c malcontent.Config, rfs []fs.FS, archi var err error var frs sync.Map - tmpRoot, err := extractArchiveToTempDir(ctx, archivePath) + tmpRoot, err := archive.ExtractArchiveToTempDir(ctx, archivePath) if err != nil { return nil, fmt.Errorf("extract to temp: %w", err) } diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go new file mode 100644 index 00000000..c4bd936b --- /dev/null +++ b/pkg/archive/archive.go @@ -0,0 +1,197 @@ +package archive + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/chainguard-dev/clog" + "github.com/chainguard-dev/malcontent/pkg/programkind" +) + +// isValidPath checks if the target file is within the given directory. +func IsValidPath(target, dir string) bool { + return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir)) +} + +const maxBytes = 1 << 29 // 512MB + +func extractNestedArchive( + ctx context.Context, + d string, + f string, + extracted *sync.Map, +) error { + isArchive := false + // zlib-compressed files are also archives + ft, err := programkind.File(f) + if err != nil { + return fmt.Errorf("failed to determine file type: %w", err) + } + if ft != nil && ft.MIME == "application/zlib" { + isArchive = true + } + if _, ok := programkind.ArchiveMap[programkind.GetExt(f)]; ok { + isArchive = true + } + //nolint:nestif // ignore complexity of 8 + if isArchive { + // Ensure the file was extracted and exists + fullPath := filepath.Join(d, f) + if _, err := os.Stat(fullPath); os.IsNotExist(err) { + return fmt.Errorf("file does not exist: %w", err) + } + + var extract func(context.Context, string, string) error + // Check for zlib-compressed files first and use the zlib-specific function + ft, err := programkind.File(fullPath) + if err != nil { + return fmt.Errorf("failed to determine file type: %w", err) + } + if ft != nil && ft.MIME == "application/zlib" { + extract = ExtractZlib + } else { + extract = ExtractionMethod(programkind.GetExt(fullPath)) + } + err = extract(ctx, d, fullPath) + if err != nil { + return fmt.Errorf("extract nested archive: %w", err) + } + // Mark the file as extracted + extracted.Store(f, true) + + // Remove the nested archive file + // This is done to prevent the file from being scanned + if err := os.Remove(fullPath); err != nil { + return fmt.Errorf("failed to remove file: %w", err) + } + + // Check if there are any newly extracted files that are also archives + files, err := os.ReadDir(d) + if err != nil { + return fmt.Errorf("failed to read directory after extraction: %w", err) + } + for _, file := range files { + relPath := filepath.Join(d, file.Name()) + if _, isExtracted := extracted.Load(relPath); !isExtracted { + if err := extractNestedArchive(ctx, d, file.Name(), extracted); err != nil { + return fmt.Errorf("failed to extract nested archive %s: %w", file.Name(), err) + } + } + } + } + return nil +} + +// extractArchiveToTempDir creates a temporary directory and extracts the archive file for scanning. +func ExtractArchiveToTempDir(ctx context.Context, path string) (string, error) { + logger := clog.FromContext(ctx).With("path", path) + logger.Debug("creating temp dir") + + tmpDir, err := os.MkdirTemp("", filepath.Base(path)) + if err != nil { + return "", fmt.Errorf("failed to create temp dir: %w", err) + } + + var extract func(context.Context, string, string) error + // Check for zlib-compressed files first and use the zlib-specific function + ft, err := programkind.File(path) + if err != nil { + return "", fmt.Errorf("failed to determine file type: %w", err) + } + if ft != nil && ft.MIME == "application/zlib" { + extract = ExtractZlib + } else { + extract = ExtractionMethod(programkind.GetExt(path)) + } + if extract == nil { + return "", fmt.Errorf("unsupported archive type: %s", path) + } + err = extract(ctx, tmpDir, path) + if err != nil { + return "", fmt.Errorf("failed to extract %s: %w", path, err) + } + + var extractedFiles sync.Map + files, err := os.ReadDir(tmpDir) + if err != nil { + return "", fmt.Errorf("failed to read files in directory %s: %w", tmpDir, err) + } + for _, file := range files { + extractedFiles.Store(filepath.Join(tmpDir, file.Name()), false) + } + + extractedFiles.Range(func(key, _ any) bool { + if key == nil { + return true + } + //nolint: nestif // ignoring complexity of 11 + if file, ok := key.(string); ok { + ext := programkind.GetExt(file) + info, err := os.Stat(file) + if err != nil { + return false + } + switch mode := info.Mode(); { + case mode.IsDir(): + err = filepath.WalkDir(file, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + rel, err := filepath.Rel(tmpDir, path) + if err != nil { + return fmt.Errorf("filepath.Rel: %w", err) + } + if !d.IsDir() { + if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil { + return fmt.Errorf("failed to extract nested archive %s: %w", rel, err) + } + } + + return nil + }) + if err != nil { + return false + } + return true + case mode.IsRegular(): + if _, ok := programkind.ArchiveMap[ext]; ok { + rel, err := filepath.Rel(tmpDir, file) + if err != nil { + return false + } + if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil { + return false + } + } + return true + } + } + return true + }) + + return tmpDir, nil +} + +func ExtractionMethod(ext string) func(context.Context, string, string) error { + switch ext { + case ".jar", ".zip", ".whl": + return ExtractZip + case ".gz": + return ExtractGzip + case ".apk", ".gem", ".tar", ".tar.bz2", ".tar.gz", ".tgz", ".tar.xz", ".tbz", ".xz": + return ExtractTar + case ".bz2", ".bzip2": + return ExtractBz2 + case ".rpm": + return ExtractRPM + case ".deb": + return ExtractDeb + default: + return nil + } +} diff --git a/pkg/archive/bz2.go b/pkg/archive/bz2.go new file mode 100644 index 00000000..75313f27 --- /dev/null +++ b/pkg/archive/bz2.go @@ -0,0 +1,57 @@ +package archive + +import ( + "compress/bzip2" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/chainguard-dev/clog" +) + +// Extract Bz2 extracts bzip2 files. +func ExtractBz2(ctx context.Context, d, f string) error { + logger := clog.FromContext(ctx).With("dir", d, "file", f) + logger.Debug("extracting bzip2 file") + + // Check if the file is valid + _, err := os.Stat(f) + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) + } + + tf, err := os.Open(f) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer tf.Close() + // Set offset to the file origin regardless of type + _, err = tf.Seek(0, io.SeekStart) + if err != nil { + return fmt.Errorf("failed to seek to start: %w", err) + } + + br := bzip2.NewReader(tf) + uncompressed := strings.TrimSuffix(filepath.Base(f), ".bz2") + uncompressed = strings.TrimSuffix(uncompressed, ".bzip2") + target := filepath.Join(d, uncompressed) + if err := os.MkdirAll(d, 0o700); err != nil { + return fmt.Errorf("failed to create directory for file: %w", err) + } + + // #nosec G115 // ignore Type conversion which leads to integer overflow + // header.Mode is int64 and FileMode is uint32 + out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer out.Close() + if _, err := io.Copy(out, io.LimitReader(br, maxBytes)); err != nil { + out.Close() + return fmt.Errorf("failed to copy file: %w", err) + } + return nil +} diff --git a/pkg/archive/deb.go b/pkg/archive/deb.go new file mode 100644 index 00000000..cf94e963 --- /dev/null +++ b/pkg/archive/deb.go @@ -0,0 +1,98 @@ +package archive + +import ( + "archive/tar" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/chainguard-dev/clog" + "github.com/egibs/go-debian/deb" +) + +// ExtractDeb extracts .deb packages. +func ExtractDeb(ctx context.Context, d, f string) error { + logger := clog.FromContext(ctx).With("dir", d, "file", f) + logger.Debug("extracting deb") + + fd, err := os.Open(f) + if err != nil { + panic(err) + } + defer fd.Close() + + df, err := deb.Load(fd, f) + if err != nil { + panic(err) + } + defer df.Close() + + for { + header, err := df.Data.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + clean := filepath.Clean(header.Name) + if filepath.IsAbs(clean) || strings.Contains(clean, "../") { + return fmt.Errorf("path is absolute or contains a relative path traversal: %s", clean) + } + + target := filepath.Join(d, clean) + + switch header.Typeflag { + case tar.TypeDir: + // #nosec G115 // ignore Type conversion which leads to integer overflow + // header.Mode is int64 and FileMode is uint32 + if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + // #nosec G115 + out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + + if _, err := io.Copy(out, io.LimitReader(df.Data, maxBytes)); err != nil { + out.Close() + return fmt.Errorf("failed to copy file: %w", err) + } + + if err := out.Close(); err != nil { + return fmt.Errorf("failed to close file: %w", err) + } + case tar.TypeSymlink: + // Skip symlinks for targets that do not exist + _, err = os.Readlink(target) + if os.IsNotExist(err) { + continue + } + // Ensure that symlinks are not relative path traversals + // #nosec G305 // L208 handles the check + linkReal, err := filepath.EvalSymlinks(filepath.Join(d, header.Linkname)) + if err != nil { + return fmt.Errorf("failed to evaluate symlink: %w", err) + } + if !IsValidPath(linkReal, d) { + return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) + } + if err := os.Symlink(linkReal, target); err != nil { + return fmt.Errorf("failed to create symlink: %w", err) + } + } + } + + return nil +} diff --git a/pkg/archive/gzip.go b/pkg/archive/gzip.go new file mode 100644 index 00000000..9260253c --- /dev/null +++ b/pkg/archive/gzip.go @@ -0,0 +1,51 @@ +package archive + +import ( + "compress/gzip" + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/chainguard-dev/clog" +) + +// extractGzip extracts .gz archives. +func ExtractGzip(ctx context.Context, d string, f string) error { + logger := clog.FromContext(ctx).With("dir", d, "file", f) + logger.Debug("extracting gzip") + + // Check if the file is valid + _, err := os.Stat(f) + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) + } + + gf, err := os.Open(f) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer gf.Close() + + base := filepath.Base(f) + target := filepath.Join(d, base[:len(base)-len(filepath.Ext(base))]) + + gr, err := gzip.NewReader(gf) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gr.Close() + + ef, err := os.Create(target) + if err != nil { + return fmt.Errorf("failed to create extracted file: %w", err) + } + defer ef.Close() + + if _, err := io.Copy(ef, io.LimitReader(gr, maxBytes)); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} diff --git a/pkg/action/oci.go b/pkg/archive/oci.go similarity index 91% rename from pkg/action/oci.go rename to pkg/archive/oci.go index 59c6499d..8d2d4a05 100644 --- a/pkg/action/oci.go +++ b/pkg/archive/oci.go @@ -1,4 +1,4 @@ -package action +package archive import ( "context" @@ -40,13 +40,13 @@ func prepareImage(ctx context.Context, d string) (string, *os.File, error) { } // return a directory with the extracted image directories/files in it. -func oci(ctx context.Context, path string) (string, error) { +func OCI(ctx context.Context, path string) (string, error) { tmpDir, tmpFile, err := prepareImage(ctx, path) if err != nil { return "", fmt.Errorf("failed to prepare image: %w", err) } - err = extractTar(ctx, tmpDir, tmpFile.Name()) + err = ExtractTar(ctx, tmpDir, tmpFile.Name()) if err != nil { return "", fmt.Errorf("extract tar: %w", err) } diff --git a/pkg/archive/rpm.go b/pkg/archive/rpm.go new file mode 100644 index 00000000..56a38ec8 --- /dev/null +++ b/pkg/archive/rpm.go @@ -0,0 +1,110 @@ +package archive + +import ( + "compress/gzip" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/cavaliergopher/cpio" + "github.com/cavaliergopher/rpm" + "github.com/chainguard-dev/clog" + "github.com/ulikunitz/xz" +) + +// extractRPM extracts .rpm packages. +func ExtractRPM(ctx context.Context, d, f string) error { + logger := clog.FromContext(ctx).With("dir", d, "file", f) + logger.Debug("extracting rpm") + + rpmFile, err := os.Open(f) + if err != nil { + return fmt.Errorf("failed to open RPM file: %w", err) + } + defer rpmFile.Close() + + pkg, err := rpm.Read(rpmFile) + if err != nil { + return fmt.Errorf("failed to read RPM package headers: %w", err) + } + + if format := pkg.PayloadFormat(); format != "cpio" { + return fmt.Errorf("unsupported payload format: %s", format) + } + + payloadOffset, err := rpmFile.Seek(0, io.SeekCurrent) + if err != nil { + return fmt.Errorf("failed to get payload offset: %w", err) + } + + if _, err := rpmFile.Seek(payloadOffset, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek to payload: %w", err) + } + + var cr *cpio.Reader + switch compression := pkg.PayloadCompression(); compression { + case "gzip": + gzStream, err := gzip.NewReader(rpmFile) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzStream.Close() + cr = cpio.NewReader(gzStream) + case "xz": + xzStream, err := xz.NewReader(rpmFile) + if err != nil { + return fmt.Errorf("failed to create xz reader: %w", err) + } + cr = cpio.NewReader(xzStream) + default: + return fmt.Errorf("unsupported compression format: %s", compression) + } + + for { + header, err := cr.Next() + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + if err != nil { + return fmt.Errorf("failed to read cpio header: %w", err) + } + + clean := filepath.Clean(header.Name) + if filepath.IsAbs(clean) || strings.Contains(clean, "../") { + return fmt.Errorf("path is absolute or contains a relative path traversal: %s", clean) + } + + target := filepath.Join(d, clean) + + if header.FileInfo().IsDir() { + if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + continue + } + + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + + if _, err := io.Copy(out, io.LimitReader(cr, maxBytes)); err != nil { + out.Close() + return fmt.Errorf("failed to copy file: %w", err) + } + + if err := out.Close(); err != nil { + return fmt.Errorf("failed to close file: %w", err) + } + } + + return nil +} diff --git a/pkg/archive/tar.go b/pkg/archive/tar.go new file mode 100644 index 00000000..aa7fade7 --- /dev/null +++ b/pkg/archive/tar.go @@ -0,0 +1,156 @@ +package archive + +import ( + "archive/tar" + "compress/bzip2" + "compress/gzip" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/chainguard-dev/clog" + "github.com/ulikunitz/xz" +) + +// extractTar extracts .apk and .tar* archives. +func ExtractTar(ctx context.Context, d string, f string) error { + logger := clog.FromContext(ctx).With("dir", d, "file", f) + logger.Debug("extracting tar") + + // Check if the file is valid + _, err := os.Stat(f) + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) + } + + filename := filepath.Base(f) + tf, err := os.Open(f) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer tf.Close() + // Set offset to the file origin regardless of type + _, err = tf.Seek(0, io.SeekStart) + if err != nil { + return fmt.Errorf("failed to seek to start: %w", err) + } + + var tr *tar.Reader + + switch { + case strings.Contains(f, ".apk") || strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz"): + gzStream, err := gzip.NewReader(tf) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzStream.Close() + tr = tar.NewReader(gzStream) + case strings.Contains(filename, ".tar.xz"): + xzStream, err := xz.NewReader(tf) + if err != nil { + return fmt.Errorf("failed to create xz reader: %w", err) + } + tr = tar.NewReader(xzStream) + case strings.Contains(filename, ".xz"): + xzStream, err := xz.NewReader(tf) + if err != nil { + return fmt.Errorf("failed to create xz reader: %w", err) + } + uncompressed := strings.Trim(filepath.Base(f), ".xz") + target := filepath.Join(d, uncompressed) + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + return fmt.Errorf("failed to create directory for file: %w", err) + } + + // #nosec G115 // ignore Type conversion which leads to integer overflow + // header.Mode is int64 and FileMode is uint32 + f, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer f.Close() + if _, err = io.Copy(f, xzStream); err != nil { + return fmt.Errorf("failed to write decompressed xz output: %w", err) + } + return nil + case strings.Contains(filename, ".tar.bz2") || strings.Contains(filename, ".tbz"): + br := bzip2.NewReader(tf) + tr = tar.NewReader(br) + default: + tr = tar.NewReader(tf) + } + + for { + header, err := tr.Next() + + if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { + break + } + + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + clean := filepath.Clean(header.Name) + if filepath.IsAbs(clean) || strings.Contains(clean, "../") { + return fmt.Errorf("path is absolute or contains a relative path traversal: %s", clean) + } + + target := filepath.Join(d, clean) + if !IsValidPath(target, d) { + return fmt.Errorf("invalid file path: %s", target) + } + + switch header.Typeflag { + case tar.TypeDir: + // #nosec G115 // ignore Type conversion which leads to integer overflow + // header.Mode is int64 and FileMode is uint32 + if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + // #nosec G115 // ignore Type conversion which leads to integer overflow + // header.Mode is int64 and FileMode is uint32 + out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + + if _, err := io.Copy(out, io.LimitReader(tr, maxBytes)); err != nil { + out.Close() + return fmt.Errorf("failed to copy file: %w", err) + } + + if err := out.Close(); err != nil { + return fmt.Errorf("failed to close file: %w", err) + } + case tar.TypeSymlink: + // Skip symlinks for targets that do not exist + _, err = os.Readlink(target) + if os.IsNotExist(err) { + continue + } + // Ensure that symlinks are not relative path traversals + // #nosec G305 // L208 handles the check + linkReal, err := filepath.EvalSymlinks(filepath.Join(d, header.Linkname)) + if err != nil { + return fmt.Errorf("failed to evaluate symlink: %w", err) + } + if !IsValidPath(target, d) { + return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) + } + if err := os.Symlink(linkReal, target); err != nil { + return fmt.Errorf("failed to create symlink: %w", err) + } + } + } + return nil +} diff --git a/pkg/archive/zip.go b/pkg/archive/zip.go new file mode 100644 index 00000000..4714ceae --- /dev/null +++ b/pkg/archive/zip.go @@ -0,0 +1,87 @@ +package archive + +import ( + "archive/zip" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/chainguard-dev/clog" +) + +// extractZip extracts .jar and .zip archives. +func ExtractZip(ctx context.Context, d string, f string) error { + logger := clog.FromContext(ctx).With("dir", d, "file", f) + logger.Debug("extracting zip") + + // Check if the file is valid + _, err := os.Stat(f) + if err != nil { + return fmt.Errorf("failed to stat file %s: %w", f, err) + } + + read, err := zip.OpenReader(f) + if err != nil { + return fmt.Errorf("failed to open zip file %s: %w", f, err) + } + defer read.Close() + + for _, file := range read.File { + clean := filepath.Clean(filepath.ToSlash(file.Name)) + if strings.Contains(clean, "..") { + logger.Warnf("skipping potentially unsafe file path: %s", file.Name) + continue + } + + name := filepath.Join(d, clean) + if !IsValidPath(name, d) { + logger.Warnf("skipping file path outside extraction directory: %s", name) + continue + } + + // Check if a directory with the same name exists + if info, err := os.Stat(name); err == nil && info.IsDir() { + continue + } + + if file.Mode().IsDir() { + mode := file.Mode() | 0o700 + err := os.MkdirAll(name, mode) + if err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + continue + } + + open, err := file.Open() + if err != nil { + return fmt.Errorf("failed to open file in zip: %w", err) + } + + err = os.MkdirAll(filepath.Dir(name), 0o700) + if err != nil { + open.Close() + return fmt.Errorf("failed to create directory: %w", err) + } + + mode := file.Mode() | 0o200 + create, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) + if err != nil { + open.Close() + return fmt.Errorf("failed to create file: %w", err) + } + + if _, err = io.Copy(create, io.LimitReader(open, maxBytes)); err != nil { + open.Close() + create.Close() + return fmt.Errorf("failed to copy file: %w", err) + } + + open.Close() + create.Close() + } + return nil +} diff --git a/pkg/archive/zlib.go b/pkg/archive/zlib.go new file mode 100644 index 00000000..1ff356db --- /dev/null +++ b/pkg/archive/zlib.go @@ -0,0 +1,51 @@ +package archive + +import ( + "compress/zlib" + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/chainguard-dev/clog" +) + +// extractZlib extracts extension-agnostic zlib-compressed files. +func ExtractZlib(ctx context.Context, d string, f string) error { + logger := clog.FromContext(ctx).With("dir", d, "file", f) + logger.Debugf("extracting zlib") + + // Check if the file is valid + _, err := os.Stat(f) + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) + } + + gf, err := os.Open(f) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer gf.Close() + + base := filepath.Base(f) + target := filepath.Join(d, base[:len(base)-len(filepath.Ext(base))]) + + zr, err := zlib.NewReader(gf) + if err != nil { + return fmt.Errorf("failed to create zlib reader: %w", err) + } + defer zr.Close() + + ef, err := os.Create(target) + if err != nil { + return fmt.Errorf("failed to create extracted file: %w", err) + } + defer ef.Close() + + if _, err := io.Copy(ef, io.LimitReader(zr, maxBytes)); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} diff --git a/pkg/programkind/programkind.go b/pkg/programkind/programkind.go index e0fe6e15..f0d700a7 100644 --- a/pkg/programkind/programkind.go +++ b/pkg/programkind/programkind.go @@ -10,11 +10,31 @@ import ( "io/fs" "os" "path/filepath" + "regexp" "strings" "github.com/gabriel-vasile/mimetype" ) +// Supported archive extensions. +var ArchiveMap = map[string]bool{ + ".apk": true, + ".bz2": true, + ".bzip2": true, + ".deb": true, + ".gem": true, + ".gz": true, + ".jar": true, + ".rpm": true, + ".tar": true, + ".tar.gz": true, + ".tar.xz": true, + ".tgz": true, + ".whl": true, + ".xz": true, + ".zip": true, +} + // file extension to MIME type, if it's a good scanning target. var supportedKind = map[string]string{ "7z": "", @@ -78,6 +98,39 @@ type FileType struct { MIME string } +// IsSupportedArchive returns whether a path can be processed by our archive extractor. +func IsSupportedArchive(path string) bool { + return ArchiveMap[GetExt(path)] +} + +// getExt returns the extension of a file path +// and attempts to avoid including fragments of filenames with other dots before the extension. +func GetExt(path string) string { + base := filepath.Base(path) + + // Handle files with version numbers in the name + // e.g. file1.2.3.tar.gz -> .tar.gz + re := regexp.MustCompile(`\d+\.\d+\.\d+$`) + base = re.ReplaceAllString(base, "") + + ext := filepath.Ext(base) + + if ext != "" && strings.Contains(base, ".") { + parts := strings.Split(base, ".") + if len(parts) > 2 { + subExt := fmt.Sprintf(".%s%s", parts[len(parts)-2], ext) + if isValidExt := func(ext string) bool { + _, ok := ArchiveMap[ext] + return ok + }(subExt); isValidExt { + return subExt + } + } + } + + return ext +} + func makeFileType(path string, ext string, mime string) *FileType { ext = strings.TrimPrefix(ext, ".")