-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqueryer.go
94 lines (77 loc) · 2.63 KB
/
queryer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package sqlz
import (
"context"
"fmt"
)
type QueryerContext interface {
// QueryContext executes a query that returns rows, typically a SELECT.
//
// The dst must be a pointer.
// The args are for any placeholder parameters in the query.
QueryContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error
// QueryRowContext executes a query that is expected to return at most one row.
// It always returns a non-nil value or an error.
//
// The dst must be a pointer.
// The args are for any placeholder parameters in the query.
QueryRowContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error
}
type (
queryerContext struct {
sqlQueryer sqlQueryerContext
// Options
structTag string
}
NewDBOption interface{ apply(qc *queryerContext) }
newDBOptionStructTag string
)
func (f newDBOptionStructTag) apply(qc *queryerContext) { qc.structTag = string(f) }
func WithNewDBOptionStructTag(structTag string) NewDBOption { //nolint:ireturn
return newDBOptionStructTag(structTag)
}
func NewDB(db sqlQueryerContext, opts ...NewDBOption) QueryerContext { //nolint:ireturn
return newDB(db, opts...)
}
const defaultStructTag = "db"
func newDB(db sqlQueryerContext, opts ...NewDBOption) *queryerContext {
qc := &queryerContext{
sqlQueryer: db,
structTag: defaultStructTag,
}
for _, opt := range opts {
opt.apply(qc)
}
return qc
}
func (qc *queryerContext) QueryContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error {
rows, err := qc.sqlQueryer.QueryContext(ctx, query, args...) //nolint:rowserrcheck
return qc.queryContext(rows, err, dst)
}
func (qc *queryerContext) queryContext(rows sqlRows, queryContextErr error, dst interface{}) (err error) {
if queryContextErr != nil {
return fmt.Errorf("QueryContext: %w", queryContextErr)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = fmt.Errorf("rows.Close: %w", closeErr)
return
}
}()
return ScanRows(rows, qc.structTag, dst)
}
func (qc *queryerContext) QueryRowContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error {
rows, err := qc.sqlQueryer.QueryContext(ctx, query, args...) //nolint:rowserrcheck
return qc.queryRowContext(rows, err, dst)
}
func (qc *queryerContext) queryRowContext(rows sqlRows, queryContextErr error, dst interface{}) (err error) {
if queryContextErr != nil {
return fmt.Errorf("QueryContext: %w", queryContextErr)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = fmt.Errorf("rows.Close: %w", closeErr)
return
}
}()
return ScanRows(rows, qc.structTag, dst)
}