From 8d7f164b56ed1affc2628087e1b5d662e9078d6b Mon Sep 17 00:00:00 2001 From: ginokent <29125616+ginokent@users.noreply.github.com> Date: Sat, 26 Aug 2023 10:19:59 +0900 Subject: [PATCH 1/2] BREAKING CHANGE: fix database/sql funcs --- .golangci.yml | 3 +- database/sql/db.go | 19 +++++ database/sql/db_test.go | 74 +++++++++++++++++ database/sql/interface.go | 22 +++++ database/sql/mock_test.go | 150 +++++++++++++++++++++++++++++++++++ database/sql/queryer.go | 72 ++++++++++++----- database/sql/queryer_test.go | 109 ++++++++++--------------- database/sql/rows.go | 24 +++--- database/sql/rows_test.go | 34 ++++---- database/sql/sql.go | 26 ------ database/sql/tx.go | 2 +- database/sql/tx_test.go | 4 +- 12 files changed, 398 insertions(+), 141 deletions(-) create mode 100644 database/sql/db.go create mode 100644 database/sql/db_test.go create mode 100644 database/sql/interface.go create mode 100644 database/sql/mock_test.go delete mode 100644 database/sql/sql.go diff --git a/.golangci.yml b/.golangci.yml index ca408241..dade5bd1 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -19,7 +19,7 @@ linters: - ifshort # for readability - interfacer # deprecated https://github.com/mvdan/interfacer - interfacebloat # unnecessary - - ireturn # unnecessary + - ireturn # TODO: enable - lll # unnecessary - maligned # deprecated https://github.com/mdempsky/maligned - nlreturn # ignore "return with no blank line before" @@ -58,5 +58,6 @@ issues: - maintidx - noctx - revive + - testpackage - varnamelen - wrapcheck diff --git a/database/sql/db.go b/database/sql/db.go new file mode 100644 index 00000000..a000b52e --- /dev/null +++ b/database/sql/db.go @@ -0,0 +1,19 @@ +package sqlz + +import ( + "context" + "database/sql" +) + +func MustOpen(ctx context.Context, driverName string, dataSourceName string) *sql.DB { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + panic(err) + } + + if err := db.PingContext(ctx); err != nil { + panic(err) + } + + return db +} diff --git a/database/sql/db_test.go b/database/sql/db_test.go new file mode 100644 index 00000000..a2a34a84 --- /dev/null +++ b/database/sql/db_test.go @@ -0,0 +1,74 @@ +package sqlz //nolint:testpackage + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "strings" + "testing" +) + +func TestMustOpen(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + sql.Register(t.Name(), &driverDriverMock{ + OpenFunc: func(name string) (driver.Conn, error) { + return &driverConnMock{ + PrepareFunc: func(query string) (driver.Stmt, error) { + return &driverStmtMock{ + CloseFunc: func() error { + return nil + }, + ExecFunc: func(args []driver.Value) (driver.Result, error) { + return &driverResultMock{}, nil + }, + QueryFunc: func(args []driver.Value) (driver.Rows, error) { + return &driverRowsMock{}, nil + }, + }, nil + }, + }, nil + }, + }) + + ctx := context.Background() + db := MustOpen(ctx, t.Name(), ":memory:") + if db == nil { + t.Fatalf("❌: MustOpen: db == nil") + } + }) + + t.Run("failure,sqlUnknownDriver", func(t *testing.T) { + t.Parallel() + + defer func() { + const expect = "sql: unknown driver" + if actual := fmt.Sprintf("%v", recover()); !strings.Contains(actual, expect) { + t.Errorf("❌: recover: expect(%v) != actual(%s)", expect, actual) + } + }() + + MustOpen(context.Background(), t.Name(), "") + }) + + t.Run("failure,sqlDriverOpenError", func(t *testing.T) { + t.Parallel() + + defer func() { + expect := context.Canceled + if actual := fmt.Sprintf("%v", recover()); fmt.Sprint(expect) != fmt.Sprint(actual) { + t.Errorf("❌: recover: expect(%v) != actual(%s)", expect, actual) + } + }() + + sql.Register(t.Name(), &driverDriverMock{}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + MustOpen(ctx, t.Name(), "") + }) +} diff --git a/database/sql/interface.go b/database/sql/interface.go new file mode 100644 index 00000000..3da63939 --- /dev/null +++ b/database/sql/interface.go @@ -0,0 +1,22 @@ +package sqlz + +import ( + "context" + "database/sql" +) + +type sqlQueryerContext interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +type sqlRows interface { + Close() error + Columns() ([]string, error) + Next() bool + Scan(...interface{}) error + Err() error +} + +type sqlTxBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} diff --git a/database/sql/mock_test.go b/database/sql/mock_test.go new file mode 100644 index 00000000..2f768cc7 --- /dev/null +++ b/database/sql/mock_test.go @@ -0,0 +1,150 @@ +package sqlz //nolint:testpackage + +import ( + "context" + "database/sql" + "database/sql/driver" +) + +type driverDriverMock struct { + OpenFunc func(name string) (driver.Conn, error) +} + +func (m *driverDriverMock) Open(name string) (driver.Conn, error) { + return m.OpenFunc(name) +} + +var _ driver.Driver = (*driverDriverMock)(nil) + +type driverConnMock struct { + PrepareFunc func(query string) (driver.Stmt, error) + CloseFunc func() error + BeginFunc func() (driver.Tx, error) +} + +func (m *driverConnMock) Prepare(query string) (driver.Stmt, error) { + return m.PrepareFunc(query) +} + +func (m *driverConnMock) Close() error { + return m.CloseFunc() +} + +func (m *driverConnMock) Begin() (driver.Tx, error) { + return m.BeginFunc() +} + +var _ driver.Conn = (*driverConnMock)(nil) + +type driverStmtMock struct { + CloseFunc func() error + NumInputFunc func() int + ExecFunc func(args []driver.Value) (driver.Result, error) + QueryFunc func(args []driver.Value) (driver.Rows, error) +} + +func (m *driverStmtMock) Close() error { + return m.CloseFunc() +} + +func (m *driverStmtMock) NumInput() int { + return m.NumInputFunc() +} + +func (m *driverStmtMock) Exec(args []driver.Value) (driver.Result, error) { + return m.ExecFunc(args) +} + +func (m *driverStmtMock) Query(args []driver.Value) (driver.Rows, error) { + return m.QueryFunc(args) +} + +var _ driver.Stmt = (*driverStmtMock)(nil) + +type driverResultMock struct { + LastInsertIdFunc func() (int64, error) //nolint:stylecheck // NOTE: sql.Result has LastInsertId method + RowsAffectedFunc func() (int64, error) +} + +func (m *driverResultMock) LastInsertId() (int64, error) { + return m.LastInsertIdFunc() +} + +func (m *driverResultMock) RowsAffected() (int64, error) { + return m.RowsAffectedFunc() +} + +var _ driver.Result = (*driverResultMock)(nil) + +type driverRowsMock struct { + CloseFunc func() error + ColumnsFunc func() []string + NextFunc func(dest []driver.Value) error +} + +func (m *driverRowsMock) Close() error { + return m.CloseFunc() +} + +func (m *driverRowsMock) Columns() []string { + return m.ColumnsFunc() +} + +func (m *driverRowsMock) Next(dest []driver.Value) error { + return m.NextFunc(dest) +} + +var _ driver.Rows = (*driverRowsMock)(nil) + +type sqlDBMock struct { + Rows *sql.Rows + Error error + + BeginTxFunc func(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +var ( + _ sqlQueryerContext = (*sqlDBMock)(nil) + _ sqlTxBeginner = (*sqlDBMock)(nil) +) + +func (m *sqlDBMock) QueryContext(_ context.Context, _ string, _ ...interface{}) (*sql.Rows, error) { + return m.Rows, m.Error +} + +func (m *sqlDBMock) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return m.BeginTxFunc(ctx, opts) +} + +type sqlRowsMock struct { + CloseFunc func() error + ColumnsFunc func() ([]string, error) + NextFunc func() bool + ScanFunc func(dest ...interface{}) error + ErrFunc func() error +} + +var _ sqlRows = (*sqlRowsMock)(nil) + +func (m *sqlRowsMock) Close() error { + if m.CloseFunc == nil { + return nil + } + return m.CloseFunc() +} + +func (m *sqlRowsMock) Columns() ([]string, error) { + return m.ColumnsFunc() +} + +func (m *sqlRowsMock) Next() bool { + return m.NextFunc() +} + +func (m *sqlRowsMock) Scan(dest ...interface{}) error { + return m.ScanFunc(dest...) +} + +func (m *sqlRowsMock) Err() error { + return m.ErrFunc() +} diff --git a/database/sql/queryer.go b/database/sql/queryer.go index 3ec7e4d8..ff232ee5 100644 --- a/database/sql/queryer.go +++ b/database/sql/queryer.go @@ -6,49 +6,85 @@ import ( "fmt" ) -type Queryer interface { - QueryStructSliceContext(ctx context.Context, structTag string, destStructSlicePointer interface{}, query string, args ...any) error - QueryStructContext(ctx context.Context, structTag string, destStructPointer interface{}, query string, args ...any) error +type QueryerContext interface { + // QueryContext executes a query that returns rows, typically a SELECT. + // + // The dst must be a pointer. + // The args are for any placeholder parameters in the query. + QueryContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error + // QueryRowContext executes a query that is expected to return at most one row. + // It always returns a non-nil value or an error. + // + // The dst must be a pointer. + // The args are for any placeholder parameters in the query. + QueryRowContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error } -type _Queryer struct { - SQLQueryer +type queryerContext struct { + sqlQueryer sqlQueryerContext + + structTag string } -func NewDB(db SQLQueryer) Queryer { - return &_Queryer{ - SQLQueryer: db, +type NewDBOption func(qc *queryerContext) + +const defaultStructTag = "db" + +func WithNewDBOptionStructTag(structTag string) NewDBOption { + return func(qc *queryerContext) { + qc.structTag = structTag } } -func (s *_Queryer) QueryStructSliceContext(ctx context.Context, structTag string, destStructSlicePointer interface{}, query string, args ...any) error { - rows, err := s.QueryContext(ctx, query, args...) //nolint:rowserrcheck - return s.queryStructSliceContext(rows, err, structTag, destStructSlicePointer) +func NewDB(db sqlQueryerContext, opts ...NewDBOption) QueryerContext { //nolint:ireturn + return newDB(db, opts...) +} + +func newDB(db sqlQueryerContext, opts ...NewDBOption) *queryerContext { + qc := &queryerContext{ + sqlQueryer: db, + structTag: defaultStructTag, + } + + for _, opt := range opts { + opt(qc) + } + + return qc +} + +func (qc *queryerContext) QueryContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error { + rows, err := qc.sqlQueryer.QueryContext(ctx, query, args...) //nolint:rowserrcheck + return qc.queryContext(rows, err, dst) } -func (s *_Queryer) queryStructSliceContext(rows SQLRows, queryContextErr error, structTag string, destStructSlicePointer interface{}) error { +func (qc *queryerContext) queryContext(rows sqlRows, queryContextErr error, dst interface{}) error { if queryContextErr != nil { return fmt.Errorf("QueryContext: %w", queryContextErr) } defer rows.Close() - return ScanRows(rows, structTag, destStructSlicePointer) + return ScanRows(rows, qc.structTag, dst) } -func (s *_Queryer) QueryStructContext(ctx context.Context, structTag string, destStructPointer interface{}, query string, args ...any) error { - rows, err := s.QueryContext(ctx, query, args...) //nolint:rowserrcheck - return s.queryStructContext(rows, err, structTag, destStructPointer) +func (qc *queryerContext) QueryRowContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error { + rows, err := qc.sqlQueryer.QueryContext(ctx, query, args...) //nolint:rowserrcheck + return qc.queryRowContext(rows, err, dst) } -func (s *_Queryer) queryStructContext(rows SQLRows, queryContextErr error, structTag string, destStructPointer interface{}) error { +func (qc *queryerContext) queryRowContext(rows sqlRows, queryContextErr error, dst interface{}) error { if queryContextErr != nil { return fmt.Errorf("QueryContext: %w", queryContextErr) } defer rows.Close() + // behaver like *sql.Row if !rows.Next() { + if err := rows.Err(); err != nil { + return err //nolint:wrapcheck + } return sql.ErrNoRows } - return ScanRows(rows, structTag, destStructPointer) + return ScanRows(rows, qc.structTag, dst) } diff --git a/database/sql/queryer_test.go b/database/sql/queryer_test.go index b5ccdb89..ce7a5b63 100644 --- a/database/sql/queryer_test.go +++ b/database/sql/queryer_test.go @@ -9,50 +9,7 @@ import ( "testing" ) -type mockDB struct { - SQLQueryer - SQLTxBeginner - - Rows *sql.Rows - Error error - - BeginTxFunc func(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -func (m *mockDB) QueryContext(_ context.Context, _ string, _ ...interface{}) (*sql.Rows, error) { - return m.Rows, m.Error -} - -func (m *mockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - return m.BeginTxFunc(ctx, opts) -} - -type mockRows struct { - SQLRows - CloseError error - ColumnsReturn []string - ColumnsError error - NextFunc func() bool - ScanFunc func(dest ...interface{}) error -} - -func (m *mockRows) Close() error { - return m.CloseError -} - -func (m *mockRows) Columns() ([]string, error) { - return m.ColumnsReturn, m.ColumnsError -} - -func (m *mockRows) Next() bool { - return m.NextFunc() -} - -func (m *mockRows) Scan(dest ...interface{}) error { - return m.ScanFunc(dest...) -} - -func Test_DB_QueryStructSliceContext(t *testing.T) { +func Test_DB_QueryContext(t *testing.T) { t.Parallel() t.Run("failure,sql.ErrNoRows", func(t *testing.T) { t.Parallel() @@ -61,13 +18,13 @@ func Test_DB_QueryStructSliceContext(t *testing.T) { Username string `db:"username"` } var u []*user - if err := NewDB(&mockDB{Rows: nil, Error: sql.ErrNoRows}).QueryStructSliceContext(context.Background(), "db", &u, "SELECT * FROM users"); !errors.Is(err, sql.ErrNoRows) { - t.Fatalf("❌: QueryStructSliceContext: %v", err) + if err := NewDB(&sqlDBMock{Rows: nil, Error: sql.ErrNoRows}).QueryContext(context.Background(), &u, "SELECT * FROM users"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("❌: QueryContext: %v", err) } }) } -func Test_DB_queryStructSliceContext(t *testing.T) { +func Test_DB_queryContext(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { t.Parallel() @@ -76,14 +33,15 @@ func Test_DB_queryStructSliceContext(t *testing.T) { Username string `db:"username"` } var u user - db := &_Queryer{} + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 }, - ColumnsReturn: []string{"user_id", "username"}, + + ColumnsFunc: func() ([]string, error) { return []string{"user_id", "username"}, nil }, ScanFunc: func(dest ...interface{}) error { for i := range dest { reflect.ValueOf(dest[i]).Elem().SetString("column" + strconv.Itoa(i)) @@ -91,14 +49,14 @@ func Test_DB_queryStructSliceContext(t *testing.T) { return nil }, } - if err := db.queryStructSliceContext(rows, nil, "db", &u); err != nil { - t.Fatalf("❌: queryStructSliceContext: %v", err) + if err := db.queryContext(rows, nil, &u); err != nil { + t.Fatalf("❌: queryContext: %v", err) } - t.Logf("✅: queryStructSliceContext: %+v", u) + t.Logf("✅: queryContext: %+v", u) }) } -func Test_DB_QueryStructContext(t *testing.T) { +func Test_DB_QueryRowContext(t *testing.T) { t.Parallel() t.Run("failure,sql.ErrNoRows", func(t *testing.T) { t.Parallel() @@ -107,13 +65,13 @@ func Test_DB_QueryStructContext(t *testing.T) { Username string `db:"username"` } var u user - if err := NewDB(&mockDB{Rows: nil, Error: sql.ErrNoRows}).QueryStructContext(context.Background(), "db", &u, "SELECT * FROM users"); !errors.Is(err, sql.ErrNoRows) { - t.Fatalf("❌: QueryStructContext: %v", err) + if err := NewDB(&sqlDBMock{Rows: nil, Error: sql.ErrNoRows}).QueryRowContext(context.Background(), &u, "SELECT * FROM users"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("❌: QueryRowContext: %v", err) } }) } -func Test_DB_queryStructContext(t *testing.T) { +func Test_DB_queryRowContext(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { t.Parallel() @@ -122,14 +80,14 @@ func Test_DB_queryStructContext(t *testing.T) { Username string `db:"username"` } var u user - db := &_Queryer{} + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 }, - ColumnsReturn: []string{"user_id", "username"}, + ColumnsFunc: func() ([]string, error) { return []string{"user_id", "username"}, nil }, ScanFunc: func(dest ...interface{}) error { for i := range dest { reflect.ValueOf(dest[i]).Elem().SetString("column" + strconv.Itoa(i)) @@ -137,10 +95,10 @@ func Test_DB_queryStructContext(t *testing.T) { return nil }, } - if err := db.queryStructContext(rows, nil, "db", &u); err != nil { - t.Fatalf("❌: queryStructContext: err != nil: %v", err) + if err := db.queryRowContext(rows, nil, &u); err != nil { + t.Fatalf("❌: queryRowContext: err != nil: %v", err) } - t.Logf("✅: queryStructSliceContext: %+v", u) + t.Logf("✅: queryContext: %+v", u) }) t.Run("failure,sql.ErrNoRows", func(t *testing.T) { t.Parallel() @@ -149,12 +107,29 @@ func Test_DB_queryStructContext(t *testing.T) { Username string `db:"username"` } var u user - db := &_Queryer{} - rows := &mockRows{ + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) + rows := &sqlRowsMock{ + NextFunc: func() bool { return false }, + ErrFunc: func() error { return nil }, + } + if err := db.queryRowContext(rows, nil, &u); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("❌: queryRowContext: expect(%v) != actual(%v)", sql.ErrNoRows, err) + } + }) + t.Run("failure,context.Canceled", func(t *testing.T) { + t.Parallel() + type user struct { + UserID string `db:"user_id"` + Username string `db:"username"` + } + var u user + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) + rows := &sqlRowsMock{ NextFunc: func() bool { return false }, + ErrFunc: func() error { return context.Canceled }, } - if err := db.queryStructContext(rows, nil, "db", &u); !errors.Is(err, sql.ErrNoRows) { - t.Fatalf("❌: queryStructContext: expect(%v) != actual(%v)", sql.ErrNoRows, err) + if err := db.queryRowContext(rows, nil, &u); !errors.Is(err, context.Canceled) { + t.Fatalf("❌: queryRowContext: expect(%v) != actual(%v)", context.Canceled, err) } }) } diff --git a/database/sql/rows.go b/database/sql/rows.go index 9e77864f..0b6f465e 100644 --- a/database/sql/rows.go +++ b/database/sql/rows.go @@ -5,8 +5,11 @@ import ( "reflect" ) -func ScanRows(rows SQLRows, structTag string, destPointer interface{}) error { - pointer := reflect.ValueOf(destPointer) // *Type or *[]Type or *[]*Type +// ScanRows scans rows to dst. +// +// dst must be a pointer. +func ScanRows(rows sqlRows, structTag string, dst interface{}) error { + pointer := reflect.ValueOf(dst) // expect *Type or *[]Type or *[]*Type if pointer.Kind() != reflect.Ptr { return fmt.Errorf("structSlicePointer.Kind=%s: %w", pointer.Kind(), ErrMustBePointer) } @@ -17,20 +20,20 @@ func ScanRows(rows SQLRows, structTag string, destPointer interface{}) error { deref := pointer.Elem() switch deref.Kind() { //nolint:exhaustive case reflect.Slice: - if err := scanRowsToStructSlice(rows, deref, structTag); err != nil { // []Type (or []*Type) - return fmt.Errorf("type=%T: %w", destPointer, err) + if err := scanRowsToSlice(rows, structTag, deref); err != nil { // []Type (or []*Type) + return fmt.Errorf("scanRowsToSlice: type=%T: %w", dst, err) } case reflect.Struct: - if err := scanRowsToStruct(rows, deref, structTag); err != nil { // Type (or *Type) - return fmt.Errorf("type=%T: %w", destPointer, err) + if err := scanRowsToStruct(rows, structTag, deref); err != nil { // Type (or *Type) + return fmt.Errorf("scanRowsToStruct: type=%T: %w", dst, err) } default: - return fmt.Errorf("type=%T: %w", destPointer, ErrDataTypeNotSupported) + return fmt.Errorf("type=%T: %w", dst, ErrDataTypeNotSupported) } return nil } -func scanRowsToStructSlice(rows SQLRows, destStructSlice reflect.Value, structTag string) error { // destStructSlice: []Type (or []*Type) +func scanRowsToSlice(rows sqlRows, structTag string, destStructSlice reflect.Value) error { // destStructSlice: []Type (or []*Type) sliceContentType := destStructSlice.Type().Elem() // sliceContentType: Type (or *Type) var sliceContentIsPointer bool if sliceContentType.Kind() == reflect.Ptr { @@ -39,13 +42,14 @@ func scanRowsToStructSlice(rows SQLRows, destStructSlice reflect.Value, structTa } if sliceContentType.Kind() != reflect.Struct { + // TODO: support other types return fmt.Errorf("destStructSlice.Kind=%s: %w", destStructSlice.Kind(), ErrDataTypeNotSupported) } destStructSlice.SetLen(0) for rows.Next() { v := reflect.New(sliceContentType).Elem() - if err := scanRowsToStruct(rows, v, structTag); err != nil { + if err := scanRowsToStruct(rows, structTag, v); err != nil { return fmt.Errorf("scanRowsToStruct: %w", err) } @@ -59,7 +63,7 @@ func scanRowsToStructSlice(rows SQLRows, destStructSlice reflect.Value, structTa return nil } -func scanRowsToStruct(rows SQLRows, destStruct reflect.Value, structTag string) error { +func scanRowsToStruct(rows sqlRows, structTag string, destStruct reflect.Value) error { columns, err := rows.Columns() if err != nil { return fmt.Errorf("rows.Columns: %w", err) diff --git a/database/sql/rows_test.go b/database/sql/rows_test.go index 7895e18d..67e1c0a2 100644 --- a/database/sql/rows_test.go +++ b/database/sql/rows_test.go @@ -18,12 +18,12 @@ func Test_ScanRows(t *testing.T) { } var u []user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 3 }, - ColumnsReturn: []string{"user_id", "username"}, + ColumnsFunc: func() ([]string, error) { return []string{"user_id", "username"}, nil }, ScanFunc: func(dest ...interface{}) error { for i := range dest { reflect.ValueOf(dest[i]).Elem().SetString("column" + strconv.Itoa(i)) @@ -44,12 +44,12 @@ func Test_ScanRows(t *testing.T) { } var u []*user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 3 }, - ColumnsReturn: []string{"user_id", "username"}, + ColumnsFunc: func() ([]string, error) { return []string{"user_id", "username"}, nil }, ScanFunc: func(dest ...interface{}) error { for i := range dest { reflect.ValueOf(dest[i]).Elem().SetString("column" + strconv.Itoa(i)) @@ -70,12 +70,12 @@ func Test_ScanRows(t *testing.T) { } var u user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 }, - ColumnsReturn: []string{"user_id", "username"}, + ColumnsFunc: func() ([]string, error) { return []string{"user_id", "username"}, nil }, ScanFunc: func(dest ...interface{}) error { for i := range dest { reflect.ValueOf(dest[i]).Elem().SetString("column" + strconv.Itoa(i)) @@ -96,7 +96,7 @@ func Test_ScanRows(t *testing.T) { } var notPointer user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 @@ -114,7 +114,7 @@ func Test_ScanRows(t *testing.T) { } var nilPointer *user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 @@ -132,12 +132,12 @@ func Test_ScanRows(t *testing.T) { } var u []*user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 3 }, - ColumnsReturn: []string{"user_id", "username"}, + ColumnsFunc: func() ([]string, error) { return []string{"user_id", "username"}, nil }, ScanFunc: func(dest ...interface{}) error { return sql.ErrConnDone }, @@ -149,7 +149,9 @@ func Test_ScanRows(t *testing.T) { t.Run("failure,reflect.Slice_ErrDataTypeNotSupported", func(t *testing.T) { t.Parallel() var u []string - if err := ScanRows(&mockRows{}, "db", &u); !errors.Is(err, ErrDataTypeNotSupported) { + if err := ScanRows(&sqlRowsMock{ + NextFunc: func() bool { return true }, + }, "db", &u); !errors.Is(err, ErrDataTypeNotSupported) { t.Fatalf("❌: queryStructContext: expect(%v) != actual(%v)", ErrDataTypeNotSupported, err) } }) @@ -161,12 +163,12 @@ func Test_ScanRows(t *testing.T) { } var u user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 }, - ColumnsReturn: []string{"user_id", "username"}, + ColumnsFunc: func() ([]string, error) { return []string{"user_id", "username"}, nil }, ScanFunc: func(dest ...interface{}) error { return sql.ErrConnDone }, @@ -183,12 +185,12 @@ func Test_ScanRows(t *testing.T) { } var u user i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 }, - ColumnsError: sql.ErrNoRows, + ColumnsFunc: func() ([]string, error) { return nil, sql.ErrNoRows }, } if err := ScanRows(rows, "db", &u); !errors.Is(err, sql.ErrNoRows) { t.Fatalf("❌: queryStructContext: expect(%v) != actual(%v)", sql.ErrNoRows, err) @@ -198,7 +200,7 @@ func Test_ScanRows(t *testing.T) { t.Parallel() var user string i := 0 - rows := &mockRows{ + rows := &sqlRowsMock{ NextFunc: func() bool { i++ return i < 2 diff --git a/database/sql/sql.go b/database/sql/sql.go deleted file mode 100644 index ca06e04b..00000000 --- a/database/sql/sql.go +++ /dev/null @@ -1,26 +0,0 @@ -package sqlz - -import ( - "context" - "database/sql" -) - -type SQLQueryer interface { - QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) -} - -type SQLExecuter interface { - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) -} - -type SQLRows interface { - Close() error - Columns() ([]string, error) - Next() bool - Scan(...interface{}) error - Err() error -} - -type SQLTxBeginner interface { - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} diff --git a/database/sql/tx.go b/database/sql/tx.go index 8bd423d0..35b9f674 100644 --- a/database/sql/tx.go +++ b/database/sql/tx.go @@ -5,7 +5,7 @@ import ( "database/sql" ) -func MustBeginTx(ctx context.Context, db SQLTxBeginner, opts *sql.TxOptions) *sql.Tx { +func MustBeginTx(ctx context.Context, db sqlTxBeginner, opts *sql.TxOptions) *sql.Tx { tx, err := db.BeginTx(ctx, opts) if err != nil { panic(err) diff --git a/database/sql/tx_test.go b/database/sql/tx_test.go index 5d95b626..5d841d51 100644 --- a/database/sql/tx_test.go +++ b/database/sql/tx_test.go @@ -10,7 +10,7 @@ func TestMustBeginTx(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { t.Parallel() - db := &mockDB{ + db := &sqlDBMock{ BeginTxFunc: func(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { return &sql.Tx{}, nil }, @@ -22,7 +22,7 @@ func TestMustBeginTx(t *testing.T) { }) t.Run("failure", func(t *testing.T) { t.Parallel() - db := &mockDB{ + db := &sqlDBMock{ BeginTxFunc: func(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { return nil, sql.ErrConnDone }, From 60f231bf5180c681e87c6edb450ac2efb9817cd3 Mon Sep 17 00:00:00 2001 From: ginokent <29125616+ginokent@users.noreply.github.com> Date: Sun, 27 Aug 2023 08:03:57 +0900 Subject: [PATCH 2/2] fix: misc --- .tool-versions | 1 - database/sql/queryer.go | 4 +-- database/sql/queryer_test.go | 8 +++--- database/sql/rows_test.go | 52 ++++++++++++++++++------------------ 4 files changed, 32 insertions(+), 33 deletions(-) delete mode 100644 .tool-versions diff --git a/.tool-versions b/.tool-versions deleted file mode 100644 index e7eea5d8..00000000 --- a/.tool-versions +++ /dev/null @@ -1 +0,0 @@ -golang 1.20.4 diff --git a/database/sql/queryer.go b/database/sql/queryer.go index ff232ee5..2a96de55 100644 --- a/database/sql/queryer.go +++ b/database/sql/queryer.go @@ -28,8 +28,6 @@ type queryerContext struct { type NewDBOption func(qc *queryerContext) -const defaultStructTag = "db" - func WithNewDBOptionStructTag(structTag string) NewDBOption { return func(qc *queryerContext) { qc.structTag = structTag @@ -40,6 +38,8 @@ func NewDB(db sqlQueryerContext, opts ...NewDBOption) QueryerContext { //nolint: return newDB(db, opts...) } +const defaultStructTag = "db" + func newDB(db sqlQueryerContext, opts ...NewDBOption) *queryerContext { qc := &queryerContext{ sqlQueryer: db, diff --git a/database/sql/queryer_test.go b/database/sql/queryer_test.go index ce7a5b63..230cf4d6 100644 --- a/database/sql/queryer_test.go +++ b/database/sql/queryer_test.go @@ -33,7 +33,7 @@ func Test_DB_queryContext(t *testing.T) { Username string `db:"username"` } var u user - db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag(defaultStructTag)) i := 0 rows := &sqlRowsMock{ NextFunc: func() bool { @@ -80,7 +80,7 @@ func Test_DB_queryRowContext(t *testing.T) { Username string `db:"username"` } var u user - db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag(defaultStructTag)) i := 0 rows := &sqlRowsMock{ NextFunc: func() bool { @@ -107,7 +107,7 @@ func Test_DB_queryRowContext(t *testing.T) { Username string `db:"username"` } var u user - db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag(defaultStructTag)) rows := &sqlRowsMock{ NextFunc: func() bool { return false }, ErrFunc: func() error { return nil }, @@ -123,7 +123,7 @@ func Test_DB_queryRowContext(t *testing.T) { Username string `db:"username"` } var u user - db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag("db")) + db := newDB(&sqlDBMock{}, WithNewDBOptionStructTag(defaultStructTag)) rows := &sqlRowsMock{ NextFunc: func() bool { return false }, ErrFunc: func() error { return context.Canceled }, diff --git a/database/sql/rows_test.go b/database/sql/rows_test.go index 67e1c0a2..c0739ea5 100644 --- a/database/sql/rows_test.go +++ b/database/sql/rows_test.go @@ -13,8 +13,8 @@ func Test_ScanRows(t *testing.T) { t.Run("success,reflect.Slice", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var u []user i := 0 @@ -31,7 +31,7 @@ func Test_ScanRows(t *testing.T) { return nil }, } - if err := ScanRows(rows, "db", &u); err != nil { + if err := ScanRows(rows, "testdb", &u); err != nil { t.Fatalf("❌: ScanRows: err != nil: %v", err) } t.Logf("✅: ScanRows: %+v", u) @@ -39,8 +39,8 @@ func Test_ScanRows(t *testing.T) { t.Run("success,reflect.Slice_pointer_slice", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var u []*user i := 0 @@ -57,7 +57,7 @@ func Test_ScanRows(t *testing.T) { return nil }, } - if err := ScanRows(rows, "db", &u); err != nil { + if err := ScanRows(rows, "testdb", &u); err != nil { t.Fatalf("❌: ScanRows: err != nil: %v", err) } t.Logf("✅: ScanRows: %+v", u) @@ -65,8 +65,8 @@ func Test_ScanRows(t *testing.T) { t.Run("success,reflect.Struct", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var u user i := 0 @@ -83,7 +83,7 @@ func Test_ScanRows(t *testing.T) { return nil }, } - if err := ScanRows(rows, "db", &u); err != nil { + if err := ScanRows(rows, "testdb", &u); err != nil { t.Fatalf("❌: ScanRows: err != nil: %v", err) } t.Logf("✅: ScanRows: %+v", u) @@ -91,8 +91,8 @@ func Test_ScanRows(t *testing.T) { t.Run("failure,ErrMustBePointer", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var notPointer user i := 0 @@ -102,15 +102,15 @@ func Test_ScanRows(t *testing.T) { return i < 2 }, } - if err := ScanRows(rows, "db", notPointer); !errors.Is(err, ErrMustBePointer) { + if err := ScanRows(rows, "testdb", notPointer); !errors.Is(err, ErrMustBePointer) { t.Fatalf("❌: ScanRows: expect(%v) != actual(%v)", ErrMustBePointer, err) } }) t.Run("failure,ErrMustNotNil", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var nilPointer *user i := 0 @@ -120,15 +120,15 @@ func Test_ScanRows(t *testing.T) { return i < 2 }, } - if err := ScanRows(rows, "db", nilPointer); !errors.Is(err, ErrMustNotNil) { + if err := ScanRows(rows, "testdb", nilPointer); !errors.Is(err, ErrMustNotNil) { t.Fatalf("❌: ScanRows: expect(%v) != actual(%v)", ErrMustNotNil, err) } }) t.Run("failure,reflect.Slice_Scan", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var u []*user i := 0 @@ -142,7 +142,7 @@ func Test_ScanRows(t *testing.T) { return sql.ErrConnDone }, } - if err := ScanRows(rows, "db", &u); !errors.Is(err, sql.ErrConnDone) { + if err := ScanRows(rows, "testdb", &u); !errors.Is(err, sql.ErrConnDone) { t.Fatalf("❌: queryStructContext: expect(%v) != actual(%v)", sql.ErrConnDone, err) } }) @@ -151,15 +151,15 @@ func Test_ScanRows(t *testing.T) { var u []string if err := ScanRows(&sqlRowsMock{ NextFunc: func() bool { return true }, - }, "db", &u); !errors.Is(err, ErrDataTypeNotSupported) { + }, "testdb", &u); !errors.Is(err, ErrDataTypeNotSupported) { t.Fatalf("❌: queryStructContext: expect(%v) != actual(%v)", ErrDataTypeNotSupported, err) } }) t.Run("failure,reflect.Struct_Scan", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var u user i := 0 @@ -173,15 +173,15 @@ func Test_ScanRows(t *testing.T) { return sql.ErrConnDone }, } - if err := ScanRows(rows, "db", &u); !errors.Is(err, sql.ErrConnDone) { + if err := ScanRows(rows, "testdb", &u); !errors.Is(err, sql.ErrConnDone) { t.Fatalf("❌: queryStructContext: expect(%v) != actual(%v)", sql.ErrConnDone, err) } }) t.Run("failure,reflect.Struct_Scan", func(t *testing.T) { t.Parallel() type user struct { - UserID string `db:"user_id"` - Username string `db:"username"` + UserID string `testdb:"user_id"` + Username string `testdb:"username"` } var u user i := 0 @@ -192,7 +192,7 @@ func Test_ScanRows(t *testing.T) { }, ColumnsFunc: func() ([]string, error) { return nil, sql.ErrNoRows }, } - if err := ScanRows(rows, "db", &u); !errors.Is(err, sql.ErrNoRows) { + if err := ScanRows(rows, "testdb", &u); !errors.Is(err, sql.ErrNoRows) { t.Fatalf("❌: queryStructContext: expect(%v) != actual(%v)", sql.ErrNoRows, err) } }) @@ -206,7 +206,7 @@ func Test_ScanRows(t *testing.T) { return i < 2 }, } - if err := ScanRows(rows, "db", &user); !errors.Is(err, ErrDataTypeNotSupported) { + if err := ScanRows(rows, "testdb", &user); !errors.Is(err, ErrDataTypeNotSupported) { t.Fatalf("❌: ScanRows: expect(%v) != actual(%v)", ErrDataTypeNotSupported, err) } })