Skip to content
This repository has been archived by the owner on Oct 2, 2020. It is now read-only.

Commit

Permalink
Implement FieldHooks natively in mapstructure
Browse files Browse the repository at this point in the history
Summary: There needs to be parity between how we parse a struct at
decode time and passing it when applying the DecodeHooks. The simplest
way to do this is to apply field hooks at decode time, and stop piggy-
backing on the DecodeHooks codepath.

In the process of doing this, I also cleaned up the FieldHook API to
only expose the types it will actually use.

Test Plan: tests pass
  • Loading branch information
willhug authored May 24, 2017
1 parent 40ed007 commit 531089d
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 155 deletions.
6 changes: 4 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,12 @@ func Decode(dest, src interface{}, os ...Option) error {
// the destination.
func decodeFrom(opts *options, src interface{}) Into {
return func(dest interface{}) error {
var fieldHooks FieldHookFunc
hooks := opts.DecodeHooks

// fieldHook goes first because it may replace the source data map.
if len(opts.FieldHooks) > 0 {
hooks = append(hooks, fieldHook(opts))
fieldHooks = composeFieldHooks(opts.FieldHooks)
}

hooks = append(
Expand All @@ -186,7 +187,8 @@ func decodeFrom(opts *options, src interface{}) Into {
DecodeHook: fromDecodeHookFunc(
supportPointers(composeDecodeHooks(hooks)),
),
TagName: opts.TagName,
FieldHook: mapstructure.FieldHookFunc(fieldHooks),
TagName: opts.TagName,
}

decoder, err := mapstructure.NewDecoder(&cfg)
Expand Down
37 changes: 29 additions & 8 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ func TestDecode(t *testing.T) {
}

func TestFieldHook(t *testing.T) {
type embeddedStruct struct {
SomeOtherInt int
}
type myStruct struct {
embeddedStruct
SomeInt int
SomeString string
PtrToPtrToString **string
Expand Down Expand Up @@ -293,12 +297,12 @@ func TestFieldHook(t *testing.T) {
"PtrToPtrToString": "hello",
},
setupHook: func(h *mockFieldHook) {
h.Expect(_typeOfEmptyInterface, structField{
h.Expect(structField{
Name: "SomeInt",
Type: typeOfInt,
}, reflectEq{1}).Return(valueOf(42), nil)

h.Expect(_typeOfEmptyInterface, structField{
h.Expect(structField{
Name: "PtrToPtrToString",
Type: typeOfPtrPtrString,
}, reflectEq{"hello"}).Return(valueOf("world"), nil)
Expand All @@ -308,12 +312,29 @@ func TestFieldHook(t *testing.T) {
PtrToPtrToString: ptrToPtrToString("world"),
},
},
{
desc: "embedded updates",
give: map[string]interface{}{
"someOtherInt": 1,
},
setupHook: func(h *mockFieldHook) {
h.Expect(structField{
Name: "SomeOtherInt",
Type: typeOfInt,
}, reflectEq{1}).Return(valueOf(42), nil)
},
want: myStruct{
embeddedStruct: embeddedStruct{
SomeOtherInt: 42,
},
},
},
{
desc: "field name override",
give: map[string]interface{}{"yamlKey": "foo"},
giveOpts: []Option{YAML()},
setupHook: func(h *mockFieldHook) {
h.Expect(_typeOfEmptyInterface, structField{
h.Expect(structField{
Name: "YAMLField",
Type: typeOfString,
Tag: `yaml:"yamlKey"`,
Expand All @@ -326,7 +347,7 @@ func TestFieldHook(t *testing.T) {
give: map[string]interface{}{"YAMLKEY": "foo"},
giveOpts: []Option{YAML()},
setupHook: func(h *mockFieldHook) {
h.Expect(_typeOfEmptyInterface, structField{
h.Expect(structField{
Name: "YAMLField",
Type: typeOfString,
Tag: `yaml:"yamlKey"`,
Expand All @@ -341,12 +362,12 @@ func TestFieldHook(t *testing.T) {
"PtrToPtrToString": "hello",
},
setupHook: func(h *mockFieldHook) {
h.Expect(_typeOfEmptyInterface, structField{
h.Expect(structField{
Name: "SomeInt",
Type: typeOfInt,
}, reflectEq{1}).Return(reflect.Value{}, errors.New("great sadness"))

h.Expect(_typeOfEmptyInterface, structField{
h.Expect(structField{
Name: "PtrToPtrToString",
Type: typeOfPtrPtrString,
}, reflectEq{"hello"}).Return(reflect.Value{}, errors.New("more sadness"))
Expand All @@ -363,12 +384,12 @@ func TestFieldHook(t *testing.T) {
"someString": 3,
},
setupHook: func(h *mockFieldHook) {
h.Expect(typeOfInt, structField{
h.Expect(structField{
Name: "SomeInt",
Type: typeOfInt,
}, reflectEq{42}).Return(reflect.ValueOf(100), nil)

h.Expect(typeOfInt, structField{
h.Expect(structField{
Name: "SomeString",
Type: typeOfString,
}, reflectEq{3}).Return(reflect.ValueOf("hello"), nil)
Expand Down
145 changes: 6 additions & 139 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"time"

"github.com/uber-go/mapdecode/internal/mapstructure"

"go.uber.org/multierr"
)

var (
Expand All @@ -38,23 +35,19 @@ var (
)

// FieldHookFunc is a hook called while decoding a specific struct field. It
// receives the source type, information about the target field, and the
// source data.
type FieldHookFunc func(from reflect.Type, to reflect.StructField, data reflect.Value) (reflect.Value, error)
// receives information about the target field, and the source data.
type FieldHookFunc func(dest reflect.StructField, srcData reflect.Value) (reflect.Value, error)

func composeFieldHooks(hooks []FieldHookFunc) FieldHookFunc {
return func(from reflect.Type, to reflect.StructField, data reflect.Value) (reflect.Value, error) {
return func(dest reflect.StructField, srcData reflect.Value) (reflect.Value, error) {
var err error
for _, hook := range hooks {
data, err = hook(from, to, data)
srcData, err = hook(dest, srcData)
if err != nil {
return data, err
return srcData, err
}

// Update the `from` type to reflect changes made by the hook.
from = data.Type()
}
return data, err
return srcData, err
}
}

Expand Down Expand Up @@ -206,129 +199,3 @@ func strconvHook(from, to reflect.Type, data reflect.Value) (reflect.Value, erro

return data, nil
}

// fieldHook applies the user-specified FieldHookFunc to all struct fields.
func fieldHook(opts *options) DecodeHookFunc {
hook := composeFieldHooks(opts.FieldHooks)
return func(from, to reflect.Type, data reflect.Value) (reflect.Value, error) {
if to.Kind() != reflect.Struct || from.Kind() != reflect.Map {
return data, nil
}

// We can only decode map[string]* and map[interface{}]* into structs.
if k := from.Key().Kind(); k != reflect.String && k != reflect.Interface {
return data, nil
}

// This map tracks type-changing updates to items in the map.
//
// If the source map has a rigid type for values (map[string]string
// rather than map[string]interface{}), we can't make replacements to
// values in-place if a hook changed the type of a value. So we will
// make a copy of the source map with a more liberal type and inject
// these updates into the copy.
updates := make(map[interface{}]interface{})

var errors []error
for i := 0; i < to.NumField(); i++ {
structField := to.Field(i)
if structField.PkgPath != "" && !structField.Anonymous {
// This field is not exported so we won't be able to decode
// into it.
continue
}

// This field resolution logic is adapted from mapstructure's own
// logic.
//
// See https://github.com/mitchellh/mapstructure/blob/53818660ed4955e899c0bcafa97299a388bd7c8e/mapstructure.go#L741

fieldName := structField.Name

// Field name override was specified.
tagParts := strings.Split(structField.Tag.Get(opts.TagName), ",")
if tagParts[0] != "" {
fieldName = tagParts[0]
}

// Get the value for this field from the source map, if any.
key := reflect.ValueOf(fieldName)
value := data.MapIndex(key)
if !value.IsValid() {
// Case-insensitive linear search if the name doesn't match
// as-is.
for _, kV := range data.MapKeys() {
// Kind() == Interface if map[interface{}]* so we use
// Interface().(string) to handle interface{} and string
// keys.
k, ok := kV.Interface().(string)
if !ok {
continue
}

if strings.EqualFold(k, fieldName) {
key = kV
value = data.MapIndex(kV)
break
}
}
}

if !value.IsValid() {
// No value specified for this field in source map.
continue
}

newValue, err := hook(value.Type(), structField, value)
if err != nil {
errors = append(errors, fmt.Errorf(
"error reading into field %q: %v", fieldName, err))
continue
}

if newValue == value {
continue
}

// If we can, assign in-place.
if newValue.Type().AssignableTo(value.Type()) {
// XXX(abg): Is it okay to make updates to the source map?
data.SetMapIndex(key, newValue)
} else {
updates[key.Interface()] = newValue.Interface()
}
}

if len(errors) > 0 {
return data, multierr.Combine(errors...)
}

// No more changes to make.
if len(updates) == 0 {
return data, nil
}

// Equivalent to,
//
// newData := make(map[$key]interface{})
// for k, v := range data {
// if newV, ok := updates[k]; ok {
// newData[k] = newV
// } else {
// newData[k] = v
// }
// }
newData := reflect.MakeMap(reflect.MapOf(from.Key(), _typeOfEmptyInterface))
for _, key := range data.MapKeys() {
var value reflect.Value
if v, ok := updates[key.Interface()]; ok {
value = reflect.ValueOf(v)
} else {
value = data.MapIndex(key)
}
newData.SetMapIndex(key, value)
}

return newData, nil
}
}
4 changes: 2 additions & 2 deletions hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ func TestMultipleFieldHooks(t *testing.T) {
typeOfInt := reflect.TypeOf(42)

hook1.
Expect(_typeOfEmptyInterface, structField{
Expect(structField{
Name: "Int",
Type: typeOfInt,
}, reflectEq{"FOO"}).
Return(valueOf("BAR"), nil)

hook2.
Expect(reflect.TypeOf(""), structField{
Expect(structField{
Name: "Int",
Type: typeOfInt,
}, reflectEq{"BAR"}).
Expand Down
24 changes: 24 additions & 0 deletions internal/mapstructure/mapstructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface
// source and target types.
type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error)

// FieldHookFunc is the callback function that can be used to alter
// StructFields before they are decoded into the struct objects.
type FieldHookFunc func(reflect.StructField, reflect.Value) (reflect.Value, error)

// DecoderConfig is the configuration that is used to create a new decoder
// and allows customization of various aspects of decoding.
type DecoderConfig struct {
Expand All @@ -51,6 +55,15 @@ type DecoderConfig struct {
// error.
DecodeHook DecodeHookFunc

// FieldHook, if set, will be called before decoding StructFields
// from the source data. This lets you read struct tags and apply
// mutations to the values before they're set down onto the
// resulting struct.
//
// If an error is returned, the entire decode will fail with that
// error.
FieldHook FieldHookFunc

// If ErrorUnused is true, then it is an error for there to exist
// keys in the original map that were unused in the decoding process
// (extra keys).
Expand Down Expand Up @@ -806,6 +819,17 @@ func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value)
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
}

// Apply struct field hooks
if d.config.FieldHook != nil {
newValue, err := d.config.FieldHook(*fieldType, rawMapVal)
if err != nil {
errors = appendErrors(errors, fmt.Errorf(
"error reading into field %q: %v", fieldName, err))
continue
}
rawMapVal = newValue
}

if err := d.decode(fieldName, rawMapVal.Interface(), field); err != nil {
errors = appendErrors(errors, err)
}
Expand Down
8 changes: 4 additions & 4 deletions mock_hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ func (m *mockFieldHook) Hook() FieldHookFunc {
}

// Expect sets up a call expectation on the hook.
func (m *mockFieldHook) Expect(from, to, data interface{}) *gomock.Call {
return m.c.RecordCall(m, "Call", from, to, data)
func (m *mockFieldHook) Expect(dest, srcData interface{}) *gomock.Call {
return m.c.RecordCall(m, "Call", dest, srcData)
}

func (m *mockFieldHook) Call(from reflect.Type, to reflect.StructField, data reflect.Value) (reflect.Value, error) {
results := m.c.Call(m, "Call", from, to, data)
func (m *mockFieldHook) Call(dest reflect.StructField, srcData reflect.Value) (reflect.Value, error) {
results := m.c.Call(m, "Call", dest, srcData)
out := results[0].(reflect.Value)
err, _ := results[1].(error)
return out, err
Expand Down

0 comments on commit 531089d

Please sign in to comment.