Skip to content

Commit

Permalink
Split up archive.go into type-specific files; add wider zlib support (#…
Browse files Browse the repository at this point in the history
…723)

* Split up archive.go into type-specific files

Signed-off-by: egibs <[email protected]>

* Move archive code to new package; leave tests to avoid circular imports

Signed-off-by: egibs <[email protected]>

---------

Signed-off-by: egibs <[email protected]>
  • Loading branch information
egibs authored Dec 18, 2024
1 parent 094eb42 commit f18fb0e
Show file tree
Hide file tree
Showing 14 changed files with 889 additions and 761 deletions.
736 changes: 0 additions & 736 deletions pkg/action/archive.go

This file was deleted.

36 changes: 19 additions & 17 deletions pkg/action/archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/clog/slogtest"
"github.com/chainguard-dev/malcontent/pkg/archive"
"github.com/chainguard-dev/malcontent/pkg/malcontent"
"github.com/chainguard-dev/malcontent/pkg/programkind"
"github.com/chainguard-dev/malcontent/pkg/render"
"github.com/chainguard-dev/malcontent/rules"
thirdparty "github.com/chainguard-dev/malcontent/third_party"
Expand All @@ -25,23 +27,23 @@ func TestExtractionMethod(t *testing.T) {
ext string
want func(context.Context, string, string) error
}{
{"apk", ".apk", extractTar},
{"gem", ".gem", extractTar},
{"gzip", ".gz", extractGzip},
{"jar", ".jar", extractZip},
{"tar.gz", ".tar.gz", extractTar},
{"tar.xz", ".tar.xz", extractTar},
{"tar", ".tar", extractTar},
{"tgz", ".tgz", extractTar},
{"apk", ".apk", archive.ExtractTar},
{"gem", ".gem", archive.ExtractTar},
{"gzip", ".gz", archive.ExtractGzip},
{"jar", ".jar", archive.ExtractZip},
{"tar.gz", ".tar.gz", archive.ExtractTar},
{"tar.xz", ".tar.xz", archive.ExtractTar},
{"tar", ".tar", archive.ExtractTar},
{"tgz", ".tgz", archive.ExtractTar},
{"unknown", ".unknown", nil},
{"upx", ".upx", nil},
{"zip", ".zip", extractZip},
{"zip", ".zip", archive.ExtractZip},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := extractionMethod(tt.ext)
got := archive.ExtractionMethod(tt.ext)
if (got == nil) != (tt.want == nil) {
t.Errorf("extractionMethod() for extension %v did not return expected result", tt.ext)
}
Expand Down Expand Up @@ -75,7 +77,7 @@ func TestExtractionMultiple(t *testing.T) {
t.Run(tt.path, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
dir, err := extractArchiveToTempDir(ctx, tt.path)
dir, err := archive.ExtractArchiveToTempDir(ctx, tt.path)
if err != nil {
t.Fatal(err)
}
Expand All @@ -102,7 +104,7 @@ func TestExtractionMultiple(t *testing.T) {
func TestExtractTar(t *testing.T) {
t.Parallel()
ctx := context.Background()
dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.tar.gz"))
dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.tar.gz"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -130,7 +132,7 @@ func TestExtractTar(t *testing.T) {
func TestExtractGzip(t *testing.T) {
t.Parallel()
ctx := context.Background()
dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.gz"))
dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.gz"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -158,7 +160,7 @@ func TestExtractGzip(t *testing.T) {
func TestExtractZip(t *testing.T) {
t.Parallel()
ctx := context.Background()
dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.zip"))
dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko.zip"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -186,7 +188,7 @@ func TestExtractZip(t *testing.T) {
func TestExtractNestedArchive(t *testing.T) {
t.Parallel()
ctx := context.Background()
dir, err := extractArchiveToTempDir(ctx, filepath.Join("testdata", "apko_nested.tar.gz"))
dir, err := archive.ExtractArchiveToTempDir(ctx, filepath.Join("testdata", "apko_nested.tar.gz"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -325,7 +327,7 @@ func TestGetExt(t *testing.T) {
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
t.Parallel()
if got := getExt(tt.path); got != tt.want {
if got := programkind.GetExt(tt.path); got != tt.want {
t.Errorf("Ext() = %v, want %v", got, tt.want)
}
})
Expand Down Expand Up @@ -402,7 +404,7 @@ func TestIsValidPath(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isValidPath(tt.target, tt.baseDir)
result := archive.IsValidPath(tt.target, tt.baseDir)
if result != tt.expected {
t.Errorf("isValidPath(%q, %q) = %v, want %v", tt.target, tt.baseDir, result, tt.expected)
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/action/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/agext/levenshtein"
"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/pkg/malcontent"
"github.com/chainguard-dev/malcontent/pkg/programkind"
orderedmap "github.com/wk8/go-ordered-map/v2"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -151,8 +152,8 @@ func Diff(ctx context.Context, c malcontent.Config) (*malcontent.Report, error)
var srcBase, destBase string
srcCh := make(chan map[string]*malcontent.FileReport, 1)
destCh := make(chan map[string]*malcontent.FileReport, 1)
srcIsArchive := isSupportedArchive(srcPath)
destIsArchive := isSupportedArchive(destPath)
srcIsArchive := programkind.IsSupportedArchive(srcPath)
destIsArchive := programkind.IsSupportedArchive(destPath)

srcInfo, err := os.Stat(srcPath)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"sync"

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/pkg/archive"
"github.com/chainguard-dev/malcontent/pkg/compile"
"github.com/chainguard-dev/malcontent/pkg/malcontent"
"github.com/chainguard-dev/malcontent/pkg/programkind"
Expand Down Expand Up @@ -294,7 +295,7 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
if c.OCI {
// store the image URI for later use
imageURI = scanPath
ociExtractPath, err = oci(ctx, imageURI)
ociExtractPath, err = archive.OCI(ctx, imageURI)
logger.Debug("oci image", slog.Any("scanPath", scanPath), slog.Any("ociExtractPath", ociExtractPath))
if err != nil {
return nil, fmt.Errorf("failed to prepare OCI image for scanning: %w", err)
Expand Down Expand Up @@ -432,7 +433,7 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
case <-scanCtx.Done():
return scanCtx.Err()
default:
if isSupportedArchive(path) {
if programkind.IsSupportedArchive(path) {
return handleArchive(path)
}
return handleFile(path)
Expand Down Expand Up @@ -493,7 +494,7 @@ func processArchive(ctx context.Context, c malcontent.Config, rfs []fs.FS, archi
var err error
var frs sync.Map

tmpRoot, err := extractArchiveToTempDir(ctx, archivePath)
tmpRoot, err := archive.ExtractArchiveToTempDir(ctx, archivePath)
if err != nil {
return nil, fmt.Errorf("extract to temp: %w", err)
}
Expand Down
197 changes: 197 additions & 0 deletions pkg/archive/archive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package archive

import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/pkg/programkind"
)

// isValidPath checks if the target file is within the given directory.
func IsValidPath(target, dir string) bool {
return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir))
}

const maxBytes = 1 << 29 // 512MB

func extractNestedArchive(
ctx context.Context,
d string,
f string,
extracted *sync.Map,
) error {
isArchive := false
// zlib-compressed files are also archives
ft, err := programkind.File(f)
if err != nil {
return fmt.Errorf("failed to determine file type: %w", err)
}
if ft != nil && ft.MIME == "application/zlib" {
isArchive = true
}
if _, ok := programkind.ArchiveMap[programkind.GetExt(f)]; ok {
isArchive = true
}
//nolint:nestif // ignore complexity of 8
if isArchive {
// Ensure the file was extracted and exists
fullPath := filepath.Join(d, f)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
return fmt.Errorf("file does not exist: %w", err)
}

var extract func(context.Context, string, string) error
// Check for zlib-compressed files first and use the zlib-specific function
ft, err := programkind.File(fullPath)
if err != nil {
return fmt.Errorf("failed to determine file type: %w", err)
}
if ft != nil && ft.MIME == "application/zlib" {
extract = ExtractZlib
} else {
extract = ExtractionMethod(programkind.GetExt(fullPath))
}
err = extract(ctx, d, fullPath)
if err != nil {
return fmt.Errorf("extract nested archive: %w", err)
}
// Mark the file as extracted
extracted.Store(f, true)

// Remove the nested archive file
// This is done to prevent the file from being scanned
if err := os.Remove(fullPath); err != nil {
return fmt.Errorf("failed to remove file: %w", err)
}

// Check if there are any newly extracted files that are also archives
files, err := os.ReadDir(d)
if err != nil {
return fmt.Errorf("failed to read directory after extraction: %w", err)
}
for _, file := range files {
relPath := filepath.Join(d, file.Name())
if _, isExtracted := extracted.Load(relPath); !isExtracted {
if err := extractNestedArchive(ctx, d, file.Name(), extracted); err != nil {
return fmt.Errorf("failed to extract nested archive %s: %w", file.Name(), err)
}
}
}
}
return nil
}

// extractArchiveToTempDir creates a temporary directory and extracts the archive file for scanning.
func ExtractArchiveToTempDir(ctx context.Context, path string) (string, error) {
logger := clog.FromContext(ctx).With("path", path)
logger.Debug("creating temp dir")

tmpDir, err := os.MkdirTemp("", filepath.Base(path))
if err != nil {
return "", fmt.Errorf("failed to create temp dir: %w", err)
}

var extract func(context.Context, string, string) error
// Check for zlib-compressed files first and use the zlib-specific function
ft, err := programkind.File(path)
if err != nil {
return "", fmt.Errorf("failed to determine file type: %w", err)
}
if ft != nil && ft.MIME == "application/zlib" {
extract = ExtractZlib
} else {
extract = ExtractionMethod(programkind.GetExt(path))
}
if extract == nil {
return "", fmt.Errorf("unsupported archive type: %s", path)
}
err = extract(ctx, tmpDir, path)
if err != nil {
return "", fmt.Errorf("failed to extract %s: %w", path, err)
}

var extractedFiles sync.Map
files, err := os.ReadDir(tmpDir)
if err != nil {
return "", fmt.Errorf("failed to read files in directory %s: %w", tmpDir, err)
}
for _, file := range files {
extractedFiles.Store(filepath.Join(tmpDir, file.Name()), false)
}

extractedFiles.Range(func(key, _ any) bool {
if key == nil {
return true
}
//nolint: nestif // ignoring complexity of 11
if file, ok := key.(string); ok {
ext := programkind.GetExt(file)
info, err := os.Stat(file)
if err != nil {
return false
}
switch mode := info.Mode(); {
case mode.IsDir():
err = filepath.WalkDir(file, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
rel, err := filepath.Rel(tmpDir, path)
if err != nil {
return fmt.Errorf("filepath.Rel: %w", err)
}
if !d.IsDir() {
if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil {
return fmt.Errorf("failed to extract nested archive %s: %w", rel, err)
}
}

return nil
})
if err != nil {
return false
}
return true
case mode.IsRegular():
if _, ok := programkind.ArchiveMap[ext]; ok {
rel, err := filepath.Rel(tmpDir, file)
if err != nil {
return false
}
if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil {
return false
}
}
return true
}
}
return true
})

return tmpDir, nil
}

func ExtractionMethod(ext string) func(context.Context, string, string) error {
switch ext {
case ".jar", ".zip", ".whl":
return ExtractZip
case ".gz":
return ExtractGzip
case ".apk", ".gem", ".tar", ".tar.bz2", ".tar.gz", ".tgz", ".tar.xz", ".tbz", ".xz":
return ExtractTar
case ".bz2", ".bzip2":
return ExtractBz2
case ".rpm":
return ExtractRPM
case ".deb":
return ExtractDeb
default:
return nil
}
}
Loading

0 comments on commit f18fb0e

Please sign in to comment.