Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max recursion depth #15

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func (kvs KVS) MarshalJSON() ([]byte, error) {
type Decoder struct {
*scanner
emitDepth int
maxDepth int
emitKV bool
emitRecursive bool
objectAsKVS bool
Expand Down Expand Up @@ -137,6 +138,14 @@ func (d *Decoder) Pos() int { return int(d.pos) }
// Err returns the most recent decoder error if any, or nil
func (d *Decoder) Err() error { return d.err }

// MaxDepth will set the maximum recursion depth.
// If the maximum depth is exceeded, ErrMaxDepth is returned.
// Less than or 0 means no limit (default).
func (d *Decoder) MaxDepth(n int) *Decoder {
d.maxDepth = n
return d
}

// Decode parses the JSON-encoded data and returns an interface value
func (d *Decoder) decode() {
defer close(d.metaCh)
Expand Down Expand Up @@ -191,7 +200,7 @@ func (d *Decoder) any() (interface{}, ValueType, error) {
i, err := d.number()
return i, Number, err
case '-':
if c = d.next(); c < '0' && c > '9' {
if c = d.next(); c < '0' || c > '9' {
return nil, Unknown, d.mkError(ErrSyntax, "in negative numeric literal")
}
n, err := d.number()
Expand Down Expand Up @@ -365,7 +374,7 @@ func (d *Decoder) number() (float64, error) {
d.scratch.add(c)

// first char following must be digit
if c = d.next(); c < '0' && c > '9' {
if c = d.next(); c < '0' || c > '9' {
return 0, d.mkError(ErrSyntax, "after decimal point in numeric literal")
}
d.scratch.add(c)
Expand Down Expand Up @@ -417,6 +426,9 @@ func (d *Decoder) number() (float64, error) {
// array accept valid JSON array value
func (d *Decoder) array() ([]interface{}, error) {
d.depth++
if d.maxDepth > 0 && d.depth > d.maxDepth {
return nil, ErrMaxDepth
}

var (
c byte
Expand Down Expand Up @@ -458,6 +470,9 @@ out:
// object accept valid JSON array value
func (d *Decoder) object() (map[string]interface{}, error) {
d.depth++
if d.maxDepth > 0 && d.depth > d.maxDepth {
return nil, ErrMaxDepth
}

var (
c byte
Expand Down Expand Up @@ -543,6 +558,9 @@ out:
// object (ordered) accept valid JSON array value
func (d *Decoder) objectOrdered() (KVS, error) {
d.depth++
if d.maxDepth > 0 && d.depth > d.maxDepth {
return nil, ErrMaxDepth
}

var (
c byte
Expand Down
39 changes: 39 additions & 0 deletions decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,42 @@ func TestDecoderReaderFailure(t *testing.T) {
t.Fatalf("missing expected underlying reader error")
}
}

func TestDecoderMaxDepth(t *testing.T) {
tests := []struct {
input string
maxDepth int
mustFail bool
}{
// No limit
{input: `[{"bio":"bada bing bada boom","id":1,"name":"Charles","falseVal":false}]`, maxDepth: 0, mustFail: false},
// Array + object = depth 2 = false
{input: `[{"bio":"bada bing bada boom","id":1,"name":"Charles","falseVal":false}]`, maxDepth: 1, mustFail: true},
// Depth 2 = ok
{input: `[{"bio":"bada bing bada boom","id":1,"name":"Charles","falseVal":false}]`, maxDepth: 2, mustFail: false},
// Arrays:
{input: `[[[[[[[[[[[[[[[[[[[[[["ok"]]]]]]]]]]]]]]]]]]]]]]`, maxDepth: 2, mustFail: true},
{input: `[[[[[[[[[[[[[[[[[[[[[["ok"]]]]]]]]]]]]]]]]]]]]]]`, maxDepth: 10, mustFail: true},
{input: `[[[[[[[[[[[[[[[[[[[[[["ok"]]]]]]]]]]]]]]]]]]]]]]`, maxDepth: 100, mustFail: false},
// Objects:
{input: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"ok":false}}}}}}}}}}}}}}}}}}}}}}`, maxDepth: 2, mustFail: true},
{input: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"ok":false}}}}}}}}}}}}}}}}}}}}}}`, maxDepth: 10, mustFail: true},
{input: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"ok":false}}}}}}}}}}}}}}}}}}}}}}`, maxDepth: 100, mustFail: false},
}

for _, test := range tests {
decoder := NewDecoder(mkReader(test.input), 0).MaxDepth(test.maxDepth)
var mv *MetaValue
for mv = range decoder.Stream() {
t.Logf("depth=%d offset=%d len=%d (%v)", mv.Depth, mv.Offset, mv.Length, mv.Value)
}

err := decoder.Err()
if test.mustFail && err != ErrMaxDepth {
t.Fatalf("missing expected decoder error, got %q", err)
}
if !test.mustFail && err != nil {
t.Fatalf("unexpected error: %q", err)
}
}
}
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
var (
ErrSyntax = DecoderError{msg: "invalid character"}
ErrUnexpectedEOF = DecoderError{msg: "unexpected end of JSON input"}
ErrMaxDepth = DecoderError{msg: "maximum recursion depth exceeded"}
)

type errPos [2]int // line number, byte offset where error occurred
Expand Down