diff --git a/length.go b/length.go index 8e2ae4d..750e8f4 100644 --- a/length.go +++ b/length.go @@ -38,6 +38,9 @@ func readLength(reader io.Reader) (length int, read int, err error) { if lengthBytes > 8 { return 0, read, errors.New("long-form length overflow") } + + // Accumulate into a 64-bit variable + var length64 int64 for i := 0; i < lengthBytes; i++ { b, err = readByte(reader) if err != nil { @@ -49,8 +52,15 @@ func readLength(reader io.Reader) (length int, read int, err error) { read++ // x.600, 8.1.3.5 - length <<= 8 - length |= int(b) + length64 <<= 8 + length64 |= int64(b) + } + + // Cast to a platform-specific integer + length = int(length64) + // Ensure we didn't overflow + if int64(length) != length64 { + return 0, read, errors.New("long-form length overflow") } default: diff --git a/length_test.go b/length_test.go index afe0e80..289510a 100644 --- a/length_test.go +++ b/length_test.go @@ -11,7 +11,7 @@ func TestReadLength(t *testing.T) { testcases := map[string]struct { Data []byte - ExpectedLength int + ExpectedLength int64 ExpectedBytesRead int ExpectedError string }{ @@ -68,7 +68,19 @@ func TestReadLength(t *testing.T) { ExpectedLength: 127, ExpectedBytesRead: 2, }, - "long-definite-form max length": { + "long-definite-form max length (32-bit)": { + Data: []byte{ + LengthLongFormBitmask | 4, + 0x7F, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + }, + ExpectedLength: math.MaxInt32, + ExpectedBytesRead: 5, + }, + "long-definite-form max length (64-bit)": { Data: []byte{ LengthLongFormBitmask | 8, 0x7F, @@ -86,6 +98,11 @@ func TestReadLength(t *testing.T) { } for k, tc := range testcases { + // Skip tests requiring 64-bit integers on platforms that don't support them + if tc.ExpectedLength != int64(int(tc.ExpectedLength)) { + continue + } + reader := bytes.NewBuffer(tc.Data) length, read, err := readLength(reader) @@ -104,7 +121,7 @@ func TestReadLength(t *testing.T) { t.Errorf("%s: expected read %d, got %d", k, tc.ExpectedBytesRead, read) } - if length != tc.ExpectedLength { + if int64(length) != tc.ExpectedLength { t.Errorf("%s: expected length %d, got %d", k, tc.ExpectedLength, length) } } @@ -112,7 +129,7 @@ func TestReadLength(t *testing.T) { func TestEncodeLength(t *testing.T) { testcases := map[string]struct { - Length int + Length int64 ExpectedBytes []byte }{ "0": { @@ -133,7 +150,18 @@ func TestEncodeLength(t *testing.T) { ExpectedBytes: []byte{LengthLongFormBitmask | 1, 128}, }, - "max long-form length": { + "max long-form length (32-bit)": { + Length: math.MaxInt32, + ExpectedBytes: []byte{ + LengthLongFormBitmask | 4, + 0x7F, + 0xFF, + 0xFF, + 0xFF, + }, + }, + + "max long-form length (64-bit)": { Length: math.MaxInt64, ExpectedBytes: []byte{ LengthLongFormBitmask | 8, @@ -150,7 +178,12 @@ func TestEncodeLength(t *testing.T) { } for k, tc := range testcases { - b := encodeLength(tc.Length) + // Skip tests requiring 64-bit integers on platforms that don't support them + if tc.Length != int64(int(tc.Length)) { + continue + } + + b := encodeLength(int(tc.Length)) if bytes.Compare(tc.ExpectedBytes, b) != 0 { t.Errorf("%s: Expected\n\t%#v\ngot\n\t%#v", k, tc.ExpectedBytes, b) }