Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] - Add Size Method to BufferedReadSeeker and Refactor Context Timeout Handling in HandleFile #3307

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions pkg/handlers/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ const (
var (
// NOTE: This is a temporary workaround for |openArchive| incrementing depth twice per archive.
// See: https://github.com/trufflesecurity/trufflehog/issues/2942
maxDepth = 5 * 2
maxSize = 2 << 30 // 2 GB
maxTimeout = time.Duration(30) * time.Second
maxDepth = 5 * 2
maxSize = 2 << 30 // 2 GB
)

// SetArchiveMaxSize sets the maximum size of the archive.
Expand All @@ -34,9 +33,6 @@ func SetArchiveMaxSize(size int) { maxSize = size }
// SetArchiveMaxDepth sets the maximum depth of the archive.
func SetArchiveMaxDepth(depth int) { maxDepth = depth }

// SetArchiveMaxTimeout sets the maximum timeout for the archive handler.
func SetArchiveMaxTimeout(timeout time.Duration) { maxTimeout = timeout }

// archiveHandler is a handler for common archive files that are supported by the archiver library.
type archiveHandler struct{ *defaultHandler }

Expand All @@ -57,8 +53,6 @@ func (h *archiveHandler) HandleFile(ctx logContext.Context, input fileReader) (c
}

go func() {
ctx, cancel := logContext.WithTimeout(ctx, maxTimeout)
defer cancel()
defer close(dataChan)

// Update the metrics for the file processing.
Expand Down
20 changes: 18 additions & 2 deletions pkg/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"time"

"github.com/gabriel-vasile/mimetype"
"github.com/mholt/archiver/v4"
Expand Down Expand Up @@ -249,6 +250,11 @@ func selectHandler(mimeT mimeType, isGenericArchive bool) FileHandler {
}
}

var maxTimeout = time.Duration(30) * time.Second

// SetArchiveMaxTimeout sets the maximum timeout for the archive handler.
func SetArchiveMaxTimeout(timeout time.Duration) { maxTimeout = timeout }

// HandleFile orchestrates the complete file handling process for a given file.
// It determines the MIME type of the file, selects the appropriate handler based on this type, and processes the file.
// This function initializes the handling process and delegates to the specific handler to manage file
Expand Down Expand Up @@ -279,20 +285,30 @@ func HandleFile(
}
defer rdr.Close()

size, err := rdr.Size()
if err != nil {
ctx.Logger().Error(err, "error getting file size")
}

ctx = logContext.WithValues(ctx, "mime", rdr.mime.String(), "size_bytes", size)

mimeT := mimeType(rdr.mime.String())
config := newFileHandlingConfig(options...)
if config.skipArchives && rdr.isGenericArchive {
ctx.Logger().V(5).Info("skipping archive file", "mime", mimeT)
return nil
}

processingCtx, cancel := logContext.WithTimeout(ctx, maxTimeout)
defer cancel()

handler := selectHandler(mimeT, rdr.isGenericArchive)
archiveChan, err := handler.HandleFile(ctx, rdr) // Delegate to the specific handler to process the file.
archiveChan, err := handler.HandleFile(processingCtx, rdr) // Delegate to the specific handler to process the file.
if err != nil {
return fmt.Errorf("error handling file: %w", err)
}

return handleChunks(ctx, archiveChan, chunkSkel, reporter)
return handleChunks(processingCtx, archiveChan, chunkSkel, reporter)
}

// handleChunks reads data from the handlerChan and uses it to fill chunks according to a predefined skeleton (chunkSkel).
Expand Down
27 changes: 27 additions & 0 deletions pkg/iobuf/bufferedreaderseeker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package iobuf

import (
"errors"
"fmt"
"io"
"os"

Expand Down Expand Up @@ -355,3 +356,29 @@ func (br *BufferedReadSeeker) Close() error {
}
return nil
}

// Size returns the total size of the reader.
func (br *BufferedReadSeeker) Size() (int64, error) {
if br.sizeKnown {
return br.totalSize, nil
}

currentPos, err := br.Seek(0, io.SeekCurrent)
if err != nil {
return 0, fmt.Errorf("failed to get current position: %w", err)
}

endPos, err := br.Seek(0, io.SeekEnd)
if err != nil {
return 0, fmt.Errorf("failed to seek to end: %w", err)
}

if _, err = br.Seek(currentPos, io.SeekStart); err != nil {
return 0, fmt.Errorf("failed to restore position: %w", err)
}

br.totalSize = endPos
br.sizeKnown = true

return br.totalSize, nil
}
128 changes: 128 additions & 0 deletions pkg/iobuf/bufferedreaderseeker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package iobuf

import (
"bytes"
"errors"
"io"
"strings"
"testing"
Expand Down Expand Up @@ -350,3 +351,130 @@ func TestBufferedReaderSeekerReadAt(t *testing.T) {
})
}
}

// TestBufferedReadSeekerSize tests the Size method of BufferedReadSeeker.
func TestBufferedReadSeekerSize(t *testing.T) {
tests := []struct {
name string
reader io.Reader
setup func(*BufferedReadSeeker)
expectedSize int64
expectError bool
verifyPosition func(*BufferedReadSeeker, int64)
}{
{
name: "size of seekable reader",
reader: strings.NewReader("Hello, World!"),
expectedSize: 13,
},
{
name: "size of non-seekable reader",
reader: bytes.NewBufferString("Hello, World!"),
expectedSize: 13,
},
{
name: "size of empty seekable reader",
reader: strings.NewReader(""),
expectedSize: 0,
},
{
name: "size of empty non-seekable reader",
reader: bytes.NewBufferString(""),
expectedSize: 0,
},
{
name: "size of non-seekable reader after partial read",
reader: bytes.NewBufferString("Partial read data"),
setup: func(brs *BufferedReadSeeker) {
// Read first 7 bytes ("Partial").
buf := make([]byte, 7)
_, _ = brs.Read(buf)
},
expectedSize: 17, // "Partial read data" is 16 bytes
expectError: false,
verifyPosition: func(brs *BufferedReadSeeker, expectedSize int64) {
// After Size is called, the read position should remain at 7
currentPos, err := brs.Seek(0, io.SeekCurrent)
assert.NoError(t, err)
assert.Equal(t, int64(7), currentPos)
},
},
{
name: "repeated Size calls",
reader: strings.NewReader("Repeated Size Calls Test"),
expectedSize: 24,
expectError: false,
setup: func(brs *BufferedReadSeeker) {
// Call Size multiple times.
size1, err1 := brs.Size()
assert.NoError(t, err1)
assert.Equal(t, int64(24), size1)

size2, err2 := brs.Size()
assert.NoError(t, err2)
assert.Equal(t, int64(24), size2)
},
},
{
name: "size with error during reading",
reader: &errorReader{
data: "Data before error",
errorAfter: 5, // Return error after reading 5 bytes
},
expectedSize: 0,
expectError: true,
},
{
name: "size with limited reader simulating EOF",
reader: io.LimitReader(strings.NewReader("Limited data"), 7),
expectedSize: 7,
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

brs := NewBufferedReaderSeeker(tt.reader)

if tt.setup != nil {
tt.setup(brs)
}

size, err := brs.Size()
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedSize, size)
}

if tt.verifyPosition != nil {
tt.verifyPosition(brs, tt.expectedSize)
}
})
}
}

// errorReader is an io.Reader that returns an error after reading a specified number of bytes.
ahrav marked this conversation as resolved.
Show resolved Hide resolved
// It's used to simulate non-EOF errors during read operations.
type errorReader struct {
data string
errorAfter int // Number of bytes to read before returning an error
readBytes int
}

func (er *errorReader) Read(p []byte) (int, error) {
if er.readBytes >= er.errorAfter {
return 0, errors.New("simulated read error")
}
remaining := er.errorAfter - er.readBytes
toRead := len(p)
if toRead > remaining {
toRead = remaining
}
copy(p, er.data[er.readBytes:er.readBytes+toRead])
er.readBytes += toRead
return toRead, nil
}
Loading