-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
BREAKING CHANGE: fix database/sql funcs (#84)
- Loading branch information
Showing
13 changed files
with
423 additions
and
167 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package sqlz | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
) | ||
|
||
func MustOpen(ctx context.Context, driverName string, dataSourceName string) *sql.DB { | ||
db, err := sql.Open(driverName, dataSourceName) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if err := db.PingContext(ctx); err != nil { | ||
panic(err) | ||
} | ||
|
||
return db | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package sqlz //nolint:testpackage | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
"fmt" | ||
"strings" | ||
"testing" | ||
) | ||
|
||
func TestMustOpen(t *testing.T) { | ||
t.Parallel() | ||
|
||
t.Run("success", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
sql.Register(t.Name(), &driverDriverMock{ | ||
OpenFunc: func(name string) (driver.Conn, error) { | ||
return &driverConnMock{ | ||
PrepareFunc: func(query string) (driver.Stmt, error) { | ||
return &driverStmtMock{ | ||
CloseFunc: func() error { | ||
return nil | ||
}, | ||
ExecFunc: func(args []driver.Value) (driver.Result, error) { | ||
return &driverResultMock{}, nil | ||
}, | ||
QueryFunc: func(args []driver.Value) (driver.Rows, error) { | ||
return &driverRowsMock{}, nil | ||
}, | ||
}, nil | ||
}, | ||
}, nil | ||
}, | ||
}) | ||
|
||
ctx := context.Background() | ||
db := MustOpen(ctx, t.Name(), ":memory:") | ||
if db == nil { | ||
t.Fatalf("❌: MustOpen: db == nil") | ||
} | ||
}) | ||
|
||
t.Run("failure,sqlUnknownDriver", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
defer func() { | ||
const expect = "sql: unknown driver" | ||
if actual := fmt.Sprintf("%v", recover()); !strings.Contains(actual, expect) { | ||
t.Errorf("❌: recover: expect(%v) != actual(%s)", expect, actual) | ||
} | ||
}() | ||
|
||
MustOpen(context.Background(), t.Name(), "") | ||
}) | ||
|
||
t.Run("failure,sqlDriverOpenError", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
defer func() { | ||
expect := context.Canceled | ||
if actual := fmt.Sprintf("%v", recover()); fmt.Sprint(expect) != fmt.Sprint(actual) { | ||
t.Errorf("❌: recover: expect(%v) != actual(%s)", expect, actual) | ||
} | ||
}() | ||
|
||
sql.Register(t.Name(), &driverDriverMock{}) | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
cancel() | ||
MustOpen(ctx, t.Name(), "") | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package sqlz | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
) | ||
|
||
type sqlQueryerContext interface { | ||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||
} | ||
|
||
type sqlRows interface { | ||
Close() error | ||
Columns() ([]string, error) | ||
Next() bool | ||
Scan(...interface{}) error | ||
Err() error | ||
} | ||
|
||
type sqlTxBeginner interface { | ||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
package sqlz //nolint:testpackage | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
) | ||
|
||
type driverDriverMock struct { | ||
OpenFunc func(name string) (driver.Conn, error) | ||
} | ||
|
||
func (m *driverDriverMock) Open(name string) (driver.Conn, error) { | ||
return m.OpenFunc(name) | ||
} | ||
|
||
var _ driver.Driver = (*driverDriverMock)(nil) | ||
|
||
type driverConnMock struct { | ||
PrepareFunc func(query string) (driver.Stmt, error) | ||
CloseFunc func() error | ||
BeginFunc func() (driver.Tx, error) | ||
} | ||
|
||
func (m *driverConnMock) Prepare(query string) (driver.Stmt, error) { | ||
return m.PrepareFunc(query) | ||
} | ||
|
||
func (m *driverConnMock) Close() error { | ||
return m.CloseFunc() | ||
} | ||
|
||
func (m *driverConnMock) Begin() (driver.Tx, error) { | ||
return m.BeginFunc() | ||
} | ||
|
||
var _ driver.Conn = (*driverConnMock)(nil) | ||
|
||
type driverStmtMock struct { | ||
CloseFunc func() error | ||
NumInputFunc func() int | ||
ExecFunc func(args []driver.Value) (driver.Result, error) | ||
QueryFunc func(args []driver.Value) (driver.Rows, error) | ||
} | ||
|
||
func (m *driverStmtMock) Close() error { | ||
return m.CloseFunc() | ||
} | ||
|
||
func (m *driverStmtMock) NumInput() int { | ||
return m.NumInputFunc() | ||
} | ||
|
||
func (m *driverStmtMock) Exec(args []driver.Value) (driver.Result, error) { | ||
return m.ExecFunc(args) | ||
} | ||
|
||
func (m *driverStmtMock) Query(args []driver.Value) (driver.Rows, error) { | ||
return m.QueryFunc(args) | ||
} | ||
|
||
var _ driver.Stmt = (*driverStmtMock)(nil) | ||
|
||
type driverResultMock struct { | ||
LastInsertIdFunc func() (int64, error) //nolint:stylecheck // NOTE: sql.Result has LastInsertId method | ||
RowsAffectedFunc func() (int64, error) | ||
} | ||
|
||
func (m *driverResultMock) LastInsertId() (int64, error) { | ||
return m.LastInsertIdFunc() | ||
} | ||
|
||
func (m *driverResultMock) RowsAffected() (int64, error) { | ||
return m.RowsAffectedFunc() | ||
} | ||
|
||
var _ driver.Result = (*driverResultMock)(nil) | ||
|
||
type driverRowsMock struct { | ||
CloseFunc func() error | ||
ColumnsFunc func() []string | ||
NextFunc func(dest []driver.Value) error | ||
} | ||
|
||
func (m *driverRowsMock) Close() error { | ||
return m.CloseFunc() | ||
} | ||
|
||
func (m *driverRowsMock) Columns() []string { | ||
return m.ColumnsFunc() | ||
} | ||
|
||
func (m *driverRowsMock) Next(dest []driver.Value) error { | ||
return m.NextFunc(dest) | ||
} | ||
|
||
var _ driver.Rows = (*driverRowsMock)(nil) | ||
|
||
type sqlDBMock struct { | ||
Rows *sql.Rows | ||
Error error | ||
|
||
BeginTxFunc func(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||
} | ||
|
||
var ( | ||
_ sqlQueryerContext = (*sqlDBMock)(nil) | ||
_ sqlTxBeginner = (*sqlDBMock)(nil) | ||
) | ||
|
||
func (m *sqlDBMock) QueryContext(_ context.Context, _ string, _ ...interface{}) (*sql.Rows, error) { | ||
return m.Rows, m.Error | ||
} | ||
|
||
func (m *sqlDBMock) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { | ||
return m.BeginTxFunc(ctx, opts) | ||
} | ||
|
||
type sqlRowsMock struct { | ||
CloseFunc func() error | ||
ColumnsFunc func() ([]string, error) | ||
NextFunc func() bool | ||
ScanFunc func(dest ...interface{}) error | ||
ErrFunc func() error | ||
} | ||
|
||
var _ sqlRows = (*sqlRowsMock)(nil) | ||
|
||
func (m *sqlRowsMock) Close() error { | ||
if m.CloseFunc == nil { | ||
return nil | ||
} | ||
return m.CloseFunc() | ||
} | ||
|
||
func (m *sqlRowsMock) Columns() ([]string, error) { | ||
return m.ColumnsFunc() | ||
} | ||
|
||
func (m *sqlRowsMock) Next() bool { | ||
return m.NextFunc() | ||
} | ||
|
||
func (m *sqlRowsMock) Scan(dest ...interface{}) error { | ||
return m.ScanFunc(dest...) | ||
} | ||
|
||
func (m *sqlRowsMock) Err() error { | ||
return m.ErrFunc() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.