From 4c9eb737ac3f981909d8cf323d245cdd2de98081 Mon Sep 17 00:00:00 2001 From: Evan Gibler <20933572+egibs@users.noreply.github.com> Date: Thu, 2 Jan 2025 19:39:08 -0600 Subject: [PATCH] Address more extraction edge cases; improve naming and consistency (#733) * Address more gzip, tar, and tar.gz edge cases Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Add ordering comment; move tar case to top Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Add .gzip for completeness, move above .zip Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Address PR comments Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Fix deb.go symlink validation Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Move deb.go and tar.go extraction logic to helper functions Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Re-add check for nonexistent targets Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Universal naming/consistency changes Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Run go mod tidy Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --------- Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --- go.mod | 2 +- pkg/archive/archive.go | 78 +++++++++++++++++++++++++++++++++++++++--- pkg/archive/bz2.go | 14 ++++++-- pkg/archive/deb.go | 44 +++++------------------- pkg/archive/gzip.go | 26 ++++++++++++-- pkg/archive/rpm.go | 12 +++++-- pkg/archive/tar.go | 62 +++++++++++---------------------- pkg/archive/upx.go | 3 ++ pkg/archive/zip.go | 39 ++++++++++++--------- pkg/archive/zlib.go | 16 +++++---- 10 files changed, 181 insertions(+), 115 deletions(-) diff --git a/go.mod b/go.mod index fdf7cba6..d0f15f4d 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.20.2 github.com/hillu/go-yara/v4 v4.3.3 + github.com/klauspost/compress v1.17.11 github.com/olekukonko/tablewriter v0.0.5 github.com/shirou/gopsutil/v4 v4.24.11 github.com/ulikunitz/xz v0.5.12 @@ -40,7 +41,6 @@ require ( github.com/ebitengine/purego v0.8.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/go-ole/go-ole v1.3.0 // indirect - github.com/klauspost/compress v1.17.11 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 14c91450..d2188642 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -1,8 +1,10 @@ package archive import ( + "archive/tar" "context" "fmt" + "io" "io/fs" "os" "path/filepath" @@ -190,13 +192,16 @@ func ExtractArchiveToTempDir(ctx context.Context, path string) (string, error) { } func ExtractionMethod(ext string) func(context.Context, string, string) error { + // The ordering of these statements is important, especially for extensions + // that are substrings of other extensions (e.g., `.gz` and `.tar.gz` or `.tgz`) switch ext { - case ".jar", ".zip", ".whl": - return ExtractZip - case ".gz": - return ExtractGzip + // New cases should go below this line so that the lengthier tar extensions are evaluated first case ".apk", ".gem", ".tar", ".tar.bz2", ".tar.gz", ".tgz", ".tar.xz", ".tbz", ".xz": return ExtractTar + case ".gz", ".gzip": + return ExtractGzip + case ".jar", ".zip", ".whl": + return ExtractZip case ".bz2", ".bzip2": return ExtractBz2 case ".rpm": @@ -207,3 +212,68 @@ func ExtractionMethod(ext string) func(context.Context, string, string) error { return nil } } + +// handleDirectory extracts valid directories within .deb or .tar archives. +func handleDirectory(target string) error { + if err := os.MkdirAll(target, 0o700); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + return nil +} + +// handleFile extracts valid files within .deb or .tar archives. +func handleFile(target string, tr *tar.Reader) error { + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer out.Close() + + written, err := io.Copy(out, io.LimitReader(tr, maxBytes)) + if err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } + + return nil +} + +// handleSymlink creates valid symlinks when extracting .deb or .tar archives. +func handleSymlink(dir, linkName, target string) error { + // Skip symlinks for targets that do not exist + _, err := os.Readlink(target) + if os.IsNotExist(err) { + return nil + } + + fullLink := filepath.Join(dir, linkName) + + // Remove existing symlinks + if _, err := os.Lstat(fullLink); err == nil { + if err := os.Remove(fullLink); err != nil { + return fmt.Errorf("failed to remove existing symlink: %w", err) + } + } + + if err := os.Symlink(target, fullLink); err != nil { + return fmt.Errorf("failed to create symlink: %w", err) + } + + linkReal, err := filepath.EvalSymlinks(fullLink) + if err != nil { + os.Remove(fullLink) + return fmt.Errorf("failed to evaluate symlink: %w", err) + } + if !IsValidPath(linkReal, dir) { + os.Remove(fullLink) + return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) + } + + return nil +} diff --git a/pkg/archive/bz2.go b/pkg/archive/bz2.go index 75313f27..d99e5b57 100644 --- a/pkg/archive/bz2.go +++ b/pkg/archive/bz2.go @@ -38,20 +38,28 @@ func ExtractBz2(ctx context.Context, d, f string) error { uncompressed := strings.TrimSuffix(filepath.Base(f), ".bz2") uncompressed = strings.TrimSuffix(uncompressed, ".bzip2") target := filepath.Join(d, uncompressed) + if !IsValidPath(target, d) { + return fmt.Errorf("invalid file path: %s", target) + } if err := os.MkdirAll(d, 0o700); err != nil { return fmt.Errorf("failed to create directory for file: %w", err) } // #nosec G115 // ignore Type conversion which leads to integer overflow // header.Mode is int64 and FileMode is uint32 - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) + out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return fmt.Errorf("failed to create file: %w", err) } defer out.Close() - if _, err := io.Copy(out, io.LimitReader(br, maxBytes)); err != nil { - out.Close() + + written, err := io.Copy(out, io.LimitReader(br, maxBytes)) + if err != nil { return fmt.Errorf("failed to copy file: %w", err) } + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } + return nil } diff --git a/pkg/archive/deb.go b/pkg/archive/deb.go index cf94e963..d1c32061 100644 --- a/pkg/archive/deb.go +++ b/pkg/archive/deb.go @@ -46,49 +46,21 @@ func ExtractDeb(ctx context.Context, d, f string) error { } target := filepath.Join(d, clean) + if !IsValidPath(target, d) { + return fmt.Errorf("invalid file path: %s", target) + } switch header.Typeflag { case tar.TypeDir: - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("failed to create directory: %w", err) + if err := handleDirectory(target); err != nil { + return fmt.Errorf("failed to extract directory: %w", err) } case tar.TypeReg: - if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { - return fmt.Errorf("failed to create parent directory: %w", err) - } - - // #nosec G115 - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - - if _, err := io.Copy(out, io.LimitReader(df.Data, maxBytes)); err != nil { - out.Close() - return fmt.Errorf("failed to copy file: %w", err) - } - - if err := out.Close(); err != nil { - return fmt.Errorf("failed to close file: %w", err) + if err := handleFile(target, df.Data); err != nil { + return fmt.Errorf("failed to extract file: %w", err) } case tar.TypeSymlink: - // Skip symlinks for targets that do not exist - _, err = os.Readlink(target) - if os.IsNotExist(err) { - continue - } - // Ensure that symlinks are not relative path traversals - // #nosec G305 // L208 handles the check - linkReal, err := filepath.EvalSymlinks(filepath.Join(d, header.Linkname)) - if err != nil { - return fmt.Errorf("failed to evaluate symlink: %w", err) - } - if !IsValidPath(linkReal, d) { - return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) - } - if err := os.Symlink(linkReal, target); err != nil { + if err := handleSymlink(d, header.Linkname, target); err != nil { return fmt.Errorf("failed to create symlink: %w", err) } } diff --git a/pkg/archive/gzip.go b/pkg/archive/gzip.go index 9260253c..41976505 100644 --- a/pkg/archive/gzip.go +++ b/pkg/archive/gzip.go @@ -9,10 +9,23 @@ import ( "path/filepath" "github.com/chainguard-dev/clog" + "github.com/chainguard-dev/malcontent/pkg/programkind" ) // extractGzip extracts .gz archives. func ExtractGzip(ctx context.Context, d string, f string) error { + // Check whether the provided file is a valid gzip archive + var isGzip bool + if ft, err := programkind.File(f); err == nil && ft != nil { + if ft.MIME == "application/gzip" { + isGzip = true + } + } + + if !isGzip { + return fmt.Errorf("not a valid gzip archive") + } + logger := clog.FromContext(ctx).With("dir", d, "file", f) logger.Debug("extracting gzip") @@ -30,6 +43,9 @@ func ExtractGzip(ctx context.Context, d string, f string) error { base := filepath.Base(f) target := filepath.Join(d, base[:len(base)-len(filepath.Ext(base))]) + if !IsValidPath(target, d) { + return fmt.Errorf("invalid file path: %s", target) + } gr, err := gzip.NewReader(gf) if err != nil { @@ -37,15 +53,19 @@ func ExtractGzip(ctx context.Context, d string, f string) error { } defer gr.Close() - ef, err := os.Create(target) + out, err := os.Create(target) if err != nil { return fmt.Errorf("failed to create extracted file: %w", err) } - defer ef.Close() + defer out.Close() - if _, err := io.Copy(ef, io.LimitReader(gr, maxBytes)); err != nil { + written, err := io.Copy(out, io.LimitReader(gr, maxBytes)) + if err != nil { return fmt.Errorf("failed to copy file: %w", err) } + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } return nil } diff --git a/pkg/archive/rpm.go b/pkg/archive/rpm.go index 10d04c06..e0629c6a 100644 --- a/pkg/archive/rpm.go +++ b/pkg/archive/rpm.go @@ -86,6 +86,9 @@ func ExtractRPM(ctx context.Context, d, f string) error { } target := filepath.Join(d, clean) + if !IsValidPath(target, d) { + return fmt.Errorf("invalid file path: %s", target) + } if header.FileInfo().IsDir() { if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { @@ -98,15 +101,18 @@ func ExtractRPM(ctx context.Context, d, f string) error { return fmt.Errorf("failed to create parent directory: %w", err) } - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) + out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return fmt.Errorf("failed to create file: %w", err) } - if _, err := io.Copy(out, io.LimitReader(cr, maxBytes)); err != nil { - out.Close() + written, err := io.Copy(out, io.LimitReader(cr, maxBytes)) + if err != nil { return fmt.Errorf("failed to copy file: %w", err) } + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } if err := out.Close(); err != nil { return fmt.Errorf("failed to close file: %w", err) diff --git a/pkg/archive/tar.go b/pkg/archive/tar.go index aa7fade7..60c12c8d 100644 --- a/pkg/archive/tar.go +++ b/pkg/archive/tar.go @@ -13,6 +13,7 @@ import ( "strings" "github.com/chainguard-dev/clog" + "github.com/chainguard-dev/malcontent/pkg/programkind" "github.com/ulikunitz/xz" ) @@ -33,6 +34,15 @@ func ExtractTar(ctx context.Context, d string, f string) error { return fmt.Errorf("failed to open file: %w", err) } defer tf.Close() + + isTGZ := strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz") + var isGzip bool + if ft, err := programkind.File(f); err == nil && ft != nil { + if ft.MIME == "application/gzip" { + isGzip = true + } + } + // Set offset to the file origin regardless of type _, err = tf.Seek(0, io.SeekStart) if err != nil { @@ -40,9 +50,8 @@ func ExtractTar(ctx context.Context, d string, f string) error { } var tr *tar.Reader - switch { - case strings.Contains(f, ".apk") || strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz"): + case strings.Contains(f, ".apk") || (isTGZ && isGzip): gzStream, err := gzip.NewReader(tf) if err != nil { return fmt.Errorf("failed to create gzip reader: %w", err) @@ -68,12 +77,13 @@ func ExtractTar(ctx context.Context, d string, f string) error { // #nosec G115 // ignore Type conversion which leads to integer overflow // header.Mode is int64 and FileMode is uint32 - f, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) + out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return fmt.Errorf("failed to create file: %w", err) } - defer f.Close() - if _, err = io.Copy(f, xzStream); err != nil { + defer out.Close() + + if _, err = io.Copy(out, xzStream); err != nil { return fmt.Errorf("failed to write decompressed xz output: %w", err) } return nil @@ -107,47 +117,15 @@ func ExtractTar(ctx context.Context, d string, f string) error { switch header.Typeflag { case tar.TypeDir: - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("failed to create directory: %w", err) + if err := handleDirectory(target); err != nil { + return fmt.Errorf("failed to extract directory: %w", err) } case tar.TypeReg: - if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { - return fmt.Errorf("failed to create parent directory: %w", err) - } - - // #nosec G115 // ignore Type conversion which leads to integer overflow - // header.Mode is int64 and FileMode is uint32 - out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - - if _, err := io.Copy(out, io.LimitReader(tr, maxBytes)); err != nil { - out.Close() - return fmt.Errorf("failed to copy file: %w", err) - } - - if err := out.Close(); err != nil { - return fmt.Errorf("failed to close file: %w", err) + if err := handleFile(target, tr); err != nil { + return fmt.Errorf("failed to extract file: %w", err) } case tar.TypeSymlink: - // Skip symlinks for targets that do not exist - _, err = os.Readlink(target) - if os.IsNotExist(err) { - continue - } - // Ensure that symlinks are not relative path traversals - // #nosec G305 // L208 handles the check - linkReal, err := filepath.EvalSymlinks(filepath.Join(d, header.Linkname)) - if err != nil { - return fmt.Errorf("failed to evaluate symlink: %w", err) - } - if !IsValidPath(target, d) { - return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) - } - if err := os.Symlink(linkReal, target); err != nil { + if err := handleSymlink(d, header.Linkname, target); err != nil { return fmt.Errorf("failed to create symlink: %w", err) } } diff --git a/pkg/archive/upx.go b/pkg/archive/upx.go index 073d1565..fad9ea32 100644 --- a/pkg/archive/upx.go +++ b/pkg/archive/upx.go @@ -34,6 +34,9 @@ func ExtractUPX(ctx context.Context, d, f string) error { base := filepath.Base(f) target := filepath.Join(d, base[:len(base)-len(filepath.Ext(base))]) + if !IsValidPath(target, d) { + return fmt.Errorf("invalid file path: %s", target) + } // copy the file to the temporary directory before decompressing tf, err := os.ReadFile(f) diff --git a/pkg/archive/zip.go b/pkg/archive/zip.go index 4714ceae..619ae6dd 100644 --- a/pkg/archive/zip.go +++ b/pkg/archive/zip.go @@ -36,52 +36,57 @@ func ExtractZip(ctx context.Context, d string, f string) error { continue } - name := filepath.Join(d, clean) - if !IsValidPath(name, d) { - logger.Warnf("skipping file path outside extraction directory: %s", name) + target := filepath.Join(d, clean) + if !IsValidPath(target, d) { + logger.Warnf("skipping file path outside extraction directory: %s", target) continue } // Check if a directory with the same name exists - if info, err := os.Stat(name); err == nil && info.IsDir() { + if info, err := os.Stat(target); err == nil && info.IsDir() { continue } if file.Mode().IsDir() { - mode := file.Mode() | 0o700 - err := os.MkdirAll(name, mode) + err := os.MkdirAll(target, 0o700) if err != nil { return fmt.Errorf("failed to create directory: %w", err) } continue } - open, err := file.Open() + zf, err := file.Open() if err != nil { return fmt.Errorf("failed to open file in zip: %w", err) } - err = os.MkdirAll(filepath.Dir(name), 0o700) + err = os.MkdirAll(filepath.Dir(target), 0o700) if err != nil { - open.Close() + zf.Close() return fmt.Errorf("failed to create directory: %w", err) } - mode := file.Mode() | 0o200 - create, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) + out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { - open.Close() + out.Close() return fmt.Errorf("failed to create file: %w", err) } - if _, err = io.Copy(create, io.LimitReader(open, maxBytes)); err != nil { - open.Close() - create.Close() + written, err := io.Copy(out, io.LimitReader(zf, maxBytes)) + if err != nil { return fmt.Errorf("failed to copy file: %w", err) } + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } - open.Close() - create.Close() + if err := out.Close(); err != nil { + return fmt.Errorf("failed to close file: %w", err) + } + + if err := zf.Close(); err != nil { + return fmt.Errorf("failed to close file: %w", err) + } } return nil } diff --git a/pkg/archive/zlib.go b/pkg/archive/zlib.go index 1ff356db..2def0b4c 100644 --- a/pkg/archive/zlib.go +++ b/pkg/archive/zlib.go @@ -22,30 +22,34 @@ func ExtractZlib(ctx context.Context, d string, f string) error { return fmt.Errorf("failed to stat file: %w", err) } - gf, err := os.Open(f) + zf, err := os.Open(f) if err != nil { return fmt.Errorf("failed to open file: %w", err) } - defer gf.Close() + defer zf.Close() base := filepath.Base(f) target := filepath.Join(d, base[:len(base)-len(filepath.Ext(base))]) - zr, err := zlib.NewReader(gf) + zr, err := zlib.NewReader(zf) if err != nil { return fmt.Errorf("failed to create zlib reader: %w", err) } defer zr.Close() - ef, err := os.Create(target) + out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return fmt.Errorf("failed to create extracted file: %w", err) } - defer ef.Close() + defer out.Close() - if _, err := io.Copy(ef, io.LimitReader(zr, maxBytes)); err != nil { + written, err := io.Copy(out, io.LimitReader(zr, maxBytes)) + if err != nil { return fmt.Errorf("failed to copy file: %w", err) } + if written >= maxBytes { + return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target) + } return nil }