diff --git a/AUTHORS b/AUTHORS index 876b2964a..cbd284e40 100644 --- a/AUTHORS +++ b/AUTHORS @@ -95,6 +95,7 @@ Tan Jinhua <312841925 at qq.com> Thomas Wodarek Tim Ruffles Tom Jenkinson +Tzu-Chiao Yeh Vladimir Kovpak Vladyslav Zhelezniak Xiangyu Hu diff --git a/README.md b/README.md index ded6e3b16..a071b26aa 100644 --- a/README.md +++ b/README.md @@ -399,6 +399,7 @@ Examples: * `autocommit=1`: `SET autocommit=1` * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` * [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` + * metata=none`](https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_resultset_metadata): `SET resultset_metadata=none` (note that this is only applicable to MySQL 8.0+ versions). #### Examples diff --git a/connection.go b/connection.go index 835f89729..7038a7d8e 100644 --- a/connection.go +++ b/connection.go @@ -21,20 +21,22 @@ import ( ) type mysqlConn struct { - buf buffer - netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - affectedRows uint64 - insertId uint64 - cfg *Config - maxAllowedPacket int - maxWriteSize int - writeTimeout time.Duration - flags clientFlag - status statusFlag - sequence uint8 - parseTime bool - reset bool // set when the Go SQL package calls ResetSession + buf buffer + netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. + affectedRows uint64 + insertId uint64 + cfg *Config + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + parseTime bool + reset bool // set when the Go SQL package calls ResetSession + optionalResultSetMetadata bool + resultSetMetadata uint8 // for context support (Go 1.8+) watching bool @@ -392,6 +394,10 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) } } + if mc.optionalResultSetMetadata && mc.resultSetMetadata == resultSetMetadataNone { + return mc.readIgnoreColumns(rows, resLen) + } + // Columns rows.rs.columns, err = mc.readColumns(resLen) return rows, err @@ -400,6 +406,21 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) return nil, mc.markBadConn(err) } +func (mc *mysqlConn) readIgnoreColumns(rows *textRows, resLen int) (*textRows, error) { + data, err := mc.readPacket() + if err != nil { + errLog.Print(err) + return nil, err + } + // Expected an EOF packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + // Set empty columnNames, we will first read these columnNames via rows.Columns(). + rows.rs.columnNames = make([]string, resLen) + return rows, nil + } + return nil, ErrOptionalResultSetMetadataPkt +} + // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { diff --git a/connector.go b/connector.go index d567b4e4f..695cc1fe2 100644 --- a/connector.go +++ b/connector.go @@ -12,6 +12,7 @@ import ( "context" "database/sql/driver" "net" + "strings" ) type connector struct { @@ -88,6 +89,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { plugin = defaultAuthPlugin } + // Set the optionalResultSetMetadata ahead to set the client capability flag. + if resultSetMetadata, ok := mc.cfg.Params["resultset_metadata"]; ok { + upperVal := strings.ToUpper(resultSetMetadata) + switch upperVal { + case resultSetMetadataSysVarNone: + mc.optionalResultSetMetadata = true + mc.resultSetMetadata = resultSetMetadataNone + case resultSetMetadataSysVarFull: + mc.optionalResultSetMetadata = true + mc.resultSetMetadata = resultSetMetadataFull + } + // To be consistent with other params, in case the param is passed wrongly still send to MySQL to let the server side rejects it. + } + // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) if err != nil { diff --git a/const.go b/const.go index b1e6b85ef..b4de1eaa1 100644 --- a/const.go +++ b/const.go @@ -56,6 +56,7 @@ const ( clientCanHandleExpiredPasswords clientSessionTrack clientDeprecateEOF + clientOptionalResultSetMetadata ) const ( @@ -172,3 +173,16 @@ const ( cachingSha2PasswordFastAuthSuccess = 3 cachingSha2PasswordPerformFullAuthentication = 4 ) + +const ( + // One-byte metadata flag + // https://dev.mysql.com/worklog/task/?id=8134 + resultSetMetadataNone uint8 = iota + resultSetMetadataFull +) + +const ( + // ResultSet Metadata system var + resultSetMetadataSysVarNone = "NONE" + resultSetMetadataSysVarFull = "FULL" +) diff --git a/driver_test.go b/driver_test.go index 4850498d0..d2e12a01c 100644 --- a/driver_test.go +++ b/driver_test.go @@ -44,6 +44,7 @@ var ( prot string addr string dbname string + vendor string dsn string netAddr string available bool @@ -202,6 +203,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) func maybeSkip(t *testing.T, err error, skipErrno uint16) { mySQLErr, ok := err.(*MySQLError) if !ok { + errLog.Print("non match") return } @@ -1345,6 +1347,49 @@ func TestFoundRows(t *testing.T) { }) } +func TestOptionalResultSetMetadata(t *testing.T) { + runTests(t, dsn+"&resultset_metadata=none", func(dbt *DBTest) { + _, err := dbt.db.Exec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + if err == ErrNoOptionalResultMetadataSet { + t.Skip("server does not support resultset metadata") + } else if err != nil { + dbt.Fatal(err) + } + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + row := dbt.db.QueryRow("SELECT id, data FROM test WHERE id = 1") + id, data := 0, 0 + err = row.Scan(&id, &data) + if err != nil { + dbt.Fatal(err) + } + + if id != 1 && data != 0 { + dbt.Fatal("invalid result") + } + }) + runTests(t, dsn+"&resultset_metadata=full", func(dbt *DBTest) { + _, err := dbt.db.Exec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + if err == ErrNoOptionalResultMetadataSet { + t.Skip("server does not support resultset metadata") + } else if err != nil { + dbt.Fatal(err) + } + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + row := dbt.db.QueryRow("SELECT id, data FROM test WHERE id = 1") + id, data := 0, 0 + err = row.Scan(&id, &data) + if err != nil { + dbt.Fatal(err) + } + + if id != 1 && data != 0 { + dbt.Fatal("invalid result") + } + }) +} + func TestTLS(t *testing.T) { tlsTestReq := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { diff --git a/dsn.go b/dsn.go index a306d66a3..399b8cb8d 100644 --- a/dsn.go +++ b/dsn.go @@ -34,22 +34,22 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin diff --git a/dsn_test.go b/dsn_test.go index fc6eea9c8..9d0c50442 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -44,6 +44,9 @@ var testDSNs = []struct { }, { "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, +}, { + "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, diff --git a/errors.go b/errors.go index 92cc9a361..26e393de1 100644 --- a/errors.go +++ b/errors.go @@ -17,18 +17,20 @@ import ( // Various errors the driver might return. Can change between driver versions. var ( - ErrInvalidConn = errors.New("invalid connection") - ErrMalformPkt = errors.New("malformed packet") - ErrNoTLS = errors.New("TLS requested but server does not support TLS") - ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") - ErrNativePassword = errors.New("this user requires mysql native password authentication.") - ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") - ErrUnknownPlugin = errors.New("this authentication plugin is not supported") - ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") - ErrPktSync = errors.New("commands out of sync. You can't run this command now") - ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") - ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") - ErrBusyBuffer = errors.New("busy buffer") + ErrInvalidConn = errors.New("invalid connection") + ErrMalformPkt = errors.New("malformed packet") + ErrNoTLS = errors.New("TLS requested but server does not support TLS") + ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") + ErrNativePassword = errors.New("this user requires mysql native password authentication") + ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") + ErrUnknownPlugin = errors.New("this authentication plugin is not supported") + ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") + ErrPktSync = errors.New("commands out of sync. You can't run this command now") + ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrBusyBuffer = errors.New("busy buffer") + ErrNoOptionalResultMetadataSet = errors.New("requested optional resultset metadata but server does not support") + ErrOptionalResultSetMetadataPkt = errors.New("malformed optional resultset metadata packets") // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn diff --git a/packets.go b/packets.go index ab30601ae..e37aea510 100644 --- a/packets.go +++ b/packets.go @@ -234,10 +234,18 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 1 + 2 // capability flags (upper 2 bytes) [2 bytes] + upperFlags := clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + mc.flags |= upperFlags << 16 + pos += 2 + if mc.flags&clientOptionalResultSetMetadata == 0 && mc.optionalResultSetMetadata { + return nil, "", ErrNoOptionalResultMetadataSet + } + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += 1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -300,6 +308,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientFlags |= clientMultiStatements } + if mc.optionalResultSetMetadata { + clientFlags |= clientOptionalResultSetMetadata + } + // encode length of the auth plugin data var authRespLEIBuf [9]byte authRespLen := len(authResp) @@ -554,6 +566,17 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { return int(num), nil } + // Sniff one extra byte for resultset metadata if we set capability + // CLIENT_OPTIONAL_RESULTSET_METADTA + // https://dev.mysql.com/worklog/task/?id=8134 + if len(data) == 2 && mc.flags&clientOptionalResultSetMetadata != 0 { + // ResultSet metadata flag check + if mc.resultSetMetadata != data[1] { + return 0, ErrOptionalResultSetMetadataPkt + } + return int(num), nil + } + return 0, ErrMalformPkt } return 0, err