Skip to content

Commit

Permalink
Address more extraction edge cases; improve naming and consistency (#733
Browse files Browse the repository at this point in the history
)

* Address more gzip, tar, and tar.gz edge cases

Signed-off-by: egibs <[email protected]>

* Add ordering comment; move tar case to top

Signed-off-by: egibs <[email protected]>

* Add .gzip for completeness, move above .zip

Signed-off-by: egibs <[email protected]>

* Address PR comments

Signed-off-by: egibs <[email protected]>

* Fix deb.go symlink validation

Signed-off-by: egibs <[email protected]>

* Move deb.go and tar.go extraction logic to helper functions

Signed-off-by: egibs <[email protected]>

* Re-add check for nonexistent targets

Signed-off-by: egibs <[email protected]>

* Universal naming/consistency changes

Signed-off-by: egibs <[email protected]>

* Run go mod tidy

Signed-off-by: egibs <[email protected]>

---------

Signed-off-by: egibs <[email protected]>
  • Loading branch information
egibs authored Jan 3, 2025
1 parent 3b49925 commit 4c9eb73
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 115 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/google/go-cmp v0.6.0
github.com/google/go-containerregistry v0.20.2
github.com/hillu/go-yara/v4 v4.3.3
github.com/klauspost/compress v1.17.11
github.com/olekukonko/tablewriter v0.0.5
github.com/shirou/gopsutil/v4 v4.24.11
github.com/ulikunitz/xz v0.5.12
Expand All @@ -40,7 +41,6 @@ require (
github.com/ebitengine/purego v0.8.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/kr/pretty v0.2.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect
Expand Down
78 changes: 74 additions & 4 deletions pkg/archive/archive.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package archive

import (
"archive/tar"
"context"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
Expand Down Expand Up @@ -190,13 +192,16 @@ func ExtractArchiveToTempDir(ctx context.Context, path string) (string, error) {
}

func ExtractionMethod(ext string) func(context.Context, string, string) error {
// The ordering of these statements is important, especially for extensions
// that are substrings of other extensions (e.g., `.gz` and `.tar.gz` or `.tgz`)
switch ext {
case ".jar", ".zip", ".whl":
return ExtractZip
case ".gz":
return ExtractGzip
// New cases should go below this line so that the lengthier tar extensions are evaluated first
case ".apk", ".gem", ".tar", ".tar.bz2", ".tar.gz", ".tgz", ".tar.xz", ".tbz", ".xz":
return ExtractTar
case ".gz", ".gzip":
return ExtractGzip
case ".jar", ".zip", ".whl":
return ExtractZip
case ".bz2", ".bzip2":
return ExtractBz2
case ".rpm":
Expand All @@ -207,3 +212,68 @@ func ExtractionMethod(ext string) func(context.Context, string, string) error {
return nil
}
}

// handleDirectory extracts valid directories within .deb or .tar archives.
func handleDirectory(target string) error {
if err := os.MkdirAll(target, 0o700); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
return nil
}

// handleFile extracts valid files within .deb or .tar archives.
func handleFile(target string, tr *tar.Reader) error {
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}

out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()

written, err := io.Copy(out, io.LimitReader(tr, maxBytes))
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
if written >= maxBytes {
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
}

return nil
}

// handleSymlink creates valid symlinks when extracting .deb or .tar archives.
func handleSymlink(dir, linkName, target string) error {
// Skip symlinks for targets that do not exist
_, err := os.Readlink(target)
if os.IsNotExist(err) {
return nil
}

fullLink := filepath.Join(dir, linkName)

// Remove existing symlinks
if _, err := os.Lstat(fullLink); err == nil {
if err := os.Remove(fullLink); err != nil {
return fmt.Errorf("failed to remove existing symlink: %w", err)
}
}

if err := os.Symlink(target, fullLink); err != nil {
return fmt.Errorf("failed to create symlink: %w", err)
}

linkReal, err := filepath.EvalSymlinks(fullLink)
if err != nil {
os.Remove(fullLink)
return fmt.Errorf("failed to evaluate symlink: %w", err)
}
if !IsValidPath(linkReal, dir) {
os.Remove(fullLink)
return fmt.Errorf("symlink points outside temporary directory: %s", linkReal)
}

return nil
}
14 changes: 11 additions & 3 deletions pkg/archive/bz2.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,28 @@ func ExtractBz2(ctx context.Context, d, f string) error {
uncompressed := strings.TrimSuffix(filepath.Base(f), ".bz2")
uncompressed = strings.TrimSuffix(uncompressed, ".bzip2")
target := filepath.Join(d, uncompressed)
if !IsValidPath(target, d) {
return fmt.Errorf("invalid file path: %s", target)
}
if err := os.MkdirAll(d, 0o700); err != nil {
return fmt.Errorf("failed to create directory for file: %w", err)
}

// #nosec G115 // ignore Type conversion which leads to integer overflow
// header.Mode is int64 and FileMode is uint32
out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600)
out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
if _, err := io.Copy(out, io.LimitReader(br, maxBytes)); err != nil {
out.Close()

written, err := io.Copy(out, io.LimitReader(br, maxBytes))
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
if written >= maxBytes {
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
}

return nil
}
44 changes: 8 additions & 36 deletions pkg/archive/deb.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,49 +46,21 @@ func ExtractDeb(ctx context.Context, d, f string) error {
}

target := filepath.Join(d, clean)
if !IsValidPath(target, d) {
return fmt.Errorf("invalid file path: %s", target)
}

switch header.Typeflag {
case tar.TypeDir:
// #nosec G115 // ignore Type conversion which leads to integer overflow
// header.Mode is int64 and FileMode is uint32
if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
if err := handleDirectory(target); err != nil {
return fmt.Errorf("failed to extract directory: %w", err)
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}

// #nosec G115
out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}

if _, err := io.Copy(out, io.LimitReader(df.Data, maxBytes)); err != nil {
out.Close()
return fmt.Errorf("failed to copy file: %w", err)
}

if err := out.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
if err := handleFile(target, df.Data); err != nil {
return fmt.Errorf("failed to extract file: %w", err)
}
case tar.TypeSymlink:
// Skip symlinks for targets that do not exist
_, err = os.Readlink(target)
if os.IsNotExist(err) {
continue
}
// Ensure that symlinks are not relative path traversals
// #nosec G305 // L208 handles the check
linkReal, err := filepath.EvalSymlinks(filepath.Join(d, header.Linkname))
if err != nil {
return fmt.Errorf("failed to evaluate symlink: %w", err)
}
if !IsValidPath(linkReal, d) {
return fmt.Errorf("symlink points outside temporary directory: %s", linkReal)
}
if err := os.Symlink(linkReal, target); err != nil {
if err := handleSymlink(d, header.Linkname, target); err != nil {
return fmt.Errorf("failed to create symlink: %w", err)
}
}
Expand Down
26 changes: 23 additions & 3 deletions pkg/archive/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,23 @@ import (
"path/filepath"

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/pkg/programkind"
)

// extractGzip extracts .gz archives.
func ExtractGzip(ctx context.Context, d string, f string) error {
// Check whether the provided file is a valid gzip archive
var isGzip bool
if ft, err := programkind.File(f); err == nil && ft != nil {
if ft.MIME == "application/gzip" {
isGzip = true
}
}

if !isGzip {
return fmt.Errorf("not a valid gzip archive")
}

logger := clog.FromContext(ctx).With("dir", d, "file", f)
logger.Debug("extracting gzip")

Expand All @@ -30,22 +43,29 @@ func ExtractGzip(ctx context.Context, d string, f string) error {

base := filepath.Base(f)
target := filepath.Join(d, base[:len(base)-len(filepath.Ext(base))])
if !IsValidPath(target, d) {
return fmt.Errorf("invalid file path: %s", target)
}

gr, err := gzip.NewReader(gf)
if err != nil {
return fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gr.Close()

ef, err := os.Create(target)
out, err := os.Create(target)
if err != nil {
return fmt.Errorf("failed to create extracted file: %w", err)
}
defer ef.Close()
defer out.Close()

if _, err := io.Copy(ef, io.LimitReader(gr, maxBytes)); err != nil {
written, err := io.Copy(out, io.LimitReader(gr, maxBytes))
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
if written >= maxBytes {
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
}

return nil
}
12 changes: 9 additions & 3 deletions pkg/archive/rpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ func ExtractRPM(ctx context.Context, d, f string) error {
}

target := filepath.Join(d, clean)
if !IsValidPath(target, d) {
return fmt.Errorf("invalid file path: %s", target)
}

if header.FileInfo().IsDir() {
if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil {
Expand All @@ -98,15 +101,18 @@ func ExtractRPM(ctx context.Context, d, f string) error {
return fmt.Errorf("failed to create parent directory: %w", err)
}

out, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode))
out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}

if _, err := io.Copy(out, io.LimitReader(cr, maxBytes)); err != nil {
out.Close()
written, err := io.Copy(out, io.LimitReader(cr, maxBytes))
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
if written >= maxBytes {
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
}

if err := out.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
Expand Down
Loading

0 comments on commit 4c9eb73

Please sign in to comment.