-
Notifications
You must be signed in to change notification settings - Fork 6
/
rows.go
200 lines (168 loc) · 5.83 KB
/
rows.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
package embedded
import (
"database/sql/driver"
"errors"
"fmt"
"io"
gms "github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)
// doltMultiRows implements driver.RowsNextResultSet by aggregating a set of individual
// doltRows instances.
type doltMultiRows struct {
rowSets []*doltRows
currentRowSet int
}
var _ driver.RowsNextResultSet = (*doltMultiRows)(nil)
func (d *doltMultiRows) Columns() []string {
if d.currentRowSet >= len(d.rowSets) {
return nil
}
return d.rowSets[d.currentRowSet].Columns()
}
// Close implements the driver.Rows interface. When Close is called on a doltMultiRows instance,
// it will close all individual doltRows instances that it contains. If any errors are encountered
// while closing the individual row sets, the first error will be returned, after attempting to close
// all row sets.
func (d *doltMultiRows) Close() error {
var retErr error
for _, rowSet := range d.rowSets {
if err := rowSet.Close(); err != nil {
retErr = err
}
}
return retErr
}
func (d *doltMultiRows) Next(dest []driver.Value) error {
if d.currentRowSet >= len(d.rowSets) {
return io.EOF
}
return d.rowSets[d.currentRowSet].Next(dest)
}
func (d *doltMultiRows) HasNextResultSet() bool {
idx := d.currentRowSet + 1
for ; idx < len(d.rowSets); idx++ {
if d.rowSets[idx].isQueryResultSet || d.rowSets[idx].err != nil {
return true
}
}
return false
}
func (d *doltMultiRows) NextResultSet() error {
idx := d.currentRowSet + 1
for ; idx < len(d.rowSets); idx++ {
if d.rowSets[idx].isQueryResultSet || d.rowSets[idx].err != nil {
// Update the current row set index when we find the next result set for a query. If we encountered an
// error running the statement earlier and saved an error in the row set, return that error now that the
// result set with the error has been requested. This matches the MySQL driver's behavior.
d.currentRowSet = idx
return d.rowSets[d.currentRowSet].err
}
}
return io.EOF
}
type doltRows struct {
sch gms.Schema
rowIter gms.RowIter
gmsCtx *gms.Context
columns []string
// err holds any error encountered while trying to retrieve this result set
err error
// isQueryResultSet indicates if this result set was generated by a statement that doesn't produce a result set. For
// example, an INSERT or DML statement doesn't return a result set, but we still keep track of a doltRows
// instance for their results in case an error was returned. This field is also used to skip over doltRows
// that are not result sets when calling NextResultSet() on a doltMultiRows instance.
isQueryResultSet bool
}
var _ driver.Rows = (*doltRows)(nil)
// Columns returns the names of the columns. The number of columns of the result is inferred from the length of the
// slice. If a particular column name isn't known, an empty string should be returned for that entry.
func (rows *doltRows) Columns() []string {
if rows.columns == nil {
rows.columns = make([]string, len(rows.sch))
for i, col := range rows.sch {
rows.columns[i] = col.Name
}
}
return rows.columns
}
// Close closes the rows iterator.
func (rows *doltRows) Close() error {
if rows.rowIter == nil {
return nil
}
return translateError(rows.rowIter.Close(rows.gmsCtx))
}
// Next is called to populate the next row of data into the provided slice. The provided slice will be the same size as
// the Columns() are wide. Next returns io.EOF when there are no more rows.
func (rows *doltRows) Next(dest []driver.Value) error {
nextRow, err := rows.rowIter.Next(rows.gmsCtx)
if err != nil {
if err == io.EOF {
return io.EOF
}
return translateError(err)
}
if len(dest) != len(nextRow) {
return errors.New("mismatch between expected column count and actual column count")
}
for i := range nextRow {
if v, ok := nextRow[i].(driver.Valuer); ok {
dest[i], err = v.Value()
if err != nil {
return fmt.Errorf("error processing column %d: %w", i, err)
}
} else if geomValue, ok := nextRow[i].(types.GeometryValue); ok {
dest[i] = geomValue.Serialize()
} else if enumType, ok := rows.sch[i].Type.(gms.EnumType); ok {
if v, _, err := enumType.Convert(nextRow[i]); err != nil {
return fmt.Errorf("could not convert to expected enum type for column %d: %w", i, err)
} else if enumStr, ok := enumType.At(int(v.(uint16))); !ok {
return fmt.Errorf("not a valid enum index for column %d: %v", i, v)
} else {
dest[i] = enumStr
}
} else if setType, ok := rows.sch[i].Type.(gms.SetType); ok {
if v, _, err := setType.Convert(nextRow[i]); err != nil {
return fmt.Errorf("could not convert to expected set type for column %d: %w", i, err)
} else if setStr, err := setType.BitsToString(v.(uint64)); err != nil {
return fmt.Errorf("could not convert value to set string for column %d: %w", i, err)
} else {
dest[i] = setStr
}
} else {
dest[i] = nextRow[i]
}
}
return nil
}
// peekableRowIter wrap another gms.RowIter and allows the caller to peek at results, without disturbing the order
// that results are returned from the Next() method.
type peekableRowIter struct {
iter gms.RowIter
peeks []gms.Row
}
var _ gms.RowIter = (*peekableRowIter)(nil)
// Peek returns the next row from this row iterator, without causing that row to be skipped from future calls
// to Next(). There is no limit on how many rows can be peeked.
func (p *peekableRowIter) Peek(ctx *gms.Context) (gms.Row, error) {
next, err := p.iter.Next(ctx)
if err != nil {
return nil, err
}
p.peeks = append(p.peeks, next)
return next, nil
}
// Next implements gms.RowIter
func (p *peekableRowIter) Next(ctx *gms.Context) (gms.Row, error) {
if len(p.peeks) > 0 {
peek := p.peeks[0]
p.peeks = p.peeks[1:]
return peek, nil
}
return p.iter.Next(ctx)
}
// Close implements gms.RowIter
func (p *peekableRowIter) Close(ctx *gms.Context) error {
return p.iter.Close(ctx)
}