diff --git a/binding/bind.go b/binding/bind.go index 4283d89..02acc29 100644 --- a/binding/bind.go +++ b/binding/bind.go @@ -372,7 +372,7 @@ func (b *Binding) getOrPrepareReceiver(value reflect.Value) (*receiver, error) { return nil, b.bindErrFactory(errExprSelector.String(), errMsg) } if !recv.hasVd { - recv.hasVd, _ = b.findVdTag(ameda.DereferenceType(t), false, 20) + recv.hasVd, _ = b.findVdTag(ameda.DereferenceType(t), false, 20, map[reflect.Type]bool{}) } recv.initParams() @@ -383,13 +383,14 @@ func (b *Binding) getOrPrepareReceiver(value reflect.Value) (*receiver, error) { return recv, nil } -func (b *Binding) findVdTag(t reflect.Type, inMapOrSlice bool, depth int) (hasVd bool, err error) { - if depth <= 0 { +func (b *Binding) findVdTag(t reflect.Type, inMapOrSlice bool, depth int, exist map[reflect.Type]bool) (hasVd bool, err error) { + if depth <= 0 || exist[t] { return } depth-- switch t.Kind() { case reflect.Struct: + exist[t] = true for i := t.NumField() - 1; i >= 0; i-- { field := t.Field(i) if inMapOrSlice { @@ -400,14 +401,14 @@ func (b *Binding) findVdTag(t reflect.Type, inMapOrSlice bool, depth int) (hasVd } } } - hasVd, _ = b.findVdTag(ameda.DereferenceType(field.Type), inMapOrSlice, depth) + hasVd, _ = b.findVdTag(ameda.DereferenceType(field.Type), inMapOrSlice, depth, exist) if hasVd { return true, nil } } return false, nil case reflect.Slice, reflect.Array, reflect.Map: - return b.findVdTag(ameda.DereferenceType(t.Elem()), true, depth) + return b.findVdTag(ameda.DereferenceType(t.Elem()), true, depth, exist) default: return false, nil } diff --git a/binding/bind_test.go b/binding/bind_test.go index 2825a37..4113879 100644 --- a/binding/bind_test.go +++ b/binding/bind_test.go @@ -1181,3 +1181,18 @@ func TestDefault2(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "hello Dash", (**recv.X).Dash) } + +func TestVdTagRecursion(t *testing.T) { + type Node struct { + N1 *Node + N2 *Node + N3 *Node + } + recv := &Node{} + req, _ := http.NewRequest("get", "http://localhost/", bytes.NewReader([]byte{})) + start := time.Now() + binder := binding.New(nil) + err := binder.BindAndValidate(recv, req, new(testPathParams2)) + assert.NoError(t, err) + assert.Less(t, int64(time.Since(start)), int64(time.Second)) +}