diff --git a/go/vt/vitessdriver/rows.go b/go/vt/vitessdriver/rows.go index a2438bb891c..1af88e64ec3 100644 --- a/go/vt/vitessdriver/rows.go +++ b/go/vt/vitessdriver/rows.go @@ -119,3 +119,80 @@ func (ri *rows) ColumnTypeScanType(index int) reflect.Type { return typeUnknown } } + +func (ri *rows) ColumnTypeDatabaseTypeName(index int) string { + field := ri.qr.Fields[index] + switch field.GetType() { + case query.Type_INT8: + return "TINYINT" + case query.Type_UINT8: + return "UNSIGNED TINYINT" + case query.Type_INT16: + return "SMALLINT" + case query.Type_UINT16: + return "UNSIGNED SMALLINT" + case query.Type_YEAR: + return "YEAR" + case query.Type_INT24: + return "MEDIUMINT" + case query.Type_UINT24: + return "UNSIGNED MEDIUMINT" + case query.Type_INT32: + return "INT" + case query.Type_UINT32: + return "UNSIGNED INT" + case query.Type_INT64: + return "BIGINT" + case query.Type_UINT64: + return "UNSIGNED BIGINT" + case query.Type_FLOAT32: + return "FLOAT" + case query.Type_FLOAT64: + return "DOUBLE" + case query.Type_DECIMAL: + return "DECIMAL" + case query.Type_VARCHAR: + return "VARCHAR" + case query.Type_TEXT: + return "TEXT" + case query.Type_BLOB: + return "BLOB" + case query.Type_VARBINARY: + return "VARBINARY" + case query.Type_CHAR: + return "CHAR" + case query.Type_BINARY: + return "BINARY" + case query.Type_BIT: + return "BIT" + case query.Type_ENUM: + return "ENUM" + case query.Type_SET: + return "SET" + case query.Type_HEXVAL: + return "VARBINARY" + case query.Type_HEXNUM: + return "VARBINARY" + case query.Type_BITNUM: + return "VARBINARY" + case query.Type_GEOMETRY: + return "GEOMETRY" + case query.Type_JSON: + return "JSON" + case query.Type_TIMESTAMP: + return "TIMESTAMP" + case query.Type_DATE: + return "DATE" + case query.Type_TIME: + return "TIME" + case query.Type_DATETIME: + return "DATETIME" + default: + return "" + } +} + +func (ri *rows) ColumnTypeNullable(index int) (nullable, ok bool) { + field := ri.qr.Fields[index] + return field.GetFlags()&uint32(query.MySqlFlag_NOT_NULL_FLAG) == 0, true +} diff --git a/go/vt/vitessdriver/rows_test.go b/go/vt/vitessdriver/rows_test.go index 13584e70dd8..bb196da30c3 100644 --- a/go/vt/vitessdriver/rows_test.go +++ b/go/vt/vitessdriver/rows_test.go @@ -226,3 +226,123 @@ func TestColumnTypeScanType(t *testing.T) { assert.Equal(t, ri.ColumnTypeScanType(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeScanType(i), wantTypes[i])) } } + +// Test that the ColumnTypeScanType function returns the correct reflection type for each +// sql type. The sql type in turn comes from a table column's type. +func TestColumnTypeDatabaseTypeName(t *testing.T) { + var r = sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "field1", + Type: sqltypes.Int8, + }, + { + Name: "field2", + Type: sqltypes.Uint8, + }, + { + Name: "field3", + Type: sqltypes.Int16, + }, + { + Name: "field4", + Type: sqltypes.Uint16, + }, + { + Name: "field5", + Type: sqltypes.Int24, + }, + { + Name: "field6", + Type: sqltypes.Uint24, + }, + { + Name: "field7", + Type: sqltypes.Int32, + }, + { + Name: "field8", + Type: sqltypes.Uint32, + }, + { + Name: "field9", + Type: sqltypes.Int64, + }, + { + Name: "field10", + Type: sqltypes.Uint64, + }, + { + Name: "field11", + Type: sqltypes.Float32, + }, + { + Name: "field12", + Type: sqltypes.Float64, + }, + { + Name: "field13", + Type: sqltypes.VarBinary, + }, + { + Name: "field14", + Type: sqltypes.Datetime, + }, + }, + } + + ri := newRows(&r, &converter{}).(driver.RowsColumnTypeDatabaseTypeName) + defer ri.Close() + + wantTypes := []string{ + "TINYINT", + "UNSIGNED TINYINT", + "SMALLINT", + "UNSIGNED SMALLINT", + "MEDIUMINT", + "UNSIGNED MEDIUMINT", + "INT", + "UNSIGNED INT", + "BIGINT", + "UNSIGNED BIGINT", + "FLOAT", + "DOUBLE", + "VARBINARY", + "DATETIME", + } + + for i := 0; i < len(wantTypes); i++ { + assert.Equal(t, ri.ColumnTypeDatabaseTypeName(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeDatabaseTypeName(i), wantTypes[i])) + } +} + +// Test that the ColumnTypeScanType function returns the correct reflection type for each +// sql type. The sql type in turn comes from a table column's type. +func TestColumnTypeNullable(t *testing.T) { + var r = sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "field1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG), + }, + { + Name: "field2", + Type: sqltypes.Int64, + }, + }, + } + + ri := newRows(&r, &converter{}).(driver.RowsColumnTypeNullable) + defer ri.Close() + + nullable := []bool{ + false, + true, + } + + for i := 0; i < len(nullable); i++ { + null, _ := ri.ColumnTypeNullable(i) + assert.Equal(t, null, nullable[i], fmt.Sprintf("unexpected type %v, wanted %v", null, nullable[i])) + } +}