Skip to content

Commit

Permalink
Update xsql enum
Browse files Browse the repository at this point in the history
  • Loading branch information
onanying committed Jan 17, 2024
1 parent f56c368 commit 67e4a91
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
36 changes: 24 additions & 12 deletions src/xsql/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ import (
"time"
)

type Enum int32

type Test struct {
Id int `xsql:"id"`
Foo string `xsql:"foo"`
Bar time.Time `xsql:"bar"`
Baz bool `xsql:"baz" json:"-"`
Id int `xsql:"id"`
Foo string `xsql:"foo"`
Bar time.Time `xsql:"bar"`
Bool bool `xsql:"bool" json:"-"`
Enum Enum `xsql:"enum" json:"-"`
}

func (t Test) TableName() string {
Expand Down Expand Up @@ -63,11 +66,12 @@ CREATE TABLE #xsql# (
#id# int unsigned NOT NULL AUTO_INCREMENT,
#foo# varchar(255) DEFAULT NULL,
#bar# datetime DEFAULT NULL,
#baz# int NOT NULL DEFAULT '0',
#bool# int NOT NULL DEFAULT '0',
#enum# int NOT NULL DEFAULT '0',
PRIMARY KEY (#id#)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
INSERT INTO #xsql# (#id#, #foo#, #bar#, #baz#) VALUES (1, 'v', '2022-04-14 23:49:48', 1);
INSERT INTO #xsql# (#id#, #foo#, #bar#, #baz#) VALUES (2, 'v1', '2022-04-14 23:50:00', 1);
INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (1, 'v', '2022-04-14 23:49:48', 1, 1);
INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (2, 'v1', '2022-04-14 23:50:00', 1, 1);
`
DB := newDB()
_, err := DB.Exec(strings.ReplaceAll(q, "#", "`"))
Expand Down Expand Up @@ -286,7 +290,11 @@ func TestFirst(t *testing.T) {

b, _ := json.Marshal(test)
a.Equal(string(b), `{"Id":1,"Foo":"v","Bar":"2022-04-14T23:49:48Z"}`)
a.Equal(test.Baz, true)
// bool
a.Equal(test.Bool, true)
// enum
a.IsType(Enum(0), test.Enum)
a.Equal(Enum(1), test.Enum)
}

func TestFirstEmbedding(t *testing.T) {
Expand Down Expand Up @@ -315,7 +323,8 @@ func TestFirstPart(t *testing.T) {
log.Fatal(err)
}

a.Equal(fmt.Sprintf("%+v", test), "{Id:0 Foo:v Bar:0001-01-01 00:00:00 +0000 UTC Baz:false}")
b, _ := json.Marshal(test)
a.Equal(string(b), "{\"Id\":0,\"Foo\":\"v\",\"Bar\":\"0001-01-01T00:00:00Z\"}")
}

func TestFirstTableKey(t *testing.T) {
Expand Down Expand Up @@ -344,7 +353,8 @@ func TestFind(t *testing.T) {
log.Fatal(err)
}

a.Equal(fmt.Sprintf("%+v", tests), `[{Id:1 Foo:v Bar:2022-04-14 23:49:48 +0000 UTC Baz:true} {Id:2 Foo:v1 Bar:2022-04-14 23:50:00 +0000 UTC Baz:true}]`)
b, _ := json.Marshal(tests)
a.Equal(string(b), "[{\"Id\":1,\"Foo\":\"v\",\"Bar\":\"2022-04-14T23:49:48Z\"},{\"Id\":2,\"Foo\":\"v1\",\"Bar\":\"2022-04-14T23:50:00Z\"}]")
}

func TestEmbeddingFind(t *testing.T) {
Expand Down Expand Up @@ -373,7 +383,8 @@ func TestFindPart(t *testing.T) {
log.Fatal(err)
}

a.Equal(fmt.Sprintf("%+v", tests), `[{Id:0 Foo:v Bar:0001-01-01 00:00:00 +0000 UTC Baz:false} {Id:0 Foo:v1 Bar:0001-01-01 00:00:00 +0000 UTC Baz:false}]`)
b, _ := json.Marshal(tests)
a.Equal(string(b), "[{\"Id\":0,\"Foo\":\"v\",\"Bar\":\"0001-01-01T00:00:00Z\"},{\"Id\":0,\"Foo\":\"v1\",\"Bar\":\"0001-01-01T00:00:00Z\"}]")
}

func TestFindTableKey(t *testing.T) {
Expand All @@ -387,7 +398,8 @@ func TestFindTableKey(t *testing.T) {
log.Fatal(err)
}

a.Equal(fmt.Sprintf("%+v", tests), `[{Id:1 Foo:v Bar:2022-04-14 23:49:48 +0000 UTC Baz:true} {Id:2 Foo:v1 Bar:2022-04-14 23:50:00 +0000 UTC Baz:true}]`)
b, _ := json.Marshal(tests)
a.Equal(string(b), "[{\"Id\":1,\"Foo\":\"v\",\"Bar\":\"2022-04-14T23:49:48Z\"},{\"Id\":2,\"Foo\":\"v1\",\"Bar\":\"2022-04-14T23:50:00Z\"}]")
}

func TestTxCommit(t *testing.T) {
Expand Down
12 changes: 6 additions & 6 deletions src/xsql/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (t *Fetcher) First(i interface{}) error {
return errors.New("xsql: argument can only be pointer type")
}
rootValue := value.Elem()
rootType := reflect.TypeOf(i).Elem()
rootType := rootValue.Type()

rows, err := t.Rows()
if err != nil {
Expand Down Expand Up @@ -297,18 +297,18 @@ func (t *Fetcher) foreach(row *Row, value reflect.Value, typ reflect.Type) error
if !row.Exist(tag) {
continue
}
if err := t.mapped(fieldValue, row, tag); err != nil {
if err := t.mapped(row, tag, fieldValue, fieldValue.Type()); err != nil {
return err
}
}
return nil
}

func (t *Fetcher) mapped(field reflect.Value, row *Row, tag string) (err error) {
func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect.Type) (err error) {
res := row.Get(tag)
v := res.Value()

switch field.Kind() {
switch value.Kind() {
case reflect.Int:
v = int(res.Int())
break
Expand Down Expand Up @@ -347,7 +347,7 @@ func (t *Fetcher) mapped(field reflect.Value, row *Row, tag string) (err error)
break
default:
if !res.Empty() &&
field.Type().String() == "time.Time" &&
typ.String() == "time.Time" &&
reflect.ValueOf(v).Type().String() != "time.Time" {
if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil {
v = t
Expand All @@ -363,7 +363,7 @@ func (t *Fetcher) mapped(field reflect.Value, row *Row, tag string) (err error)
err = fmt.Errorf("type mismatch for field %s: %v", tag, e)
}
}()
field.Set(reflect.ValueOf(v))
value.Set(reflect.ValueOf(v).Convert(value.Type()))

return
}

0 comments on commit 67e4a91

Please sign in to comment.