From dfaf538402be45e6cd12064b3d49e7496d2b22f4 Mon Sep 17 00:00:00 2001 From: Matt Dainty Date: Tue, 12 Dec 2023 00:34:37 +0000 Subject: [PATCH] fix: Handle small reads in branch converters (#144) * fix: Handle reads into small slices Handle the case for `Read(p)` where `len(p)` is smaller than the minimum length required by the branch converter. * test: Make use of `iotest.OneByteReader()` This now exposes a bug in all of the branch converters. * fix: Handle small reads in branch converters Add missing benchmarks also. --- internal/bra/bra.go | 8 ++++++++ internal/bra/reader.go | 15 +++++++++++---- reader_test.go | 19 ++++++++++++++++++- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/internal/bra/bra.go b/internal/bra/bra.go index f567b5d..ab7be4d 100644 --- a/internal/bra/bra.go +++ b/internal/bra/bra.go @@ -12,3 +12,11 @@ func max(x, y int) int { return y } + +func min(x, y int) int { + if x < y { + return x + } + + return y +} diff --git a/internal/bra/reader.go b/internal/bra/reader.go index 274fe1d..42edf15 100644 --- a/internal/bra/reader.go +++ b/internal/bra/reader.go @@ -9,6 +9,7 @@ import ( type readCloser struct { rc io.ReadCloser buf bytes.Buffer + n int conv converter } @@ -30,13 +31,19 @@ func (rc *readCloser) Read(p []byte) (int, error) { if !errors.Is(err, io.EOF) { return 0, err } - } - if n := rc.conv.Convert(rc.buf.Bytes(), false); n > 0 { - return rc.buf.Read(p[:n]) + if rc.buf.Len() < rc.conv.Size() { + rc.n = rc.buf.Len() + } } - return rc.buf.Read(p) + rc.n += rc.conv.Convert(rc.buf.Bytes()[rc.n:], false) + + n, err := rc.buf.Read(p[:min(rc.n, len(p))]) + + rc.n -= n + + return n, err } func newReader(readers []io.ReadCloser, conv converter) (io.ReadCloser, error) { diff --git a/reader_test.go b/reader_test.go index 241d32c..b3fe8f7 100644 --- a/reader_test.go +++ b/reader_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "testing" "testing/fstest" + "testing/iotest" "github.com/bodgit/sevenzip" "github.com/bodgit/sevenzip/internal/util" @@ -28,7 +29,7 @@ func readArchive(t *testing.T, r *sevenzip.ReadCloser) { h.Reset() - if _, err := io.Copy(h, rc); err != nil { + if _, err := io.Copy(h, iotest.OneByteReader(rc)); err != nil { t.Fatal(err) } @@ -332,3 +333,19 @@ func BenchmarkBrotli(b *testing.B) { func BenchmarkZstandard(b *testing.B) { benchmarkArchive(b, "zstd.7z") } + +func BenchmarkBCJ(b *testing.B) { + benchmarkArchive(b, "bcj.7z") +} + +func BenchmarkPPC(b *testing.B) { + benchmarkArchive(b, "ppc.7z") +} + +func BenchmarkARM(b *testing.B) { + benchmarkArchive(b, "arm.7z") +} + +func BenchmarkSPARC(b *testing.B) { + benchmarkArchive(b, "sparc.7z") +}