forked from cockroachdb/copyist
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalues.go
464 lines (430 loc) · 13.4 KB
/
values.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
// Copyright 2020 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.
package copyist
import (
"bytes"
"database/sql/driver"
"encoding/base64"
"errors"
"fmt"
"io"
"strconv"
"strings"
"text/scanner"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
"github.com/lib/pq"
)
// valueType is an enumeration of all types that can be round-tripped to and
// from a copyist recording file, with no important information lost in the
// translation. If/when other drivers are supported, they will need to add
// support for any types that are not already handled:
//
// 1. Add enumeration value below. Use an explicit numeric value so that its
// easier to look up a type by number. For a new driver, leave plenty of
// space between numeric runs so that previous drivers can add more types.
// 2. Add a case to the formatValueWithType switch.
// 3. Add a case to the parseValueWithType switch.
// 4. Add a case to the deepCopyValue switch if the value's content might be
// mutated across calls to the driver.
type valueType int
const (
_ valueType = 0
nilType valueType = 1
stringType valueType = 2
intType valueType = 3
int64Type valueType = 4
float64Type valueType = 5
boolType valueType = 6
errorType valueType = 7
timeType valueType = 8
stringSliceType valueType = 9
byteSliceType valueType = 10
valueSliceType valueType = 11
// Custom pq types.
pqErrorType valueType = 100
// Custom pgx types.
pgConnErrorType valueType = 200
)
// formatValueWithType converts the given value into a formatted string suitable
// for inclusion in a copyist recording file. The format is as follows:
//
// <dataType>:<formattedValue>
//
// where dataType is the numeric value of the `copyist.valueType` enumeration,
// and stringValue is a formatted representation of the value that conforms to
// these rules:
//
// 1. Contains no linefeed or newline characters. This allows simple
// line-by-line parsing of record declarations.
// 2. Contains no tab characters. This allows a simple string split by the tab
// character to break a line into fields.
// 3. Contains no bracket or comma characters, except as part of a valid Go
// literal string format. This allows nested slice types to be parsed.
//
// Many data types already follow these rules with nothing more to do. Those
// data types that do not (e.g. string) need to perform escaping in order to
// ensure their formatted representation never contains disallowed characters.
func formatValueWithType(val interface{}) string {
if val == nil {
return fmt.Sprintf("%d:nil", nilType)
}
switch t := val.(type) {
// Custom pq types.
case *pq.Error:
return fmt.Sprintf("%d:%s", pqErrorType, formatPqError(t))
// Custom pgx types.
case *pgconn.PgError:
return fmt.Sprintf("%d:%s", pgConnErrorType, formatPgConnError(t))
// Built-in Go types.
case string:
return fmt.Sprintf("%d:%s", stringType, strconv.Quote(t))
case int:
return fmt.Sprintf("%d:%d", intType, val)
case int64:
return fmt.Sprintf("%d:%d", int64Type, val)
case float64:
return fmt.Sprintf("%d:%g", float64Type, t)
case bool:
return fmt.Sprintf("%d:%v", boolType, t)
case error:
return fmt.Sprintf("%d:%s", errorType, strconv.Quote(t.Error()))
case time.Time:
// time.Format normalizes the +00:00 UTC timezone into "Z". This causes
// the recorded output to differ from the "real" driver output. Use a
// format that's round-trippable by parseValueWithType.
s := t.Format(time.RFC3339Nano)
if strings.HasSuffix(s, "Z") && t.Location() != time.UTC {
s = s[:len(s)-1] + "+00:00"
}
return fmt.Sprintf("%d:%s", timeType, s)
case []string:
var buf bytes.Buffer
buf.WriteByte('[')
for i, s := range t {
if i != 0 {
buf.WriteByte(',')
}
buf.WriteString(strconv.Quote(s))
}
buf.WriteByte(']')
return fmt.Sprintf("%d:%s", stringSliceType, buf.String())
case []byte:
s := base64.RawStdEncoding.EncodeToString(t)
return fmt.Sprintf("%d:%s", byteSliceType, s)
case []driver.Value:
var buf bytes.Buffer
buf.WriteByte('[')
for i, v := range t {
if i != 0 {
buf.WriteByte(',')
}
buf.WriteString(formatValueWithType(v))
}
buf.WriteByte(']')
return fmt.Sprintf("%d:%s", valueSliceType, buf.String())
default:
panic(fmt.Errorf("unsupported type: %T", t))
}
}
// formatPqError returns a lib/pq error as a string that is suitable for
// inclusion in a copyist recording file. It does this by using the pgproto3
// library to format the error using the Postgres wire protocol, and then
// returns it as a quoted string.
func formatPqError(pqErr *pq.Error) string {
resp := pgproto3.ErrorResponse{
Severity: pqErr.Severity,
Code: string(pqErr.Code),
Message: pqErr.Message,
Detail: pqErr.Detail,
Hint: pqErr.Hint,
Position: stringToInt32OrZero(pqErr.Position),
InternalPosition: stringToInt32OrZero(pqErr.InternalPosition),
InternalQuery: pqErr.InternalQuery,
Where: pqErr.Where,
SchemaName: pqErr.Schema,
TableName: pqErr.Table,
ColumnName: pqErr.Column,
DataTypeName: pqErr.DataTypeName,
ConstraintName: pqErr.Constraint,
File: pqErr.File,
Line: stringToInt32OrZero(pqErr.Line),
Routine: pqErr.Routine,
}
// Encode using the pgproto3 library and skip the Error header bytes.
encoded, _ := resp.Encode(nil)
encoded = encoded[5:]
return strconv.Quote(string(encoded))
}
// formatPgConnError returns a lib/pq error as a string that is suitable for
// inclusion in a copyist recording file. It does this by using the pgproto3
// library to format the error using the Postgres wire protocol, and then
// returns it as a quoted string.
func formatPgConnError(pgxError *pgconn.PgError) string {
resp := pgproto3.ErrorResponse{
Severity: pgxError.Severity,
Code: pgxError.Code,
Message: pgxError.Message,
Detail: pgxError.Detail,
Hint: pgxError.Hint,
Position: pgxError.Position,
InternalPosition: pgxError.InternalPosition,
InternalQuery: pgxError.InternalQuery,
Where: pgxError.Where,
SchemaName: pgxError.SchemaName,
TableName: pgxError.TableName,
ColumnName: pgxError.ColumnName,
DataTypeName: pgxError.DataTypeName,
ConstraintName: pgxError.ConstraintName,
File: pgxError.File,
Line: pgxError.Line,
Routine: pgxError.Routine,
}
// Encode using the pgproto3 library and skip the Error header bytes.
encoded, _ := resp.Encode(nil)
encoded = encoded[5:]
return strconv.Quote(string(encoded))
}
// parseValueWithType parses a value from the copyist recording file, in the
// format produced by the `formatValueWithType` function:
//
// <dataType>:<formattedValue>
//
// Only well-known "valueType" data types are supported.
func parseValueWithType(valWithTyp string) (interface{}, error) {
index := strings.Index(valWithTyp, ":")
if index == -1 {
return nil, errors.New("expected colon")
}
num, err := strconv.Atoi(valWithTyp[:index])
if err != nil {
return nil, err
}
typ := valueType(num)
val := valWithTyp[index+1:]
switch typ {
// Custom pq types.
case pqErrorType:
return parsePqError(val)
// Custom pgx types.
case pgConnErrorType:
return parsePgConnError(val)
// Built-in Go types.
case nilType:
if val != "nil" {
return nil, errors.New("expected nil")
}
return nil, nil
case stringType:
return strconv.Unquote(val)
case intType:
return strconv.Atoi(val)
case int64Type:
return strconv.ParseInt(val, 10, 64)
case float64Type:
return strconv.ParseFloat(val, 64)
case boolType:
if val == "false" {
return false, nil
} else if val == "true" {
return true, nil
}
return nil, errors.New("expected true or false")
case errorType:
s, err := strconv.Unquote(val)
if err != nil {
return nil, err
}
if s == driver.ErrSkip.Error() {
return driver.ErrSkip, nil
}
if s == io.EOF.Error() {
// Return reference to singleton object so that callers can compare
// by reference.
return io.EOF, nil
}
return errors.New(s), nil
case timeType:
return time.Parse(time.RFC3339Nano, val)
case stringSliceType:
strs, err := parseSlice(val)
if err != nil {
return nil, err
}
for i := range strs {
strs[i], err = strconv.Unquote(strs[i])
if err != nil {
return nil, err
}
}
return strs, nil
case byteSliceType:
return base64.RawStdEncoding.DecodeString(val)
case valueSliceType:
slice, err := parseSlice(val)
if err != nil {
return nil, err
}
valueSlice := make([]driver.Value, len(slice))
for i := range slice {
valueSlice[i], err = parseValueWithType(slice[i])
if err != nil {
return nil, err
}
}
return valueSlice, nil
default:
panic(fmt.Errorf("unsupported type: %v", typ))
}
}
// parsePqError parses a string value that was formatted by formatPqError. This
// is expected to be Postgres wire protocol bytes that encode a Postgres error,
// as a quoted string.
func parsePqError(val string) (interface{}, error) {
unquoted, err := strconv.Unquote(val)
if err != nil {
return nil, err
}
var resp pgproto3.ErrorResponse
if err = resp.Decode([]byte(unquoted)); err != nil {
return nil, err
}
return &pq.Error{
Severity: resp.Severity,
Code: pq.ErrorCode(resp.Code),
Message: resp.Message,
Detail: resp.Detail,
Hint: resp.Hint,
Position: strconv.Itoa(int(resp.Position)),
InternalPosition: strconv.Itoa(int(resp.InternalPosition)),
InternalQuery: resp.InternalQuery,
Where: resp.Where,
Schema: resp.SchemaName,
Table: resp.TableName,
Column: resp.ColumnName,
DataTypeName: resp.DataTypeName,
Constraint: resp.ConstraintName,
File: resp.File,
Line: strconv.Itoa(int(resp.Line)),
Routine: resp.Routine,
}, nil
}
// parsePgConnError parses a string value that was formatted by
// formatPgConnError. This is expected to be Postgres wire protocol bytes that
// encode a Postgres error, as a quoted string.
func parsePgConnError(val string) (interface{}, error) {
unquoted, err := strconv.Unquote(val)
if err != nil {
return nil, err
}
var resp pgproto3.ErrorResponse
if err = resp.Decode([]byte(unquoted)); err != nil {
return nil, err
}
return pgconn.ErrorResponseToPgError(&resp), nil
}
// deepCopyValue makes a deep copy of the given value. It is used to ensure that
// recorded values are immutable, and will never be updated by the application
// or driver. One case where this can happen is with driver.Rows.Next, where the
// storage for output values can be reused across calls to Next.
func deepCopyValue(val interface{}) interface{} {
switch t := val.(type) {
case []string:
return append([]string{}, t...)
case []uint8:
return append([]uint8{}, t...)
case []driver.Value:
newValues := make([]driver.Value, len(t))
for i := range t {
newValues[i] = deepCopyValue(t[i])
}
return newValues
default:
// Most types don't need special handling.
return t
}
}
// splitString is a wrapper around strings.Split that returns an empty slice in
// the case where the input string is empty. strings.Split returns a slice with
// one empty string instead.
func splitString(s, sep string) []string {
if s == "" {
return []string{}
}
return strings.Split(s, sep)
}
// parseSlice is a simple parser that handles nested slice declarations of the
// form:
//
// ["foo", ["bar", 55], "baz"]
//
// It returns a slice of strings representing the "top-level" strings in the
// slice, equivalent to this:
//
// []string{"foo", `["bar", 55]`, "baz"}
//
// Tokenization of the input string is done according to Golang rules.
func parseSlice(s string) ([]string, error) {
// Trim leading and trailing brackets.
if len(s) < 2 || s[0] != '[' || s[len(s)-1] != ']' {
return nil, fmt.Errorf("invalid slice format: %s", s)
}
s = s[1 : len(s)-1]
if len(s) == 0 {
// Slice is empty.
return []string{}, nil
}
// Tokenize comma-delimited list.
var buf bytes.Buffer
var tokens []string
var scan scanner.Scanner
scan.Init(strings.NewReader(s))
scan.Mode = scanner.ScanStrings
nesting := 0
for {
tok := scan.Scan()
switch tok {
case ',':
if nesting == 0 {
tokens = append(tokens, buf.String())
buf.Reset()
continue
}
case '[':
nesting++
case ']':
if nesting == 0 {
return nil, fmt.Errorf("mismatched brackets: %s", s)
}
nesting--
case scanner.EOF:
if nesting != 0 {
return nil, fmt.Errorf("mismatched brackets: %s", s)
}
tokens = append(tokens, buf.String())
return tokens, nil
}
buf.WriteString(scan.TokenText())
}
}
// stringToInt32OrZero converts the given string into an int32 value, or returns
// zero if it cannot (typically when string is empty).
func stringToInt32OrZero(val string) int32 {
pos, err := strconv.ParseInt(val, 10, 32)
if err != nil {
return 0
}
return int32(pos)
}