diff --git a/kaitai/stream.go b/kaitai/stream.go index 05f626b..d06a980 100644 --- a/kaitai/stream.go +++ b/kaitai/stream.go @@ -55,7 +55,7 @@ func (k *Stream) EOF() (bool, error) { } // Size returns the number of bytes of the stream. -func (k *Stream) Size() (int64, error) { +func (k *Stream) Size() (size int64, err error) { // Go has no internal ReadSeeker function to get current ReadSeeker size, // thus we use the following trick. // Remember our current position @@ -63,19 +63,20 @@ func (k *Stream) Size() (int64, error) { if err != nil { return 0, err } + // Deferred Seek back to the current position + defer func() { + if _, serr := k.Seek(curPos, io.SeekStart); serr != nil { + err = fmt.Errorf("failed to seek to the initial position %v: %w", curPos, serr) + } + }() // Seek to the end of the File object _, err = k.Seek(0, io.SeekEnd) if err != nil { return 0, err } - // Remember position, which is equal to the full length - fullSize, err := k.Pos() - if err != nil { - return fullSize, err - } - // Seek back to the current position - _, err = k.Seek(curPos, io.SeekStart) - return fullSize, err + + // Return the current position, which is equal to the full length + return k.Pos() } // Pos returns the current position of the stream. diff --git a/kaitai/stream_test.go b/kaitai/stream_test.go index 29bf027..64501ce 100644 --- a/kaitai/stream_test.go +++ b/kaitai/stream_test.go @@ -58,7 +58,8 @@ func TestStream_Size(t *testing.T) { want int64 wantErr bool }{ - {"Size", NewStream(bytes.NewReader([]byte("test"))), 4, false}, + {"Zero size", NewStream(bytes.NewReader([]byte{})), 0, false}, + {"Small size", NewStream(bytes.NewReader([]byte("test"))), 4, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -74,6 +75,98 @@ func TestStream_Size(t *testing.T) { } } +type artificialError struct{} + +func (e artificialError) Error() string { + return "artificial error when seeking with io.SeekCurrent after seeking to end" +} + +type failingReader struct { + pos int64 + mustFail func(fr failingReader, offset int64, whence int) bool +} + +func (fr *failingReader) Read(p []byte) (n int, err error) { return 0, nil } +func (fr *failingReader) Seek(offset int64, whence int) (int64, error) { + if fr.mustFail(*fr, offset, whence) { + return 0, artificialError{} + } + + switch { + case whence == io.SeekCurrent: + return fr.pos, nil + case whence == io.SeekStart: + fr.pos = offset + default: // whence == io.SeekEnd + fr.pos = -1 + } + + return fr.pos, nil +} + +// No regression test for issue #26 +func TestErrorHandlingInStream_Size(t *testing.T) { + tests := map[string]struct { + initialPos int64 + failingCondition func(fr failingReader, offset int64, whence int) bool + errorCheck func(err error) bool + wantFinalPos int64 + }{ + "fails to get initial position": { + initialPos: 5, + failingCondition: func(fr failingReader, offset int64, whence int) bool { + return whence == io.SeekCurrent && offset == 0 + }, + errorCheck: func(err error) bool { + _, ok := err.(artificialError) + return ok + }, + wantFinalPos: 5, + }, + "seek to the end fails": { + initialPos: 5, + failingCondition: func(fr failingReader, offset int64, whence int) bool { + return whence == io.SeekEnd + }, + errorCheck: func(err error) bool { + _, ok := err.(artificialError) + return ok + }, + wantFinalPos: 5, + }, + "deferred seek to the initial pos fails": { + initialPos: 5, + failingCondition: func(fr failingReader, offset int64, whence int) bool { + return whence == io.SeekStart && fr.pos == -1 + }, + errorCheck: func(err error) bool { + _, ok := err.(artificialError) + return !ok + }, + wantFinalPos: -1, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + fr := &failingReader{tt.initialPos, tt.failingCondition} + s := NewStream(fr) + _, err := s.Size() + + if err == nil { + t.Fatal("Expected error, got nothing") + } + + if !tt.errorCheck(err) { + t.Fatalf("Expected error of type %T, got one of type %T", artificialError{}, err) + } + + if fr.pos != tt.wantFinalPos { + t.Fatalf("Expected position to be %v, got %v", tt.wantFinalPos, fr.pos) + } + }) + } +} + func TestStream_Pos(t *testing.T) { tests := []struct { name string