diff --git a/field/composite.go b/field/composite.go index 639e414..203646c 100644 --- a/field/composite.go +++ b/field/composite.go @@ -234,17 +234,23 @@ func (f *Composite) Marshal(v interface{}) error { defer f.mu.Unlock() rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return errors.New("data is not a pointer or nil") + if rv.Kind() != reflect.Ptr { + return errors.New("data is not a pointer") } - // get the struct from the pointer - dataStruct := rv.Elem() + elemType := rv.Type().Elem() + if elemType.Kind() != reflect.Struct { + return errors.New("data must be a pointer to struct") + } - if dataStruct.Kind() != reflect.Struct { - return errors.New("data is not a struct") + // If nil, create a new instance of the struct + if rv.IsNil() { + rv = reflect.New(elemType) } + // get the struct from the pointer + dataStruct := rv.Elem() + // iterate over struct fields for i := 0; i < dataStruct.NumField(); i++ { indexTag := NewIndexTag(dataStruct.Type().Field(i)) @@ -258,7 +264,7 @@ func (f *Composite) Marshal(v interface{}) error { } dataField := dataStruct.Field(i) - if dataField.IsZero() { + if dataField.IsZero() && !indexTag.KeepZero { continue } diff --git a/field/composite_test.go b/field/composite_test.go index 7651b88..3d176d1 100644 --- a/field/composite_test.go +++ b/field/composite_test.go @@ -351,10 +351,23 @@ type SubConstructedTLVTestData struct { } func TestCompositeField_Marshal(t *testing.T) { + t.Run("Marshal returns error on nil value", func(t *testing.T) { + composite := NewComposite(compositeTestSpec) + err := composite.Marshal(nil) + require.EqualError(t, err, "data is not a pointer") + }) + + t.Run("Marshal doesn't return an error when nil pointer is of struct type", func(t *testing.T) { + composite := NewComposite(compositeTestSpec) + var data *CompositeTestData + err := composite.Marshal(data) + require.NoError(t, err) + }) + t.Run("Marshal returns an error on provision of primitive type data", func(t *testing.T) { composite := NewComposite(compositeTestSpec) err := composite.Marshal("primitive str") - require.EqualError(t, err, "data is not a pointer or nil") + require.EqualError(t, err, "data is not a pointer") }) t.Run("Marshal skips fields without index tag", func(t *testing.T) { diff --git a/field/index_tag.go b/field/index_tag.go index 237a136..e2aa9a1 100644 --- a/field/index_tag.go +++ b/field/index_tag.go @@ -3,6 +3,7 @@ package field import ( "reflect" "regexp" + "slices" "strconv" "strings" ) @@ -20,6 +21,18 @@ type IndexTag struct { } func NewIndexTag(field reflect.StructField) IndexTag { + indexTag := extractTagInfo(field) + + if indexTag.Tag == "" { + id, tag := extractIdAndTagFromName(field.Name) + indexTag.ID = id + indexTag.Tag = tag + } + + return indexTag +} + +func extractTagInfo(field reflect.StructField) IndexTag { // value of the key "index" in the tag var value string @@ -34,66 +47,52 @@ func NewIndexTag(field reflect.StructField) IndexTag { // format of the value is "id[,keep_zero_value]" // id is the id of the field // let's parse it - if value != "" { - tag, opts := parseTag(value) - - id, err := strconv.Atoi(tag) - if err != nil { - id = -1 - } - + if value == "" { return IndexTag{ - ID: id, - Tag: tag, - KeepZero: opts.Contains("keepzero"), + ID: -1, } } - dataFieldName := field.Name - if len(dataFieldName) > 0 && fieldNameIndexRe.MatchString(dataFieldName) { - indexStr := dataFieldName[1:] + tag, opts := parseTag(value) + id, err := strconv.Atoi(tag) + if err != nil { + id = -1 + } + + return IndexTag{ + ID: id, + Tag: tag, + KeepZero: opts.Contains("keepzero"), + } +} + +func extractIdAndTagFromName(fieldName string) (int, string) { + if len(fieldName) > 0 && fieldNameIndexRe.MatchString(fieldName) { + indexStr := fieldName[1:] fieldIndex, err := strconv.Atoi(indexStr) if err != nil { - return IndexTag{ - ID: -1, - Tag: indexStr, - } + return -1, indexStr } - return IndexTag{ - ID: fieldIndex, - Tag: indexStr, - } + return fieldIndex, indexStr } - return IndexTag{ - ID: -1, - } + return -1, "" } -type tagOptions string +type tagOptions []string // parseTag splits a struct field's index tag into its id and // comma-separated options. func parseTag(tag string) (string, tagOptions) { tag, opt, _ := strings.Cut(tag, ",") - return tag, tagOptions(opt) + + return tag, tagOptions(strings.Split(opt, ",")) } // Contains reports whether a comma-separated list of options // contains a particular substr flag. substr must be surrounded by a // string boundary or commas. func (o tagOptions) Contains(optionName string) bool { - if len(o) == 0 { - return false - } - s := string(o) - for s != "" { - var name string - name, s, _ = strings.Cut(s, ",") - if name == optionName { - return true - } - } - return false + return slices.Contains(o, optionName) } diff --git a/message_test.go b/message_test.go index 09fd185..32d192f 100644 --- a/message_test.go +++ b/message_test.go @@ -367,6 +367,79 @@ func TestMessage(t *testing.T) { wantMsg := []byte("01007000000000000000164242424242424242123456000000000100") require.Equal(t, wantMsg, rawMsg) }) + + t.Run("Clone, set zero values and reset fields", func(t *testing.T) { + type TestISOF3Data struct { + F1 *field.String + F2 *field.String + F3 *field.String + } + + type ISO87Data struct { + F0 *field.String + F2 *field.String + F3 *TestISOF3Data + F4 *field.String + } + + message := NewMessage(spec) + err := message.Marshal(&ISO87Data{ + F0: field.NewStringValue("0100"), + F2: field.NewStringValue("4242424242424242"), + F3: &TestISOF3Data{ + F1: field.NewStringValue("12"), + F2: field.NewStringValue("34"), + F3: field.NewStringValue("56"), + }, + F4: field.NewStringValue("100"), + }) + require.NoError(t, err) + + // clone the message and reset some fields + clone, err := message.Clone() + require.NoError(t, err) + + // reset the fields + // first, check that the fields are set + data := &ISO87Data{} + require.NoError(t, clone.Unmarshal(data)) + + require.Equal(t, "0100", data.F0.Value()) + require.Equal(t, "4242424242424242", data.F2.Value()) + require.Equal(t, "12", data.F3.F1.Value()) + require.Equal(t, "34", data.F3.F2.Value()) + require.Equal(t, "56", data.F3.F3.Value()) + require.Equal(t, "100", data.F4.Value()) + + // reset the fields + err = clone.Marshal(&struct { + F2 *field.String `iso8583:",keepzero"` + F3 *struct { + F2 *field.String `iso8583:",keepzero"` + } `iso8583:",keepzero"` + }{}) + require.NoError(t, err) + + // check that the field values are set to zero values + data = &ISO87Data{} + require.NoError(t, clone.Unmarshal(data)) + + // check that fields are set + require.NotNil(t, data.F2) + require.NotNil(t, data.F3) + require.NotNil(t, data.F3.F2) + + // check the zero values + require.Equal(t, "", data.F2.Value()) + require.Equal(t, "", data.F3.F2.Value()) + + // check the reset fields in the message + require.Equal(t, "0100", data.F0.Value()) + require.Equal(t, "12", data.F3.F1.Value()) + require.Equal(t, "56", data.F3.F3.Value()) + require.Equal(t, "100", data.F4.Value()) + }) + } func TestPackUnpack(t *testing.T) {