Skip to content

Commit

Permalink
feat: add sync.Map (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
ginokent authored Sep 11, 2023
2 parents 867625a + 189a2a1 commit 37e385e
Show file tree
Hide file tree
Showing 18 changed files with 945 additions and 218 deletions.
21 changes: 21 additions & 0 deletions database/sql/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package sqlz

import "bytes"

type testUser struct {
*bytes.Buffer // Anonymous

UserID int `testdb:"user_id"`
Username string `testdb:"username"`
NullString *string `testdb:"null_string"`
Hyphen string `testdb:"-"`
NoTag string
}

var (
_testUserTableName = "test_user"
_testUserColumns = []string{"user_id", "username", "null_string"}
)

func (*testUser) TableName() string { return _testUserTableName }
func (*testUser) Columns() []string { return _testUserColumns }
9 changes: 0 additions & 9 deletions database/sql/queryer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package sqlz

import (
"context"
"database/sql"
"fmt"
)

Expand Down Expand Up @@ -81,13 +80,5 @@ func (qc *queryerContext) queryRowContext(rows sqlRows, queryContextErr error, d
}
defer rows.Close()

// behaver like *sql.Row
if !rows.Next() {
if err := rows.Err(); err != nil {
return err //nolint:wrapcheck
}
return sql.ErrNoRows
}

return ScanRows(rows, qc.structTag, dst)
}
42 changes: 6 additions & 36 deletions database/sql/queryer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@ func Test_DB_QueryContext(t *testing.T) {
t.Parallel()
t.Run("failure,sql.ErrNoRows", func(t *testing.T) {
t.Parallel()
type user struct {
UserID int `testdb:"user_id"`
Username string `testdb:"username"`
NullString *string `testdb:"null_string"`
}
var u []*user
var u []*testUser
if err := NewDB(&sqlDBMock{Rows: nil, Error: sql.ErrNoRows}).QueryContext(context.Background(), &u, "SELECT * FROM users"); !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("❌: QueryContext: %v", err)
}
Expand All @@ -32,12 +27,7 @@ func Test_DB_queryContext(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
type user struct {
UserID int `testdb:"user_id"`
Username string `testdb:"username"`
NullString *string `testdb:"null_string"`
}
var u []user
var u []testUser
db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("testdb"))
i := 0
columns := []string{"user_id", "username", "null_string"}
Expand Down Expand Up @@ -87,12 +77,7 @@ func Test_DB_QueryRowContext(t *testing.T) {
t.Parallel()
t.Run("failure,sql.ErrNoRows", func(t *testing.T) {
t.Parallel()
type user struct {
UserID int `testdb:"user_id"`
Username string `testdb:"username"`
NullString *string `testdb:"null_string"`
}
var u user
var u testUser
if err := NewDB(&sqlDBMock{Rows: nil, Error: sql.ErrNoRows}).QueryRowContext(context.Background(), &u, "SELECT * FROM users"); !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("❌: QueryRowContext: %v", err)
}
Expand All @@ -103,12 +88,7 @@ func Test_DB_queryRowContext(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
type user struct {
UserID int `testdb:"user_id"`
Username string `testdb:"username"`
NullString *string `testdb:"null_string"`
}
var u user
var u testUser
db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("testdb"))
i := 0
columns := []string{"user_id", "username", "null_string"}
Expand Down Expand Up @@ -140,12 +120,7 @@ func Test_DB_queryRowContext(t *testing.T) {
})
t.Run("failure,sql.ErrNoRows", func(t *testing.T) {
t.Parallel()
type user struct {
UserID int `testdb:"user_id"`
Username string `testdb:"username"`
NullString *string `testdb:"null_string"`
}
var u user
var u testUser
db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("testdb"))
rows := &sqlRowsMock{
NextFunc: func() bool { return false },
Expand All @@ -157,12 +132,7 @@ func Test_DB_queryRowContext(t *testing.T) {
})
t.Run("failure,context.Canceled", func(t *testing.T) {
t.Parallel()
type user struct {
UserID int `testdb:"user_id"`
Username string `testdb:"username"`
NullString *string `testdb:"null_string"`
}
var u user
var u testUser
db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("testdb"))
rows := &sqlRowsMock{
NextFunc: func() bool { return false },
Expand Down
96 changes: 79 additions & 17 deletions database/sql/rows.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package sqlz

import (
"database/sql"
"fmt"
"reflect"
"strings"
"sync"
)

// ScanRows scans rows to dst.
//
// structTag is used to get column names from struct tags.
//
// dst must be a pointer.
// If dst is a pointer to a struct or a slice of struct, column names are got from structTag.
// If dst is a pointer to a slice of primitive, ignore structTag.
//
//nolint:cyclop
func ScanRows(rows sqlRows, structTag string, dst interface{}) error {
pointer := reflect.ValueOf(dst) // expect *Type or *[]Type or *[]*Type
if pointer.Kind() != reflect.Pointer {
Expand All @@ -25,6 +33,13 @@ func ScanRows(rows sqlRows, structTag string, dst interface{}) error {
return fmt.Errorf("scanRowsToSlice: type=%T: %w", dst, err)
}
case reflect.Struct:
// behaver like *sql.Row
if !rows.Next() {
if err := rows.Err(); err != nil {
return err //nolint:wrapcheck
}
return sql.ErrNoRows
}
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("rows.Columns: %w", err)
Expand All @@ -34,39 +49,83 @@ func ScanRows(rows sqlRows, structTag string, dst interface{}) error {
if err := scanRowsToStruct(rows, columns, dests, tags, deref); err != nil { // expect Type (or *Type)
return fmt.Errorf("scanRowsToStruct: type=%T: %w", dst, err)
}
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.String: // primitives
// behaver like *sql.Row
if !rows.Next() {
if err := rows.Err(); err != nil {
return err //nolint:wrapcheck
}
return sql.ErrNoRows
}
if err := rows.Scan(dst); err != nil {
return fmt.Errorf("rows.Scan: %w", err)
}
return nil
default:
return fmt.Errorf("type=%T: %w", dst, ErrDataTypeNotSupported)
}
return nil
}

func scanRowsToSlice(rows sqlRows, structTag string, destStructSlice reflect.Value) error { // destStructSlice: []Type (or []*Type)
structType := destStructSlice.Type().Elem() // Type (or *Type) <- []Type (or []*Type)
var sliceContentIsPointer bool
if structType.Kind() == reflect.Pointer {
sliceContentIsPointer = true
structType = structType.Elem() // Type <- *Type
elementType := destStructSlice.Type().Elem() // Type (or *Type) <- []Type (or []*Type)
var elementIsPointer bool
if elementType.Kind() == reflect.Pointer {
elementIsPointer = true
elementType = elementType.Elem() // Type <- *Type
}

if structType.Kind() != reflect.Struct {
switch elementType.Kind() { //nolint:exhaustive
case reflect.Struct:
if err := scanRowsToStructSlice(rows, structTag, elementType, elementIsPointer, destStructSlice); err != nil { // expect []Type (or []*Type)
return fmt.Errorf("scanRowsToStructSlice: %w", err)
}
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.String: // primitives
slice, put := getReflectValueSlice()
defer put()
for rows.Next() {
v := reflect.Indirect(reflect.New(elementType))
if err := rows.Scan(v.Addr().Interface()); err != nil {
return fmt.Errorf("rows.Scan: %w", err)
}
if elementIsPointer {
slice.Slice = append(slice.Slice, v.Addr())
} else {
slice.Slice = append(slice.Slice, v)
}
}
destStructSlice.Set(reflect.Append(destStructSlice, slice.Slice...))
return nil
default:
// TODO: support other types
return fmt.Errorf("elem=%s, expected=%s: %w", structType.Kind(), reflect.Struct, ErrDataTypeNotSupported)
return fmt.Errorf("elem=%s, expected=%s: %w", elementType.Kind(), reflect.Struct, ErrDataTypeNotSupported)
}
return nil
}

func scanRowsToStructSlice(rows sqlRows, structTag string, elementType reflect.Type, elementIsPointer bool, destStructSlice reflect.Value) error { // destStructSlice: []Type (or []*Type)
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("rows.Columns: %w", err)
}
dests := make([]interface{}, len(columns))
tags := getStructTags(structType, structTag)
tags := getStructTags(elementType, structTag)
slice, put := getReflectValueSlice()
defer put()
for rows.Next() {
v := reflect.Indirect(reflect.New(structType))
v := reflect.Indirect(reflect.New(elementType))
if err := scanRowsToStruct(rows, columns, dests, tags, v); err != nil {
return fmt.Errorf("scanRowsToStruct: %w", err)
}
if sliceContentIsPointer {
if elementIsPointer {
slice.Slice = append(slice.Slice, v.Addr())
} else {
slice.Slice = append(slice.Slice, v)
Expand Down Expand Up @@ -94,18 +153,21 @@ func scanRowsToStruct(rows sqlRows, columns []string, dests []interface{}, tags
}

//nolint:gochecknoglobals
var tagsMap sync.Map
var (
tagsCache sync.Map
)

func getStructTags(t reflect.Type, structTag string) []string {
if tags, ok := tagsMap.Load(t); ok {
func getStructTags(structType reflect.Type, structTag string) []string {
if tags, ok := tagsCache.Load(structType); ok {
return tags.([]string) //nolint:forcetypeassert
}

tags := make([]string, t.NumField())
for i := 0; t.NumField() > i; i++ {
tags[i] = t.Field(i).Tag.Get(structTag)
tags := make([]string, structType.NumField())
for i := 0; structType.NumField() > i; i++ {
rawTag := structType.Field(i).Tag.Get(structTag)
tags[i] = strings.Split(rawTag, ",")[0]
}
tagsMap.Store(t, tags)
tagsCache.Store(structType, tags)
return tags
}

Expand Down
Loading

0 comments on commit 37e385e

Please sign in to comment.