diff --git a/decode.go b/decode.go index 1b18c9b..2eb8260 100644 --- a/decode.go +++ b/decode.go @@ -163,11 +163,12 @@ func Decode(dest, src interface{}, os ...Option) error { // the destination. func decodeFrom(opts *options, src interface{}) Into { return func(dest interface{}) error { + var fieldHooks FieldHookFunc hooks := opts.DecodeHooks // fieldHook goes first because it may replace the source data map. if len(opts.FieldHooks) > 0 { - hooks = append(hooks, fieldHook(opts)) + fieldHooks = composeFieldHooks(opts.FieldHooks) } hooks = append( @@ -186,7 +187,8 @@ func decodeFrom(opts *options, src interface{}) Into { DecodeHook: fromDecodeHookFunc( supportPointers(composeDecodeHooks(hooks)), ), - TagName: opts.TagName, + FieldHook: mapstructure.FieldHookFunc(fieldHooks), + TagName: opts.TagName, } decoder, err := mapstructure.NewDecoder(&cfg) diff --git a/decode_test.go b/decode_test.go index aaa46d3..cf7f7e4 100644 --- a/decode_test.go +++ b/decode_test.go @@ -238,7 +238,11 @@ func TestDecode(t *testing.T) { } func TestFieldHook(t *testing.T) { + type embeddedStruct struct { + SomeOtherInt int + } type myStruct struct { + embeddedStruct SomeInt int SomeString string PtrToPtrToString **string @@ -293,12 +297,12 @@ func TestFieldHook(t *testing.T) { "PtrToPtrToString": "hello", }, setupHook: func(h *mockFieldHook) { - h.Expect(_typeOfEmptyInterface, structField{ + h.Expect(structField{ Name: "SomeInt", Type: typeOfInt, }, reflectEq{1}).Return(valueOf(42), nil) - h.Expect(_typeOfEmptyInterface, structField{ + h.Expect(structField{ Name: "PtrToPtrToString", Type: typeOfPtrPtrString, }, reflectEq{"hello"}).Return(valueOf("world"), nil) @@ -308,12 +312,29 @@ func TestFieldHook(t *testing.T) { PtrToPtrToString: ptrToPtrToString("world"), }, }, + { + desc: "embedded updates", + give: map[string]interface{}{ + "someOtherInt": 1, + }, + setupHook: func(h *mockFieldHook) { + h.Expect(structField{ + Name: "SomeOtherInt", + Type: typeOfInt, + }, reflectEq{1}).Return(valueOf(42), nil) + }, + want: myStruct{ + embeddedStruct: embeddedStruct{ + SomeOtherInt: 42, + }, + }, + }, { desc: "field name override", give: map[string]interface{}{"yamlKey": "foo"}, giveOpts: []Option{YAML()}, setupHook: func(h *mockFieldHook) { - h.Expect(_typeOfEmptyInterface, structField{ + h.Expect(structField{ Name: "YAMLField", Type: typeOfString, Tag: `yaml:"yamlKey"`, @@ -326,7 +347,7 @@ func TestFieldHook(t *testing.T) { give: map[string]interface{}{"YAMLKEY": "foo"}, giveOpts: []Option{YAML()}, setupHook: func(h *mockFieldHook) { - h.Expect(_typeOfEmptyInterface, structField{ + h.Expect(structField{ Name: "YAMLField", Type: typeOfString, Tag: `yaml:"yamlKey"`, @@ -341,12 +362,12 @@ func TestFieldHook(t *testing.T) { "PtrToPtrToString": "hello", }, setupHook: func(h *mockFieldHook) { - h.Expect(_typeOfEmptyInterface, structField{ + h.Expect(structField{ Name: "SomeInt", Type: typeOfInt, }, reflectEq{1}).Return(reflect.Value{}, errors.New("great sadness")) - h.Expect(_typeOfEmptyInterface, structField{ + h.Expect(structField{ Name: "PtrToPtrToString", Type: typeOfPtrPtrString, }, reflectEq{"hello"}).Return(reflect.Value{}, errors.New("more sadness")) @@ -363,12 +384,12 @@ func TestFieldHook(t *testing.T) { "someString": 3, }, setupHook: func(h *mockFieldHook) { - h.Expect(typeOfInt, structField{ + h.Expect(structField{ Name: "SomeInt", Type: typeOfInt, }, reflectEq{42}).Return(reflect.ValueOf(100), nil) - h.Expect(typeOfInt, structField{ + h.Expect(structField{ Name: "SomeString", Type: typeOfString, }, reflectEq{3}).Return(reflect.ValueOf("hello"), nil) diff --git a/hooks.go b/hooks.go index 6086394..06cf75e 100644 --- a/hooks.go +++ b/hooks.go @@ -24,12 +24,9 @@ import ( "fmt" "reflect" "strconv" - "strings" "time" "github.com/uber-go/mapdecode/internal/mapstructure" - - "go.uber.org/multierr" ) var ( @@ -38,23 +35,19 @@ var ( ) // FieldHookFunc is a hook called while decoding a specific struct field. It -// receives the source type, information about the target field, and the -// source data. -type FieldHookFunc func(from reflect.Type, to reflect.StructField, data reflect.Value) (reflect.Value, error) +// receives information about the target field, and the source data. +type FieldHookFunc func(dest reflect.StructField, srcData reflect.Value) (reflect.Value, error) func composeFieldHooks(hooks []FieldHookFunc) FieldHookFunc { - return func(from reflect.Type, to reflect.StructField, data reflect.Value) (reflect.Value, error) { + return func(dest reflect.StructField, srcData reflect.Value) (reflect.Value, error) { var err error for _, hook := range hooks { - data, err = hook(from, to, data) + srcData, err = hook(dest, srcData) if err != nil { - return data, err + return srcData, err } - - // Update the `from` type to reflect changes made by the hook. - from = data.Type() } - return data, err + return srcData, err } } @@ -206,129 +199,3 @@ func strconvHook(from, to reflect.Type, data reflect.Value) (reflect.Value, erro return data, nil } - -// fieldHook applies the user-specified FieldHookFunc to all struct fields. -func fieldHook(opts *options) DecodeHookFunc { - hook := composeFieldHooks(opts.FieldHooks) - return func(from, to reflect.Type, data reflect.Value) (reflect.Value, error) { - if to.Kind() != reflect.Struct || from.Kind() != reflect.Map { - return data, nil - } - - // We can only decode map[string]* and map[interface{}]* into structs. - if k := from.Key().Kind(); k != reflect.String && k != reflect.Interface { - return data, nil - } - - // This map tracks type-changing updates to items in the map. - // - // If the source map has a rigid type for values (map[string]string - // rather than map[string]interface{}), we can't make replacements to - // values in-place if a hook changed the type of a value. So we will - // make a copy of the source map with a more liberal type and inject - // these updates into the copy. - updates := make(map[interface{}]interface{}) - - var errors []error - for i := 0; i < to.NumField(); i++ { - structField := to.Field(i) - if structField.PkgPath != "" && !structField.Anonymous { - // This field is not exported so we won't be able to decode - // into it. - continue - } - - // This field resolution logic is adapted from mapstructure's own - // logic. - // - // See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741 - - fieldName := structField.Name - - // Field name override was specified. - tagParts := strings.Split(structField.Tag.Get(opts.TagName), ",") - if tagParts[0] != "" { - fieldName = tagParts[0] - } - - // Get the value for this field from the source map, if any. - key := reflect.ValueOf(fieldName) - value := data.MapIndex(key) - if !value.IsValid() { - // Case-insensitive linear search if the name doesn't match - // as-is. - for _, kV := range data.MapKeys() { - // Kind() == Interface if map[interface{}]* so we use - // Interface().(string) to handle interface{} and string - // keys. - k, ok := kV.Interface().(string) - if !ok { - continue - } - - if strings.EqualFold(k, fieldName) { - key = kV - value = data.MapIndex(kV) - break - } - } - } - - if !value.IsValid() { - // No value specified for this field in source map. - continue - } - - newValue, err := hook(value.Type(), structField, value) - if err != nil { - errors = append(errors, fmt.Errorf( - "error reading into field %q: %v", fieldName, err)) - continue - } - - if newValue == value { - continue - } - - // If we can, assign in-place. - if newValue.Type().AssignableTo(value.Type()) { - // XXX(abg): Is it okay to make updates to the source map? - data.SetMapIndex(key, newValue) - } else { - updates[key.Interface()] = newValue.Interface() - } - } - - if len(errors) > 0 { - return data, multierr.Combine(errors...) - } - - // No more changes to make. - if len(updates) == 0 { - return data, nil - } - - // Equivalent to, - // - // newData := make(map[$key]interface{}) - // for k, v := range data { - // if newV, ok := updates[k]; ok { - // newData[k] = newV - // } else { - // newData[k] = v - // } - // } - newData := reflect.MakeMap(reflect.MapOf(from.Key(), _typeOfEmptyInterface)) - for _, key := range data.MapKeys() { - var value reflect.Value - if v, ok := updates[key.Interface()]; ok { - value = reflect.ValueOf(v) - } else { - value = data.MapIndex(key) - } - newData.SetMapIndex(key, value) - } - - return newData, nil - } -} diff --git a/hooks_test.go b/hooks_test.go index 5676e99..052b0fc 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -43,14 +43,14 @@ func TestMultipleFieldHooks(t *testing.T) { typeOfInt := reflect.TypeOf(42) hook1. - Expect(_typeOfEmptyInterface, structField{ + Expect(structField{ Name: "Int", Type: typeOfInt, }, reflectEq{"FOO"}). Return(valueOf("BAR"), nil) hook2. - Expect(reflect.TypeOf(""), structField{ + Expect(structField{ Name: "Int", Type: typeOfInt, }, reflectEq{"BAR"}). diff --git a/internal/mapstructure/mapstructure.go b/internal/mapstructure/mapstructure.go index 688feef..72f48c0 100644 --- a/internal/mapstructure/mapstructure.go +++ b/internal/mapstructure/mapstructure.go @@ -40,6 +40,10 @@ type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface // source and target types. type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) +// FieldHookFunc is the callback function that can be used to alter +// StructFields before they are decoded into the struct objects. +type FieldHookFunc func(reflect.StructField, reflect.Value) (reflect.Value, error) + // DecoderConfig is the configuration that is used to create a new decoder // and allows customization of various aspects of decoding. type DecoderConfig struct { @@ -51,6 +55,15 @@ type DecoderConfig struct { // error. DecodeHook DecodeHookFunc + // FieldHook, if set, will be called before decoding StructFields + // from the source data. This lets you read struct tags and apply + // mutations to the values before they're set down onto the + // resulting struct. + // + // If an error is returned, the entire decode will fail with that + // error. + FieldHook FieldHookFunc + // If ErrorUnused is true, then it is an error for there to exist // keys in the original map that were unused in the decoding process // (extra keys). @@ -806,6 +819,17 @@ func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) fieldName = fmt.Sprintf("%s.%s", name, fieldName) } + // Apply struct field hooks + if d.config.FieldHook != nil { + newValue, err := d.config.FieldHook(*fieldType, rawMapVal) + if err != nil { + errors = appendErrors(errors, fmt.Errorf( + "error reading into field %q: %v", fieldName, err)) + continue + } + rawMapVal = newValue + } + if err := d.decode(fieldName, rawMapVal.Interface(), field); err != nil { errors = appendErrors(errors, err) } diff --git a/mock_hooks_test.go b/mock_hooks_test.go index 3173596..4e260dd 100644 --- a/mock_hooks_test.go +++ b/mock_hooks_test.go @@ -71,12 +71,12 @@ func (m *mockFieldHook) Hook() FieldHookFunc { } // Expect sets up a call expectation on the hook. -func (m *mockFieldHook) Expect(from, to, data interface{}) *gomock.Call { - return m.c.RecordCall(m, "Call", from, to, data) +func (m *mockFieldHook) Expect(dest, srcData interface{}) *gomock.Call { + return m.c.RecordCall(m, "Call", dest, srcData) } -func (m *mockFieldHook) Call(from reflect.Type, to reflect.StructField, data reflect.Value) (reflect.Value, error) { - results := m.c.Call(m, "Call", from, to, data) +func (m *mockFieldHook) Call(dest reflect.StructField, srcData reflect.Value) (reflect.Value, error) { + results := m.c.Call(m, "Call", dest, srcData) out := results[0].(reflect.Value) err, _ := results[1].(error) return out, err