diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml new file mode 100644 index 0000000..cc5b3aa --- /dev/null +++ b/.github/workflows/main.yaml @@ -0,0 +1,64 @@ +name: Go + +on: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "stable" + + - name: Build + run: go build -v ./... + + build-readme: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "stable" + + - name: Build go snippets in readme + run: | + mkdir -p ~/.local/bin/ + wget -O ~/.local/bin/lintdown.sh https://raw.githubusercontent.com/ChillerDragon/lintdown.sh/master/lintdown.sh + chmod +x ~/.local/bin/lintdown.sh + lintdown.sh README.md + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "stable" + + - name: Test + run: go test -v -race -count=1 ./... + + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "stable" + + - name: Format + run: diff -u <(echo -n) <(gofmt -d ./) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e03e852 --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright 2024 John Behm + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..7cc46c9 --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ +# varint + +varint is a simple variable-length integer encoding. It is a way to store integers in a space-efficient manner. +This variant of varint is space efficient for small integers and is used in the Teeworlds network protocol. + +Additionally this linrary also provides functions that operate on 64bit integers which is out of scope of the Teeeworlds protocol. +These varants may be used for security research or other purposes. + +```text +/ Format: ESDDDDDD EDDDDDDD EDD... Extended, Sign, Data, +// E: is next byte part of the current integer +// S: Sign of integer 0 = positive, 1 = negative +// Data, Integer bits that follow the sign +``` + +## Example + +```go +package main + +import ( + "fmt" + + "github.com/teeworlds-go/varint" +) + +func main() { + buf := make([]byte, varint.MaxVarintLen32) + written := varint.PutVarint(buf, 33) + out, read := varint.Varint(buf) + fmt.Println("written:", written) + fmt.Println("read:", read) + fmt.Println("value:", out) + // Output: + // written: 1 + // read: 1 + // value: 33 +} +``` + diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7b4a434 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/teeworlds-go/varint + +go 1.22.5 diff --git a/internal/testutils/require/compare.go b/internal/testutils/require/compare.go new file mode 100644 index 0000000..a813009 --- /dev/null +++ b/internal/testutils/require/compare.go @@ -0,0 +1,108 @@ +package require + +import ( + "cmp" + "fmt" + "reflect" + "testing" +) + +func GreaterOrEqual(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, actual, expected) < 0 { + FailNow(t, fmt.Sprintf("expected actual value: %v to be greater or equal to: %v", actual, expected), msgAndArgs...) + } +} + +func Greater(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, actual, expected) <= 0 { + FailNow(t, fmt.Sprintf("expected actual value: %v to be greater than: %v", actual, expected), msgAndArgs...) + } +} + +func LessOrEqual(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, actual, expected) > 0 { + FailNow(t, fmt.Sprintf("expected actual value: %v to be less or equal to: %v", actual, expected), msgAndArgs...) + } +} + +func Less(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, actual, expected) >= 0 { + FailNow(t, fmt.Sprintf("expected actual value: %v to be less than: %v", actual, expected), msgAndArgs...) + } +} + +// compare returns +// +// -1 if expected is less than actual, +// +// 0 if expected equals actual, +// +// +1 if expected is greater than actual. +func compare(t *testing.T, expected, actual any) int { + t.Helper() + + e := reflect.ValueOf(expected) + a := reflect.ValueOf(actual) + + if e.Kind() != a.Kind() { + FailNow(t, "type mismatch: expected %T, got %T", expected, actual) + } + + if !e.Comparable() { + FailNow(t, "expected value is not comparable") + } + + if !a.Comparable() { + FailNow(t, "actual value is not comparable") + } + + if e.Kind() != a.Kind() { + FailNow(t, "type mismatch: expected %T, got %T", expected, actual) + } + + switch e.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + { + ev := e.Convert(reflect.TypeOf(int64(0))).Interface().(int64) + av := a.Convert(reflect.TypeOf(int64(0))).Interface().(int64) + return cmp.Compare(ev, av) + } + + case reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + { + ev := e.Convert(reflect.TypeOf(uint64(0))).Interface().(uint64) + av := a.Convert(reflect.TypeOf(uint64(0))).Interface().(uint64) + return cmp.Compare(ev, av) + } + + case reflect.Float32, reflect.Float64: + { + ev := e.Convert(reflect.TypeOf(float64(0))).Interface().(float64) + av := a.Convert(reflect.TypeOf(float64(0))).Interface().(float64) + return cmp.Compare(ev, av) + } + case reflect.String: + { + ev := e.Convert(reflect.TypeOf(string(""))).Interface().(string) + av := a.Convert(reflect.TypeOf(string(""))).Interface().(string) + return cmp.Compare(ev, av) + } + case reflect.Uintptr: + { + ev := e.Convert(reflect.TypeOf(uintptr(0))).Interface().(uintptr) + av := a.Convert(reflect.TypeOf(uintptr(0))).Interface().(uintptr) + return cmp.Compare(ev, av) + } + } + + FailNow(t, "type not supported: %T", expected) + return 0 // should not be reached +} diff --git a/internal/testutils/require/compare_test.go b/internal/testutils/require/compare_test.go new file mode 100644 index 0000000..a16fabd --- /dev/null +++ b/internal/testutils/require/compare_test.go @@ -0,0 +1,20 @@ +package require + +import "testing" + +func TestCompare(t *testing.T) { + GreaterOrEqual(t, 1, 2) + GreaterOrEqual(t, 1, 1) + Greater(t, 1, 2) + LessOrEqual(t, 2, 1) + LessOrEqual(t, 1, 1) + Less(t, 2, 1) + + require := New(t) + require.GreaterOrEqual(1, 2) + require.GreaterOrEqual(1, 1) + require.Greater(1, 2) + require.LessOrEqual(2, 1) + require.LessOrEqual(1, 1) + require.Less(2, 1) +} diff --git a/internal/testutils/require/fail.go b/internal/testutils/require/fail.go new file mode 100644 index 0000000..1da48f2 --- /dev/null +++ b/internal/testutils/require/fail.go @@ -0,0 +1,31 @@ +package require + +import ( + "strings" + "testing" +) + +func FailNow(t *testing.T, errMsg string, msgAndArgs ...any) { + t.Helper() + + labeledMessages := labeledMessages{ + { + label: "Error Trace", + message: strings.Join(CallStack(), "\n\t\t\t"), + }, + { + label: "Error", + message: errMsg, + }, + } + + message := msgOrFmtMsg(msgAndArgs...) + if len(message) > 0 { + labeledMessages = append(labeledMessages, labeledMessage{ + label: "Message", + message: message, + }) + } + + t.Fatal(labeledMessages.String()) +} diff --git a/internal/testutils/require/format.go b/internal/testutils/require/format.go new file mode 100644 index 0000000..fb4caeb --- /dev/null +++ b/internal/testutils/require/format.go @@ -0,0 +1,71 @@ +package require + +import ( + "bufio" + "fmt" + "strings" +) + +func msgOrFmtMsg(msgAndArgs ...any) string { + if len(msgAndArgs) == 0 || msgAndArgs == nil { + return "" + } + if len(msgAndArgs) == 1 { + msg := msgAndArgs[0] + if msgAsStr, ok := msg.(string); ok { + return msgAsStr + } + return fmt.Sprintf("%+v", msg) + } + if len(msgAndArgs) > 1 { + return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } + return "" +} + +type labeledMessage struct { + label string + message string +} + +type labeledMessages []labeledMessage + +func (lm labeledMessages) String() string { + longestLabel := 0 + numLabels := len(lm) + msgSizeTotal := 0 + for _, v := range lm { + if len(v.label) > longestLabel { + longestLabel = len(v.label) + } + msgSizeTotal += len(v.message) + } + + var sb strings.Builder + sb.Grow(msgSizeTotal + numLabels*(longestLabel+8)) + sb.WriteString("\n") + + for _, v := range lm { + sb.WriteString("\t") + sb.WriteString(v.label) + sb.WriteString(":") + sb.WriteString(strings.Repeat(" ", longestLabel-len(v.label))) + sb.WriteString("\t") + + // indent lines + for i, scanner := 0, bufio.NewScanner(strings.NewReader(v.message)); scanner.Scan(); i++ { + // no need to align first line because it starts at the correct location (after the label) + if i != 0 { + // append alignLen+1 spaces to align with "{{longestLabel}}:" before adding tab + sb.WriteString("\n\t") + sb.WriteString(strings.Repeat(" ", longestLabel+1)) + sb.WriteString("\t") + } + // write line + sb.WriteString(scanner.Text()) + } + sb.WriteString("\n") + } + + return sb.String() +} diff --git a/internal/testutils/require/helpers.go b/internal/testutils/require/helpers.go new file mode 100644 index 0000000..3fe65c1 --- /dev/null +++ b/internal/testutils/require/helpers.go @@ -0,0 +1,67 @@ +package require + +import ( + "fmt" + "runtime" + "strings" +) + +/* CallStack is necessary because the assert functions use the testing object +internally, causing it to print the file:line of the assert method, rather than where +the problem actually occurred in calling code.*/ + +// CallStack returns an array of strings containing the file and line number +// of each stack frame leading from the current test to the assert call that +// failed. +func CallStack() []string { + + var ( + pc uintptr + ok bool + file string + line int + name string + ) + + callers := []string{} + for i := 0; ; i++ { + pc, file, line, ok = runtime.Caller(i) + if !ok { + break + } + + if file == "" { + break + } + + f := runtime.FuncForPC(pc) + if f == nil { + break + } + name = f.Name() + + // testing.tRunner is the standard library function that calls tests. + if name == "testing.tRunner" { + break + } + + parts := strings.Split(file, "/") + if len(parts) > 1 { + dir := parts[len(parts)-2] + if dir != "require" { + callers = append(callers, fmt.Sprintf("%s:%d", file, line)) + } + } + + // Drop this package + segments := strings.Split(name, ".") + name = segments[len(segments)-1] + if strings.HasPrefix(name, "Test") || + strings.HasPrefix(name, "Benchmark") || + strings.HasPrefix(name, "Example") { + break + } + } + + return callers +} diff --git a/internal/testutils/require/require.go b/internal/testutils/require/require.go new file mode 100644 index 0000000..2571cab --- /dev/null +++ b/internal/testutils/require/require.go @@ -0,0 +1,140 @@ +package require + +import ( + "errors" + "fmt" + "reflect" + "testing" +) + +func New(t *testing.T) *Require { + return &Require{ + t: t, + } +} + +type Require struct { + t *testing.T +} + +func (r *Require) Equal(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + Equal(r.t, expected, actual, msgAndArgs...) +} + +func (r *Require) NoError(err error, msgAndArgs ...any) { + r.t.Helper() + NoError(r.t, err, msgAndArgs...) +} + +func (r *Require) Error(err error, msgAndArgs ...any) { + r.t.Helper() + Error(r.t, err, msgAndArgs...) +} + +func (r *Require) ErrorIs(expected, actual error, msgAndArgs ...any) { + r.t.Helper() + ErrorIs(r.t, expected, actual, msgAndArgs...) +} + +func (r *Require) NotNil(a any, msgAndArgs ...any) { + r.t.Helper() + NotNil(r.t, a, msgAndArgs...) +} + +func (r *Require) GreaterOrEqual(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + GreaterOrEqual(r.t, expected, actual, msgAndArgs...) +} +func (r *Require) Greater(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + Greater(r.t, expected, actual, msgAndArgs...) +} +func (r *Require) LessOrEqual(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + LessOrEqual(r.t, expected, actual, msgAndArgs...) +} +func (r *Require) Less(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + Less(r.t, expected, actual, msgAndArgs...) +} + +func (r *Require) Zero(a any, msgAndArgs ...any) { + r.t.Helper() + Zero(r.t, a, msgAndArgs...) +} + +func (r *Require) NotZero(a any, msgAndArgs ...any) { + r.t.Helper() + NotZero(r.t, a, msgAndArgs...) +} + +func Equal(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + if !reflect.DeepEqual(expected, actual) { + FailNow(t, fmt.Sprintf("expected: %[1]v (%[1]T), got: %[2]v (%[2]T)", expected, actual), msgAndArgs...) + } +} + +func NoError(t *testing.T, err error, msgAndArgs ...any) { + t.Helper() + if err != nil { + FailNow(t, fmt.Sprintf("expected no error, got: %v", err), msgAndArgs...) + } +} + +func Error(t *testing.T, err error, msgAndArgs ...any) { + t.Helper() + if err != nil { + return + } + FailNow(t, "expected error, got nil", msgAndArgs...) +} + +func ErrorIs(t *testing.T, expected, actual error, msgAndArgs ...any) { + t.Helper() + if errors.Is(actual, expected) { + return + } + FailNow(t, fmt.Sprintf("expected error: %v, got: %v", expected, actual), msgAndArgs...) +} + +func NotNil(t *testing.T, a any, msgAndArgs ...any) { + t.Helper() + if a != nil { + return + } + FailNow(t, "expected not nil, got nil", msgAndArgs...) +} + +func True(t *testing.T, b bool, msgAndArgs ...any) { + t.Helper() + if b { + return + } + FailNow(t, "expected true, got false", msgAndArgs...) +} + +func False(t *testing.T, b bool, msgAndArgs ...any) { + t.Helper() + if !b { + return + } + FailNow(t, "expected false, got true", msgAndArgs...) +} + +func Zero(t *testing.T, a any, msgAndArgs ...any) { + t.Helper() + if reflect.ValueOf(a).IsZero() { + return + } + FailNow(t, fmt.Sprintf("expected zero value, got %v", a), msgAndArgs...) +} + +func NotZero(t *testing.T, a any, msgAndArgs ...any) { + t.Helper() + if !reflect.ValueOf(a).IsZero() { + return + } + FailNow(t, "expected not zero value, got zero", msgAndArgs...) +} diff --git a/varint.go b/varint.go new file mode 100644 index 0000000..75c14a1 --- /dev/null +++ b/varint.go @@ -0,0 +1,331 @@ +package varint + +import ( + "errors" + "fmt" + "io" + "math" +) + +const ( + // max bytes that can be received for one integer + MaxVarintLen32 = 5 + MaxVarintLen64 = 10 +) + +// PutVarint encodes an int32 into buf and returns the number of bytes written. +// If the buffer is too small, PutVarint will panic. +// Try to allocate a buffer of MaxVarintLen32 bytes to avoid panics. +// Format: ESDDDDDD EDDDDDDD EDD... Extended, Sign, Data +// E: is next byte part of the current integer +// S: Sign of integer +// Data, Integer bits that follow the sign +func PutVarint(buf []byte, x int) int { + if x < math.MinInt32 || math.MaxInt32 < x { + panic("ERROR: value to Pack is out of bounds, should be within range [-2147483648:2147483647] (32bit)") + } + + const intSize = 4 + + // stack allocated buffer + data := [MaxVarintLen32]byte{} // predefined content of zeroes + index := 0 + + data[index] = byte(x>>(intSize*8-7)) & 0b01000000 // set sign bit if i<0 + x = x ^ (x >> (intSize*8 - 1)) // if(i<0) i = ~i + + data[index] |= byte(x) & 0b00111111 // pack 6bit into data + x >>= 6 // discard 6 bits + + if x != 0 { + data[index] |= 0b10000000 // set extend bit + + for { + index++ + data[index] = byte(x) & 0b01111111 // pack 7 bits + x >>= 7 // discard 7 bits + + if x != 0 { + data[index] |= 1 << 7 // set extend bit + } else { + break // break if x is 0 + } + + } + } + + size := index + 1 + + if len(buf) < size { + panic(fmt.Sprintf("varint buffer needs to have at least %d bytes but has %d", len(data), len(buf))) + } + + return copy(buf, data[:size]) +} + +// Varint decodes an int from buf and returns that value and the number of bytes read (> 0). +// If an error occurred, the value is 0 and the number of bytes n is <= 0 with the following meaning: +// +// n == 0: buf too small +// n < 0: value larger than 32 bits (overflow) +// and -n is the number of bytes read +func Varint(buf []byte) (i int, n int) { + + if len(buf) == 0 { + return 0, 0 + } + + index := 0 + // handle first byte (most right side) + sign := int((buf[index] >> 6) & 0b00000001) + value := int(buf[index] & 0b00111111) + + // no E bit set, return after parsing first byte + if buf[index] < 0b10000000 { + value ^= -sign // if(sign) value = ~(value) + index++ + return value, index + } + + // handle 2nd - nth byte + buf = buf[1:] + const maxAllowedLen = MaxVarintLen32 - 1 + for i, b := range buf { + index++ + // overflow check + // 1 sign bit + 6 data bits = 7 bits + // 7 bits * 4 bytes = 28 bits + // 7 + 28 = 35 bits, 3 too many + // last byte can only have 4 bits + if i == maxAllowedLen-1 && b > 0b00001111 { + return 0, -(i + 1) + } + + value |= int(b&0b01111111) << (6 + 7*i) + if b < 0b10000000 { + // no extend bit set + break + } + } + + value ^= -sign // if(sign) value = ~(value) + index++ + + return value, index +} + +// AppendVarint appends the varint-encoded form of x, as generated by PutVarint, to buf and returns the extended buffer. +func AppendVarint(buf []byte, x int) []byte { + arr := [MaxVarintLen32]byte{} + sbuf := arr[:] + n := PutVarint(sbuf, x) + sbuf = sbuf[:n] + return append(buf, sbuf...) +} + +// ReadVarint can decode a stream of bytes +func ReadVarint(r io.ByteReader) (int, error) { + b, err := r.ReadByte() + if err != nil { + return 0, err + } + + index := 0 + // handle first byte (most right side) + sign := int((b >> 6) & 0b00000001) + value := int(b & 0b00111111) + + // no E bit set, return after parsing first byte + if b < 0b10000000 { + value ^= -sign // if(sign) value = ~(value) + index++ + return value, nil + } + + // handle 2nd - nth byte + const maxAllowedLen = MaxVarintLen32 - 1 + for i := 0; i < MaxVarintLen32; i++ { + b, err := r.ReadByte() + if err != nil { + if errors.Is(err, io.EOF) { + return value, io.ErrUnexpectedEOF + } + return value, nil + } + index++ + + // overflow check + // 1 sign bit + 6 data bits = 7 bits + // 7 bits * 4 bytes = 28 bits + // 7 + 28 = 35 bits, 3 too many + // 7 - 3 = 7 - 3 = last byte can only have 4 bits + if i == maxAllowedLen-1 && b > 0b00001111 { + return 0, errors.New("overflow due to invalid last byte") + } + + value |= int(b&0b01111111) << (6 + 7*i) + if b < 0b10000000 { + break + } + + } + + value ^= -sign // if(sign) value = ~(value) + index++ + + return value, nil +} + +// PutBigVarint is a variant of varint that can encode 64-bit integers. +// It is out of the scope of the Teeworlds protocol but can be used to check for potential security issues. +func PutBigVarint(buf []byte, x int64) int { + const intSize = 8 + + // stack allocated buffer + data := [MaxVarintLen64]byte{} // predefined content of zeroes + index := 0 + + data[index] = byte(x>>(intSize*8-7)) & 0b01000000 // set sign bit if i<0 + x = x ^ (x >> (intSize*8 - 1)) // if(i<0) i = ~i + + data[index] |= byte(x) & 0b00111111 // pack 6bit into data + x >>= 6 // discard 6 bits + + if x != 0 { + data[index] |= 0b10000000 // set extend bit + + for { + index++ + data[index] = byte(x) & 0b01111111 // pack 7 bits + x >>= 7 // discard 7 bits + + if x != 0 { + data[index] |= 1 << 7 // set extend bit + } else { + break // break if x is 0 + } + } + } + + size := index + 1 + + if len(buf) < size { + panic(fmt.Sprintf("varint buffer needs to have at least %d bytes but has %d", len(data), len(buf))) + } + + return copy(buf, data[:size]) +} + +// Big decodes an int from buf and returns that value and the number of bytes read (> 0). +// If an error occurred, the value is 0 and the number of bytes n is <= 0 with the following meaning: +// +// n == 0: buf too small +// n < 0: value larger than 64 bits (overflow/underflow) +// and -n is the number of bytes read +func BigVarint(buf []byte) (i int64, n int) { + + if len(buf) == 0 { + return 0, 0 + } + + index := 0 + // handle first byte (most right side) + sign := int64((buf[index] >> 6) & 0b00000001) + value := int64(buf[index] & 0b00111111) + + // no E bit set, return after parsing first byte + if buf[index] < 0b10000000 { + value ^= -sign // if(sign) value = ~(value) + index++ + return value, index + } + + // handle 2nd - nth byte + buf = buf[1:] + const maxAllowedLen = MaxVarintLen64 - 1 + for i, b := range buf { + index++ + // overflow check + // 1 sign bit + 6 data bits = 7 bits + // 7 bits * 9 bytes = 63 bits + // 7 + 63 = 70 bits, 6 too many + // 7 - 6 = last byte can only have 1 bit + if i == maxAllowedLen-1 && b > 0b00000001 { + return 0, -(i + 1) + } + + value |= int64(b&0b01111111) << (6 + 7*i) + if b < 0b10000000 { + // no extend bit set + break + } + } + + value ^= -sign // if(sign) value = ~(value) + index++ + + return value, index +} + +// AppendBigVarint appends the varint-encoded form of x, as generated by PutBigVarint, to buf and returns the extended buffer. +// It is out of the scope of the Teeworlds protocol but can be used to check for potential security issues. +func AppendBigVarint(buf []byte, x int64) []byte { + arr := [MaxVarintLen64]byte{} + sbuf := arr[:] + n := PutBigVarint(sbuf, x) + sbuf = sbuf[:n] + return append(buf, sbuf...) +} + +// ReadVarint can decode a stream of bytes +func ReadBigVarint(r io.ByteReader) (int64, error) { + b, err := r.ReadByte() + if err != nil { + return 0, err + } + + index := 0 + // handle first byte (most right side) + sign := int64((b >> 6) & 0b00000001) + value := int64(b & 0b00111111) + + // no E bit set, return after parsing first byte + if b < 0b10000000 { + value ^= -sign // if(sign) value = ~(value) + index++ + return value, nil + } + + // handle 2nd - nth byte + const maxAllowedLen = MaxVarintLen64 - 1 + for i := 0; i < MaxVarintLen64; i++ { + b, err := r.ReadByte() + if err != nil { + if errors.Is(err, io.EOF) { + return value, io.ErrUnexpectedEOF + } + return value, nil + } + index++ + + // overflow check + // 1 sign bit + 6 data bits = 7 bits + // 7 bits * 4 bytes = 28 bits + // 7 + 28 = 35 bits, 3 too many + // 7 - 3 = 7 - 3 = last byte can only have 4 bits + if i == maxAllowedLen-1 && b > 0b00000001 { + return 0, errors.New("overflow due to invalid last byte") + } + + value |= int64(b&0b01111111) << (6 + 7*i) + if b < 0b10000000 { + break + } + + } + + value ^= -sign // if(sign) value = ~(value) + index++ + + return value, nil +} diff --git a/varint_example_test.go b/varint_example_test.go new file mode 100644 index 0000000..54c4aa5 --- /dev/null +++ b/varint_example_test.go @@ -0,0 +1,42 @@ +package varint_test + +import ( + "fmt" + + "github.com/teeworlds-go/varint" +) + +func Example() { + buf := make([]byte, varint.MaxVarintLen32) + written := varint.PutVarint(buf, 33) + out, read := varint.Varint(buf) + fmt.Println("written:", written) + fmt.Println("read:", read) + fmt.Println("value:", out) + // Output: + // written: 1 + // read: 1 + // value: 33 +} + +func ExamplePutVarint() { + buf := make([]byte, varint.MaxVarintLen32) + written := varint.PutVarint(buf, 63) + fmt.Println(written) + fmt.Printf("%b\n", buf[:written]) + // Output: + // 1 + // [111111] +} + +func ExampleVarint() { + // 0b1xxxxxxx - extend bit set + // 0bx0xxxxxx - positive sign + buf := []byte{0b10111111, 0b00000001} + out, read := varint.Varint(buf) + fmt.Println("read:", read) + fmt.Println("value:", out) + // Output: + // read: 2 + // value: 127 +} diff --git a/varint_test.go b/varint_test.go new file mode 100644 index 0000000..e50ccd6 --- /dev/null +++ b/varint_test.go @@ -0,0 +1,200 @@ +package varint_test + +import ( + "bytes" + "io" + "math" + "testing" + + "github.com/teeworlds-go/varint" + "github.com/teeworlds-go/varint/internal/testutils/require" +) + +func TestVarintBoundaries(t *testing.T) { + t.Parallel() + + table := []struct { + inNumber int + expectedBytes int + }{ + // positive + {63, 1}, // 2^6 -1 + {64, 2}, // 2^6 + {8191, 2}, // 2^(6+7) -1 + {8192, 3}, // 2^(6+7) + {1048576 - 1, 3}, // 2^(6+7+7) -1 + {1048576, 4}, // 2^(6+7+7) + {134217728 - 1, 4}, // 2^(6+7+7+7) -1 + {134217728, 5}, // 2^(6+7+7+7) + // negative + {-8191, 2}, // (2^(6+7) -1) *-1 + {-8192, 2}, // (2^(6+7)) *-1 + {-8193, 3}, // (2^(6+7) +1) *-1 + {-1048575, 3}, // (2^(6+7+7) -1) *-1 + {-1048576, 3}, // (2^(6+7+7)) *-1 + {-1048577, 4}, // (2^(6+7+7) +1) *-1 + {-134217727, 4}, // (2^(6+7+7+7) -1) *-1 + {-134217728, 4}, // (2^(6+7+7+7)) *-1 + {-134217729, 5}, // (2^(6+7+7+7) +1) *-1 + // int32 boundaries + {math.MaxInt32, 5}, // 2^31 -1 = 2147483647 + {math.MinInt32, 5}, // -2^31 = -2147483648 + } + + for _, row := range table { + varintWriteRead(t, row.inNumber, row.expectedBytes) + } + + // tested w/o panic code + //varintWriteRead(t, math.MaxInt32+1, 5) // 2^31 = 2147483648 + //varintWriteRead(t, math.MinInt32-1, 5) // -2^31-1 = -2147483649 +} + +func TestOverflowVarint(t *testing.T) { + t.Parallel() + + require := require.New(t) + + table := [][]byte{ + {0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001}, + {0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b0010000}, + {0b11000001, 0b10000001, 0b10000001, 0b10000001, 0b0010000}, // underflow + } + + for _, row := range table { + buf := bytes.NewBuffer(row) + _, err := varint.ReadVarint(buf) + require.Error(err) + + _, n := varint.Varint(row) + require.Less(0, n) + } +} + +func TestOverflowBigVarint(t *testing.T) { + t.Parallel() + + require := require.New(t) + + table := [][]byte{ + {0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b0000001}, + {0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b00000010}, + {0b11000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b10000001, 0b00000010}, // underflow + } + + for _, row := range table { + buf := bytes.NewBuffer(row) + _, err := varint.ReadBigVarint(buf) + require.Error(err) + _, n := varint.BigVarint(row) + require.Less(0, n) + } +} + +func TestEOFVarint(t *testing.T) { + t.Parallel() + + require := require.New(t) + + buf := []byte{0b10000001, 0b10000001, 0b10000001, 0b00000001} + b := bytes.NewBuffer(buf) + + i, err := varint.ReadVarint(b) + require.NoError(err) + require.NotZero(i) + + _, err = varint.ReadVarint(b) + require.ErrorIs(err, io.EOF) +} + +func TestBigVarintBoundaries(t *testing.T) { + t.Parallel() + + table := []struct { + inNumber int64 + expectedBytes int + }{ + // positive + {63, 1}, // 2^6 -1 + {64, 2}, // 2^6 + {8191, 2}, // 2^(6+7) -1 + {8192, 3}, // 2^(6+7) + {1048576 - 1, 3}, // 2^(6+7+7) -1 + {1048576, 4}, // 2^(6+7+7) + {134217728 - 1, 4}, // 2^(6+7+7+7) -1 + {134217728, 5}, // 2^(6+7+7+7) + // big positive + {17179869183, 5}, // 2^(6+4*7) -1 + {17179869184, 6}, // 2^(6+4*7) + {2199023255552 - 1, 6}, // 2^(6+5*7) -1 + {2199023255552, 7}, // 2^(6+5*7) + {281474976710656 - 1, 7}, // 2^(6+6*7) -1 + {281474976710656, 8}, // 2^(6+6*7) + {36028797018963968 - 1, 8}, // 2^(6+7*7) -1 + {36028797018963968, 9}, // 2^(6+7*7) + {4611686018427387904 - 1, 9}, // 2^(6+8*7) -1 + {4611686018427387904, 10}, // 2^(6+8*7) + // negative + {-8191, 2}, // (2^(6+7) -1) *-1 + {-8192, 2}, // (2^(6+7)) *-1 + {-8193, 3}, // (2^(6+7) +1) *-1 + {-1048575, 3}, // (2^(6+7+7) -1) *-1 + {-1048576, 3}, // (2^(6+7+7)) *-1 + {-1048577, 4}, // (2^(6+7+7) +1) *-1 + {-134217727, 4}, // (2^(6+7+7+7) -1) *-1 + {-134217728, 4}, // (2^(6+7+7+7)) *-1 + {-13421779, 4}, // (2^(6+7+7+7) +1) *-1 + // big negative + {-17179869183, 5}, // (2^(6+4*7) -1) *-1 + {-17179869184, 5}, // (2^(6+4*7)) *-1 + {-17179869185, 6}, // (2^(6+4*7) +1) *-1 + {-2199023255551, 6}, // (2^(6+5*7) -1) *-1 + {-2199023255552, 6}, // (2^(6+5*7)) *-1 + {-2199023255553, 7}, // (2^(6+5*7) +1) *-1 + {-281474976710655, 7}, // (2^(6+6*7) -1) *-1 + {-281474976710656, 7}, // (2^(6+6*7)) *-1 + {-281474976710657, 8}, // (2^(6+6*7) +1) *-1 + {-36028797018963967, 8}, // (2^(6+7*7) -1) *-1 + {-36028797018963968, 8}, // (2^(6+7*7)) *-1 + {-36028797018963969, 9}, // (2^(6+7*7) +1) *-1 + {-4611686018427387903, 9}, // (2^(6+8*7) -1) *-1 + {-4611686018427387904, 9}, // (2^(6+8*7)) *-1 + {-4611686018427387905, 10}, // (2^(6+8*7) +1) *-1 + // int32 boundaries + {math.MaxInt32, 5}, // 2^31 -1 = 2147483647 + {math.MinInt32, 5}, // -2^31 = -2147483648 + // int64 boundaries + {math.MaxInt64, 10}, // 2^63 -1 + {math.MinInt64, 10}, // -2^63 + } + + for _, row := range table { + bigVarintWriteRead(t, row.inNumber, row.expectedBytes) + } +} + +func varintWriteRead(t *testing.T, inNumber int, expectedBytes int) { + require := require.New(t) + + buf := make([]byte, varint.MaxVarintLen32) + written := varint.PutVarint(buf, inNumber) + require.Equal(expectedBytes, written) + out, read := varint.Varint(buf) + + require.GreaterOrEqual(1, read, "read must be at least 1 byte") + require.Equal(inNumber, out, "out == in") + require.Equal(written, read, "read == written") +} + +func bigVarintWriteRead(t *testing.T, inNumber int64, expectedBytes int) { + require := require.New(t) + + buf := make([]byte, varint.MaxVarintLen64) + written := varint.PutBigVarint(buf, inNumber) + require.Equal(expectedBytes, written) + out, read := varint.BigVarint(buf) + + require.GreaterOrEqual(1, read, "read must be at least 1 byte") + require.Equal(inNumber, out, "out == in") + require.Equal(written, read, "read == written") +}