Skip to content

Commit

Permalink
Merge pull request #22 from future-architect/feature/fix-embeded-time…
Browse files Browse the repository at this point in the history
…-support

fix embeded time.Time support
  • Loading branch information
ma91n authored Jun 28, 2022
2 parents 10286e5 + 17371e9 commit c4a57b5
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
26 changes: 20 additions & 6 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func encode(dest map[string]interface{}, src interface{}) error {
}

// tagscanner does not support nest struct type.
encodeNestStructTyp(src, dest, tags)
encodeStructField(src, dest, tags)

return nil
}
Expand All @@ -201,21 +201,35 @@ func convertToMapStringAny(mp reflect.Value, dest map[string]interface{}) bool {
return true
}

func encodeNestStructTyp(src interface{}, dest map[string]interface{}, tags []string) {
srcFieldTyps := reflect.ValueOf(src).Type().Elem()
srcFieldValues := reflect.ValueOf(src).Elem()
func encodeStructField(src interface{}, dest map[string]interface{}, tags []string) {
srcFieldValues := reflect.ValueOf(src)
srcFieldTyps := srcFieldValues.Type()
if srcFieldTyps.Kind() == reflect.Pointer {
srcFieldTyps = srcFieldTyps.Elem()
srcFieldValues = srcFieldValues.Elem()
}
for i := 0; i < srcFieldTyps.NumField(); i++ {
srcFieldTyp := srcFieldTyps.Field(i)
srcFieldValue := srcFieldValues.Field(i)

tagValue := getTagValue(srcFieldTyp.Tag, tags)
if tagValue == "" {

if srcFieldTyp.Type.Kind() != reflect.Struct {
continue
}
srcFieldValue := srcFieldValues.Field(i)
switch srcFieldTyp.Type.PkgPath() {
case "database/sql":
if tagValue == "" {
continue
}
encodeSQLNullTyp(srcFieldValue, dest, tagValue)
case "time":
if tagValue == "" {
continue
}
encodeTimeTyp(srcFieldValue, dest, tagValue)
default:
encodeStructField(srcFieldValue.Interface(), dest, tags)
}
}
}
Expand Down
51 changes: 51 additions & 0 deletions eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,12 @@ func TestEvalWithMap(t *testing.T) {
}

func TestEval_NestStructTyp(t *testing.T) {
type Embed struct {
EmbedNullString sql.NullString `db:"embed_null_string"`
EmbedTime time.Time `db:"embed_time"`
EmbedPtrTime *time.Time `db:"embed_ptr_time"`
}

type SQLTypInfo struct {
NullBool sql.NullBool `db:"null_bool"`
NullFloat64 sql.NullFloat64 `db:"null_float_64"`
Expand All @@ -1131,6 +1137,9 @@ func TestEval_NestStructTyp(t *testing.T) {
NullString sql.NullString `db:"null_string"`
NullTime sql.NullTime `db:"null_time"`
Time time.Time `db:"time"`
PtrTime *time.Time `db:"ptr_time"`

Embed Embed
}

tests := []struct {
Expand Down Expand Up @@ -1215,13 +1224,55 @@ func TestEval_NestStructTyp(t *testing.T) {
wantQuery: `SELECT * FROM person WHERE value = ?/*time*/`,
wantParams: []interface{}{time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)},
},
{
name: "bind ptr time.Time",
input: `SELECT * FROM person WHERE value = /*ptr_time*/'2022-01-01 10:00:00'`,
inputParams: SQLTypInfo{
PtrTime: func() *time.Time {
d := time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)
return &d
}(),
},
wantQuery: `SELECT * FROM person WHERE value = ?/*ptr_time*/`,
wantParams: []interface{}{time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)},
},
{
name: "bind initial",
input: `SELECT * FROM person WHERE value = /*null_string*/'hoge'`,
inputParams: SQLTypInfo{},
wantQuery: `SELECT * FROM person WHERE value = ?/*null_string*/`,
wantParams: []interface{}{nil},
},
{
name: "bind embed time.Time",
input: `SELECT * FROM person WHERE value = /*embed_time*/'2022-01-01 10:00:00'`,
inputParams: SQLTypInfo{
Embed: Embed{EmbedTime: time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)},
},
wantQuery: `SELECT * FROM person WHERE value = ?/*embed_time*/`,
wantParams: []interface{}{time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)},
},
{
name: "bind embed ptr time.Time",
input: `SELECT * FROM person WHERE value = /*embed_ptr_time*/'2022-01-01 10:00:00'`,
inputParams: SQLTypInfo{
Embed: Embed{EmbedPtrTime: func() *time.Time {
d := time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)
return &d
}()},
},
wantQuery: `SELECT * FROM person WHERE value = ?/*embed_ptr_time*/`,
wantParams: []interface{}{time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)},
},
{
name: "bind embed sql.NullString",
input: `SELECT * FROM person WHERE value = /*embed_null_string*/'embed_null_string'`,
inputParams: SQLTypInfo{
Embed: Embed{EmbedNullString: sql.NullString{String: "value", Valid: true}},
},
wantQuery: `SELECT * FROM person WHERE value = ?/*embed_null_string*/`,
wantParams: []interface{}{"value"},
},
{
name: "bind invalid",
input: `SELECT * FROM person WHERE value = /*null_string*/'hoge'`,
Expand Down

0 comments on commit c4a57b5

Please sign in to comment.