From 67e4a91b874a26e1623f816e310dd95ee9e1f676 Mon Sep 17 00:00:00 2001 From: liujian Date: Wed, 17 Jan 2024 15:14:28 +0800 Subject: [PATCH] Update xsql enum --- src/xsql/db_test.go | 36 ++++++++++++++++++++++++------------ src/xsql/fetcher.go | 12 ++++++------ 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/xsql/db_test.go b/src/xsql/db_test.go index f4f0d81..74f0c2a 100644 --- a/src/xsql/db_test.go +++ b/src/xsql/db_test.go @@ -13,11 +13,14 @@ import ( "time" ) +type Enum int32 + type Test struct { - Id int `xsql:"id"` - Foo string `xsql:"foo"` - Bar time.Time `xsql:"bar"` - Baz bool `xsql:"baz" json:"-"` + Id int `xsql:"id"` + Foo string `xsql:"foo"` + Bar time.Time `xsql:"bar"` + Bool bool `xsql:"bool" json:"-"` + Enum Enum `xsql:"enum" json:"-"` } func (t Test) TableName() string { @@ -63,11 +66,12 @@ CREATE TABLE #xsql# ( #id# int unsigned NOT NULL AUTO_INCREMENT, #foo# varchar(255) DEFAULT NULL, #bar# datetime DEFAULT NULL, - #baz# int NOT NULL DEFAULT '0', + #bool# int NOT NULL DEFAULT '0', + #enum# int NOT NULL DEFAULT '0', PRIMARY KEY (#id#) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; -INSERT INTO #xsql# (#id#, #foo#, #bar#, #baz#) VALUES (1, 'v', '2022-04-14 23:49:48', 1); -INSERT INTO #xsql# (#id#, #foo#, #bar#, #baz#) VALUES (2, 'v1', '2022-04-14 23:50:00', 1); +INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (1, 'v', '2022-04-14 23:49:48', 1, 1); +INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (2, 'v1', '2022-04-14 23:50:00', 1, 1); ` DB := newDB() _, err := DB.Exec(strings.ReplaceAll(q, "#", "`")) @@ -286,7 +290,11 @@ func TestFirst(t *testing.T) { b, _ := json.Marshal(test) a.Equal(string(b), `{"Id":1,"Foo":"v","Bar":"2022-04-14T23:49:48Z"}`) - a.Equal(test.Baz, true) + // bool + a.Equal(test.Bool, true) + // enum + a.IsType(Enum(0), test.Enum) + a.Equal(Enum(1), test.Enum) } func TestFirstEmbedding(t *testing.T) { @@ -315,7 +323,8 @@ func TestFirstPart(t *testing.T) { log.Fatal(err) } - a.Equal(fmt.Sprintf("%+v", test), "{Id:0 Foo:v Bar:0001-01-01 00:00:00 +0000 UTC Baz:false}") + b, _ := json.Marshal(test) + a.Equal(string(b), "{\"Id\":0,\"Foo\":\"v\",\"Bar\":\"0001-01-01T00:00:00Z\"}") } func TestFirstTableKey(t *testing.T) { @@ -344,7 +353,8 @@ func TestFind(t *testing.T) { log.Fatal(err) } - a.Equal(fmt.Sprintf("%+v", tests), `[{Id:1 Foo:v Bar:2022-04-14 23:49:48 +0000 UTC Baz:true} {Id:2 Foo:v1 Bar:2022-04-14 23:50:00 +0000 UTC Baz:true}]`) + b, _ := json.Marshal(tests) + a.Equal(string(b), "[{\"Id\":1,\"Foo\":\"v\",\"Bar\":\"2022-04-14T23:49:48Z\"},{\"Id\":2,\"Foo\":\"v1\",\"Bar\":\"2022-04-14T23:50:00Z\"}]") } func TestEmbeddingFind(t *testing.T) { @@ -373,7 +383,8 @@ func TestFindPart(t *testing.T) { log.Fatal(err) } - a.Equal(fmt.Sprintf("%+v", tests), `[{Id:0 Foo:v Bar:0001-01-01 00:00:00 +0000 UTC Baz:false} {Id:0 Foo:v1 Bar:0001-01-01 00:00:00 +0000 UTC Baz:false}]`) + b, _ := json.Marshal(tests) + a.Equal(string(b), "[{\"Id\":0,\"Foo\":\"v\",\"Bar\":\"0001-01-01T00:00:00Z\"},{\"Id\":0,\"Foo\":\"v1\",\"Bar\":\"0001-01-01T00:00:00Z\"}]") } func TestFindTableKey(t *testing.T) { @@ -387,7 +398,8 @@ func TestFindTableKey(t *testing.T) { log.Fatal(err) } - a.Equal(fmt.Sprintf("%+v", tests), `[{Id:1 Foo:v Bar:2022-04-14 23:49:48 +0000 UTC Baz:true} {Id:2 Foo:v1 Bar:2022-04-14 23:50:00 +0000 UTC Baz:true}]`) + b, _ := json.Marshal(tests) + a.Equal(string(b), "[{\"Id\":1,\"Foo\":\"v\",\"Bar\":\"2022-04-14T23:49:48Z\"},{\"Id\":2,\"Foo\":\"v1\",\"Bar\":\"2022-04-14T23:50:00Z\"}]") } func TestTxCommit(t *testing.T) { diff --git a/src/xsql/fetcher.go b/src/xsql/fetcher.go index 83cd641..69c2ce8 100644 --- a/src/xsql/fetcher.go +++ b/src/xsql/fetcher.go @@ -22,7 +22,7 @@ func (t *Fetcher) First(i interface{}) error { return errors.New("xsql: argument can only be pointer type") } rootValue := value.Elem() - rootType := reflect.TypeOf(i).Elem() + rootType := rootValue.Type() rows, err := t.Rows() if err != nil { @@ -297,18 +297,18 @@ func (t *Fetcher) foreach(row *Row, value reflect.Value, typ reflect.Type) error if !row.Exist(tag) { continue } - if err := t.mapped(fieldValue, row, tag); err != nil { + if err := t.mapped(row, tag, fieldValue, fieldValue.Type()); err != nil { return err } } return nil } -func (t *Fetcher) mapped(field reflect.Value, row *Row, tag string) (err error) { +func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect.Type) (err error) { res := row.Get(tag) v := res.Value() - switch field.Kind() { + switch value.Kind() { case reflect.Int: v = int(res.Int()) break @@ -347,7 +347,7 @@ func (t *Fetcher) mapped(field reflect.Value, row *Row, tag string) (err error) break default: if !res.Empty() && - field.Type().String() == "time.Time" && + typ.String() == "time.Time" && reflect.ValueOf(v).Type().String() != "time.Time" { if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil { v = t @@ -363,7 +363,7 @@ func (t *Fetcher) mapped(field reflect.Value, row *Row, tag string) (err error) err = fmt.Errorf("type mismatch for field %s: %v", tag, e) } }() - field.Set(reflect.ValueOf(v)) + value.Set(reflect.ValueOf(v).Convert(value.Type())) return }