Skip to content

Commit

Permalink
BREAKING CHANGE: fix database/sql funcs (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
ginokent authored Aug 26, 2023
2 parents ac61976 + 60f231b commit dc5cbab
Show file tree
Hide file tree
Showing 13 changed files with 423 additions and 167 deletions.
3 changes: 2 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ linters:
- ifshort # for readability
- interfacer # deprecated https://github.com/mvdan/interfacer
- interfacebloat # unnecessary
- ireturn # unnecessary
- ireturn # TODO: enable
- lll # unnecessary
- maligned # deprecated https://github.com/mdempsky/maligned
- nlreturn # ignore "return with no blank line before"
Expand Down Expand Up @@ -58,5 +58,6 @@ issues:
- maintidx
- noctx
- revive
- testpackage
- varnamelen
- wrapcheck
1 change: 0 additions & 1 deletion .tool-versions

This file was deleted.

19 changes: 19 additions & 0 deletions database/sql/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package sqlz

import (
"context"
"database/sql"
)

func MustOpen(ctx context.Context, driverName string, dataSourceName string) *sql.DB {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
panic(err)
}

if err := db.PingContext(ctx); err != nil {
panic(err)
}

return db
}
74 changes: 74 additions & 0 deletions database/sql/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package sqlz //nolint:testpackage

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strings"
"testing"
)

func TestMustOpen(t *testing.T) {
t.Parallel()

t.Run("success", func(t *testing.T) {
t.Parallel()

sql.Register(t.Name(), &driverDriverMock{
OpenFunc: func(name string) (driver.Conn, error) {
return &driverConnMock{
PrepareFunc: func(query string) (driver.Stmt, error) {
return &driverStmtMock{
CloseFunc: func() error {
return nil
},
ExecFunc: func(args []driver.Value) (driver.Result, error) {
return &driverResultMock{}, nil
},
QueryFunc: func(args []driver.Value) (driver.Rows, error) {
return &driverRowsMock{}, nil
},
}, nil
},
}, nil
},
})

ctx := context.Background()
db := MustOpen(ctx, t.Name(), ":memory:")
if db == nil {
t.Fatalf("❌: MustOpen: db == nil")
}
})

t.Run("failure,sqlUnknownDriver", func(t *testing.T) {
t.Parallel()

defer func() {
const expect = "sql: unknown driver"
if actual := fmt.Sprintf("%v", recover()); !strings.Contains(actual, expect) {
t.Errorf("❌: recover: expect(%v) != actual(%s)", expect, actual)
}
}()

MustOpen(context.Background(), t.Name(), "")
})

t.Run("failure,sqlDriverOpenError", func(t *testing.T) {
t.Parallel()

defer func() {
expect := context.Canceled
if actual := fmt.Sprintf("%v", recover()); fmt.Sprint(expect) != fmt.Sprint(actual) {
t.Errorf("❌: recover: expect(%v) != actual(%s)", expect, actual)
}
}()

sql.Register(t.Name(), &driverDriverMock{})

ctx, cancel := context.WithCancel(context.Background())
cancel()
MustOpen(ctx, t.Name(), "")
})
}
22 changes: 22 additions & 0 deletions database/sql/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package sqlz

import (
"context"
"database/sql"
)

type sqlQueryerContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}

type sqlRows interface {
Close() error
Columns() ([]string, error)
Next() bool
Scan(...interface{}) error
Err() error
}

type sqlTxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
150 changes: 150 additions & 0 deletions database/sql/mock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package sqlz //nolint:testpackage

import (
"context"
"database/sql"
"database/sql/driver"
)

type driverDriverMock struct {
OpenFunc func(name string) (driver.Conn, error)
}

func (m *driverDriverMock) Open(name string) (driver.Conn, error) {
return m.OpenFunc(name)
}

var _ driver.Driver = (*driverDriverMock)(nil)

type driverConnMock struct {
PrepareFunc func(query string) (driver.Stmt, error)
CloseFunc func() error
BeginFunc func() (driver.Tx, error)
}

func (m *driverConnMock) Prepare(query string) (driver.Stmt, error) {
return m.PrepareFunc(query)
}

func (m *driverConnMock) Close() error {
return m.CloseFunc()
}

func (m *driverConnMock) Begin() (driver.Tx, error) {
return m.BeginFunc()
}

var _ driver.Conn = (*driverConnMock)(nil)

type driverStmtMock struct {
CloseFunc func() error
NumInputFunc func() int
ExecFunc func(args []driver.Value) (driver.Result, error)
QueryFunc func(args []driver.Value) (driver.Rows, error)
}

func (m *driverStmtMock) Close() error {
return m.CloseFunc()
}

func (m *driverStmtMock) NumInput() int {
return m.NumInputFunc()
}

func (m *driverStmtMock) Exec(args []driver.Value) (driver.Result, error) {
return m.ExecFunc(args)
}

func (m *driverStmtMock) Query(args []driver.Value) (driver.Rows, error) {
return m.QueryFunc(args)
}

var _ driver.Stmt = (*driverStmtMock)(nil)

type driverResultMock struct {
LastInsertIdFunc func() (int64, error) //nolint:stylecheck // NOTE: sql.Result has LastInsertId method
RowsAffectedFunc func() (int64, error)
}

func (m *driverResultMock) LastInsertId() (int64, error) {
return m.LastInsertIdFunc()
}

func (m *driverResultMock) RowsAffected() (int64, error) {
return m.RowsAffectedFunc()
}

var _ driver.Result = (*driverResultMock)(nil)

type driverRowsMock struct {
CloseFunc func() error
ColumnsFunc func() []string
NextFunc func(dest []driver.Value) error
}

func (m *driverRowsMock) Close() error {
return m.CloseFunc()
}

func (m *driverRowsMock) Columns() []string {
return m.ColumnsFunc()
}

func (m *driverRowsMock) Next(dest []driver.Value) error {
return m.NextFunc(dest)
}

var _ driver.Rows = (*driverRowsMock)(nil)

type sqlDBMock struct {
Rows *sql.Rows
Error error

BeginTxFunc func(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

var (
_ sqlQueryerContext = (*sqlDBMock)(nil)
_ sqlTxBeginner = (*sqlDBMock)(nil)
)

func (m *sqlDBMock) QueryContext(_ context.Context, _ string, _ ...interface{}) (*sql.Rows, error) {
return m.Rows, m.Error
}

func (m *sqlDBMock) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return m.BeginTxFunc(ctx, opts)
}

type sqlRowsMock struct {
CloseFunc func() error
ColumnsFunc func() ([]string, error)
NextFunc func() bool
ScanFunc func(dest ...interface{}) error
ErrFunc func() error
}

var _ sqlRows = (*sqlRowsMock)(nil)

func (m *sqlRowsMock) Close() error {
if m.CloseFunc == nil {
return nil
}
return m.CloseFunc()
}

func (m *sqlRowsMock) Columns() ([]string, error) {
return m.ColumnsFunc()
}

func (m *sqlRowsMock) Next() bool {
return m.NextFunc()
}

func (m *sqlRowsMock) Scan(dest ...interface{}) error {
return m.ScanFunc(dest...)
}

func (m *sqlRowsMock) Err() error {
return m.ErrFunc()
}
72 changes: 54 additions & 18 deletions database/sql/queryer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,85 @@ import (
"fmt"
)

type Queryer interface {
QueryStructSliceContext(ctx context.Context, structTag string, destStructSlicePointer interface{}, query string, args ...any) error
QueryStructContext(ctx context.Context, structTag string, destStructPointer interface{}, query string, args ...any) error
type QueryerContext interface {
// QueryContext executes a query that returns rows, typically a SELECT.
//
// The dst must be a pointer.
// The args are for any placeholder parameters in the query.
QueryContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error
// QueryRowContext executes a query that is expected to return at most one row.
// It always returns a non-nil value or an error.
//
// The dst must be a pointer.
// The args are for any placeholder parameters in the query.
QueryRowContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error
}

type _Queryer struct {
SQLQueryer
type queryerContext struct {
sqlQueryer sqlQueryerContext

structTag string
}

func NewDB(db SQLQueryer) Queryer {
return &_Queryer{
SQLQueryer: db,
type NewDBOption func(qc *queryerContext)

func WithNewDBOptionStructTag(structTag string) NewDBOption {
return func(qc *queryerContext) {
qc.structTag = structTag
}
}

func (s *_Queryer) QueryStructSliceContext(ctx context.Context, structTag string, destStructSlicePointer interface{}, query string, args ...any) error {
rows, err := s.QueryContext(ctx, query, args...) //nolint:rowserrcheck
return s.queryStructSliceContext(rows, err, structTag, destStructSlicePointer)
func NewDB(db sqlQueryerContext, opts ...NewDBOption) QueryerContext { //nolint:ireturn
return newDB(db, opts...)
}

const defaultStructTag = "db"

func newDB(db sqlQueryerContext, opts ...NewDBOption) *queryerContext {
qc := &queryerContext{
sqlQueryer: db,
structTag: defaultStructTag,
}

for _, opt := range opts {
opt(qc)
}

return qc
}

func (qc *queryerContext) QueryContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error {
rows, err := qc.sqlQueryer.QueryContext(ctx, query, args...) //nolint:rowserrcheck
return qc.queryContext(rows, err, dst)
}

func (s *_Queryer) queryStructSliceContext(rows SQLRows, queryContextErr error, structTag string, destStructSlicePointer interface{}) error {
func (qc *queryerContext) queryContext(rows sqlRows, queryContextErr error, dst interface{}) error {
if queryContextErr != nil {
return fmt.Errorf("QueryContext: %w", queryContextErr)
}
defer rows.Close()

return ScanRows(rows, structTag, destStructSlicePointer)
return ScanRows(rows, qc.structTag, dst)
}

func (s *_Queryer) QueryStructContext(ctx context.Context, structTag string, destStructPointer interface{}, query string, args ...any) error {
rows, err := s.QueryContext(ctx, query, args...) //nolint:rowserrcheck
return s.queryStructContext(rows, err, structTag, destStructPointer)
func (qc *queryerContext) QueryRowContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error {
rows, err := qc.sqlQueryer.QueryContext(ctx, query, args...) //nolint:rowserrcheck
return qc.queryRowContext(rows, err, dst)
}

func (s *_Queryer) queryStructContext(rows SQLRows, queryContextErr error, structTag string, destStructPointer interface{}) error {
func (qc *queryerContext) queryRowContext(rows sqlRows, queryContextErr error, dst interface{}) error {
if queryContextErr != nil {
return fmt.Errorf("QueryContext: %w", queryContextErr)
}
defer rows.Close()

// behaver like *sql.Row
if !rows.Next() {
if err := rows.Err(); err != nil {
return err //nolint:wrapcheck
}
return sql.ErrNoRows
}

return ScanRows(rows, structTag, destStructPointer)
return ScanRows(rows, qc.structTag, dst)
}
Loading

0 comments on commit dc5cbab

Please sign in to comment.