-
Notifications
You must be signed in to change notification settings - Fork 0
/
conn.go
111 lines (96 loc) · 2.45 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package sqlite
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"reflect"
"strings"
"time"
"ella.to/logger"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
type Conn struct {
conn *sqlite.Conn
put func(conn *Conn)
}
// When your try to use transaction in a nice way, you can use the following
// at the beginning of your code:
//
// defer conn.Save(&err)
func (c *Conn) Save(err *error) {
sqlitex.Save(c.conn)(err)
}
// Done returns the connection back to the pool
func (c *Conn) Done() {
c.put(c)
}
func (c *Conn) Prepare(ctx context.Context, sql string, values ...any) (*Stmt, error) {
if slog.Default().Enabled(ctx, slog.LevelDebug) {
logger.Debug(ctx, "prepare sql", "sql", ShowSql(sql, values...))
}
sql = strings.TrimSpace(sql)
stmt, err := c.conn.Prepare(sql)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrPrepareSQL, err)
}
for i, value := range values {
i++ // bind starts from 1
if value == nil {
stmt.BindNull(i)
continue
}
valueType := reflect.TypeOf(value)
switch valueType.Kind() {
case reflect.Slice:
if valueType.Elem().Kind() == reflect.Uint8 {
blob, ok := value.([]byte)
if !ok {
blob = value.(json.RawMessage)
}
stmt.BindZeroBlob(i, int64(len(blob)))
stmt.BindBytes(i, blob)
continue
}
fallthrough
case reflect.Map:
var buffer bytes.Buffer
err = json.NewEncoder(&buffer).Encode(value)
if err != nil {
return nil, err
}
stmt.BindText(i, buffer.String())
continue
case reflect.String:
stmt.BindText(i, reflect.ValueOf(value).String())
continue
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
stmt.BindInt64(i, reflect.ValueOf(value).Int())
continue
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
stmt.BindInt64(i, int64(reflect.ValueOf(value).Uint()))
continue
case reflect.Float32, reflect.Float64:
stmt.BindFloat(i, reflect.ValueOf(value).Float())
continue
case reflect.Bool:
stmt.BindBool(i, reflect.ValueOf(value).Bool())
continue
}
switch v := value.(type) {
case time.Time:
stmt.BindInt64(i, v.UTC().Unix())
case fmt.Stringer:
stmt.BindText(i, v.String())
default:
return nil, ErrUnknownType
}
}
return stmt, nil
}
// Use this function to execute a script that contains multiple SQL statements
func (c *Conn) ExecScript(sql string) error {
return sqlitex.ExecScript(c.conn, strings.TrimSpace(sql))
}