From b13191feccf2eedd8497f6e5f1ee548201f6fdaf Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Thu, 6 Jun 2024 07:29:20 -0400 Subject: [PATCH] Additional Driver args for compression and connection read/write timeouts (#885) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * allow setting the collation in auth handshake * Allow connect with context in order to provide configurable connect timeouts * add driver arguments * check for empty ssl value when setting conn options * allow setting the collation in auth handshake (#860) * Allow connect with context in order to provide configurable connect timeouts * support collations IDs greater than 255 on the auth handshake --------- Co-authored-by: dvilaverde * refactored and added more driver args * revert change to Makefile * added tests for timeouts * adding more tests * fixing linting issues * avoiding panic on test complete * revert returning set readtimeout error in binlogsyncer * fixing nil violation when connection with timeout from binlogsyncer * Update README.md Co-authored-by: Daniël van Eeden * addressing pull request feedback * revert rename driver arg ssl to tls * addressing PR feedback * write compressed packet using writeWithTimeout * updated README.md --------- Co-authored-by: dvilaverde Co-authored-by: Daniël van Eeden --- README.md | 89 +++++++++++++ canal/canal.go | 7 +- client/client_test.go | 20 +-- client/conn.go | 42 ++++-- client/conn_test.go | 3 +- client/pool.go | 2 +- client/pool_options.go | 4 +- driver/driver.go | 89 ++++++++++--- driver/driver_options.go | 50 ++++++++ driver/driver_options_test.go | 233 ++++++++++++++++++++++++++++++++++ driver/driver_test.go | 16 ++- mysql/error.go | 2 +- packet/conn.go | 44 ++++++- replication/backup_test.go | 3 +- replication/binlogsyncer.go | 5 +- 15 files changed, 551 insertions(+), 58 deletions(-) create mode 100644 driver/driver_options.go create mode 100644 driver/driver_options_test.go diff --git a/README.md b/README.md index 926c1532d..e108469e0 100644 --- a/README.md +++ b/README.md @@ -360,6 +360,95 @@ func main() { } ``` +### Driver Options + +Configuration options can be provided by the standard DSN (Data Source Name). + +``` +[user[:password]@]addr[/db[?param=X]] +``` + +#### `compress` + +Enable compression between the client and the server. Valid values are 'zstd','zlib','uncompressed'. + +| Type | Default | Example | +| --------- | ------------- | --------------------------------------- | +| string | uncompressed | user:pass@localhost/mydb?compress=zlib | + +#### `readTimeout` + +I/O read timeout. The time unit is specified in the argument value using +golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. + +0 means no timeout. + +| Type | Default | Example | +| --------- | --------- | ------------------------------------------- | +| duration | 0 | user:pass@localhost/mydb?readTimeout=10s | + +#### `ssl` + +Enable TLS between client and server. Valid values are `true` or `custom`. When using `custom`, +the connection will use the TLS configuration set by SetCustomTLSConfig matching the host. + +| Type | Default | Example | +| --------- | --------- | ------------------------------------------- | +| string | | user:pass@localhost/mydb?ssl=true | + +#### `timeout` + +Timeout is the maximum amount of time a dial will wait for a connect to complete. +The time unit is specified in the argument value using golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. + +0 means no timeout. + +| Type | Default | Example | +| --------- | --------- | ------------------------------------------- | +| duration | 0 | user:pass@localhost/mydb?timeout=1m | + +#### `writeTimeout` + +I/O write timeout. The time unit is specified in the argument value using +golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. + +0 means no timeout. + +| Type | Default | Example | +| --------- | --------- | ----------------------------------------------- | +| duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s | + +### Custom Driver Options + +The driver package exposes the function `SetDSNOptions`, allowing for modification of the +connection by adding custom driver options. +It requires a full import of the driver (not by side-effects only). + +Example of defining a custom option: + +```golang +import ( + "database/sql" + + "github.com/go-mysql-org/go-mysql/driver" +) + +func main() { + driver.SetDSNOptions(map[string]DriverOption{ + "no_metadata": func(c *client.Conn, value string) error { + c.SetCapability(mysql.CLIENT_OPTIONAL_RESULTSET_METADATA) + return nil + }, + }) + + // dsn format: "user:password@addr/dbname?" + dsn := "root@127.0.0.1:3306/test?no_metadata=true" + db, _ := sql.Open(dsn) + db.Close() +} +``` + + We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) ## Donate diff --git a/canal/canal.go b/canal/canal.go index 0108f3c27..20e09952e 100644 --- a/canal/canal.go +++ b/canal/canal.go @@ -499,7 +499,7 @@ func (c *Canal) prepareSyncer() error { return nil } -func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) { +func (c *Canal) connect(options ...client.Option) (*client.Conn, error) { ctx, cancel := context.WithTimeout(c.ctx, time.Second*10) defer cancel() @@ -511,10 +511,11 @@ func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) { func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err error) { c.connLock.Lock() defer c.connLock.Unlock() - argF := make([]func(*client.Conn), 0) + argF := make([]client.Option, 0) if c.cfg.TLSConfig != nil { - argF = append(argF, func(conn *client.Conn) { + argF = append(argF, func(conn *client.Conn) error { conn.SetTLSConfig(c.cfg.TLSConfig) + return nil }) } diff --git a/client/client_test.go b/client/client_test.go index aaf72ff42..3917db3f5 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -31,10 +31,10 @@ func TestClientSuite(t *testing.T) { func (s *clientTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { // test the collation logic, but this is essentially a no-op since // the collation set is the default value - _ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) + return conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) }) require.NoError(s.T(), err) @@ -91,8 +91,9 @@ func (s *clientTestSuite) TestConn_Ping() { func (s *clientTestSuite) TestConn_Compress() { addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { conn.SetCapability(mysql.CLIENT_COMPRESS) + return nil }) require.NoError(s.T(), err) @@ -142,8 +143,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() { // Verify that the provided tls.Config is used when attempting to connect to mysql. // An empty tls.Config will result in a connection error. addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error { c.UseSSL(false) + return nil }) expected := "either ServerName or InsecureSkipVerify must be specified in the tls.Config" @@ -153,8 +155,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() { func (s *clientTestSuite) TestConn_TLS_Skip_Verify() { // An empty tls.Config will result in a connection error but we can configure to skip it. addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error { c.UseSSL(true) + return nil }) require.NoError(s.T(), err) } @@ -165,8 +168,9 @@ func (s *clientTestSuite) TestConn_TLS_Certificate() { // "x509: certificate is valid for MySQL_Server_8.0.12_Auto_Generated_Server_Certificate, not not-a-valid-name" tlsConfig := NewClientTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, false, "not-a-valid-name") addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error { c.SetTLSConfig(tlsConfig) + return nil }) require.Error(s.T(), err) if !strings.Contains(errors.ErrorStack(err), "certificate is not valid for any names") && @@ -251,9 +255,9 @@ func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { func (s *clientTestSuite) TestConn_SetCollation() { addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { // test the collation logic - _ = conn.SetCollation("invalid_collation") + return conn.SetCollation("invalid_collation") }) require.Error(s.T(), err) diff --git a/client/conn.go b/client/conn.go index bef9b2de9..c7be06b85 100644 --- a/client/conn.go +++ b/client/conn.go @@ -18,6 +18,8 @@ import ( "github.com/go-mysql-org/go-mysql/utils" ) +type Option func(*Conn) error + type Conn struct { *packet.Conn @@ -27,6 +29,10 @@ type Conn struct { tlsConfig *tls.Config proto string + // Connection read and write timeouts to set on the connection + ReadTimeout time.Duration + WriteTimeout time.Duration + serverVersion string // server capabilities capability uint32 @@ -66,16 +72,18 @@ func getNetProto(addr string) string { // Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock. // Accepts a series of configuration functions as a variadic argument. -func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() +func Connect(addr, user, password, dbName string, options ...Option) (*Conn, error) { + return ConnectWithTimeout(addr, user, password, dbName, time.Second*10, options...) +} - return ConnectWithContext(ctx, addr, user, password, dbName, options...) +// ConnectWithTimeout to a MySQL address using a timeout. +func ConnectWithTimeout(addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) { + return ConnectWithContext(context.Background(), addr, user, password, dbName, time.Second*10, options...) } // ConnectWithContext to a MySQL addr using the provided context. -func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { - dialer := &net.Dialer{} +func ConnectWithContext(ctx context.Context, addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) { + dialer := &net.Dialer{Timeout: timeout} return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...) } @@ -83,7 +91,7 @@ func ConnectWithContext(ctx context.Context, addr string, user string, password type Dialer func(ctx context.Context, network, address string) (net.Conn, error) // ConnectWithDialer to a MySQL server using the given Dialer. -func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) { +func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { c := new(Conn) c.attributes = map[string]string{ @@ -108,23 +116,28 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st c.password = password c.db = dbName c.proto = network - c.Conn = packet.NewConn(conn) // use default charset here, utf-8 c.charset = DEFAULT_CHARSET // Apply configuration functions. - for i := range options { - options[i](c) + for _, option := range options { + if err := option(c); err != nil { + // must close the connection in the event the provided configuration is not valid + _ = conn.Close() + return nil, err + } } + c.Conn = packet.NewConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout) if c.tlsConfig != nil { seq := c.Conn.Sequence - c.Conn = packet.NewTLSConn(conn) + c.Conn = packet.NewTLSConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout) c.Conn.Sequence = seq } if err = c.handshake(); err != nil { + // in the event of an error c.handshake() will close the connection return nil, errors.Trace(err) } @@ -139,11 +152,13 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st if len(c.collation) != 0 { collation, err := charset.GetCollationByName(c.collation) if err != nil { + c.Close() return nil, errors.Trace(fmt.Errorf("invalid collation name %s", c.collation)) } if collation.ID > 255 { if _, err := c.exec(fmt.Sprintf("SET NAMES %s COLLATE %s", c.charset, c.collation)); err != nil { + c.Close() return nil, errors.Trace(err) } } @@ -206,6 +221,11 @@ func (c *Conn) UnsetCapability(cap uint32) { c.ccaps &= ^cap } +// HasCapability returns true if the connection has the specific capability +func (c *Conn) HasCapability(cap uint32) bool { + return c.ccaps&cap > 0 +} + // UseSSL: use default SSL // pass to options when connect func (c *Conn) UseSSL(insecureSkipVerify bool) { diff --git a/client/conn_test.go b/client/conn_test.go index e2091d50e..55ea973d6 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -28,10 +28,11 @@ func TestConnSuite(t *testing.T) { func (s *connTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) { + s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) error { // required for the ExecuteMultiple test c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS) c.SetAttributes(map[string]string{"attrtest": "attrvalue"}) + return nil }) require.NoError(s.T(), err) diff --git a/client/pool.go b/client/pool.go index 91341a537..6e5d6dc21 100644 --- a/client/pool.go +++ b/client/pool.go @@ -166,7 +166,7 @@ func NewPool( user string, password string, dbName string, - options ...func(conn *Conn), + options ...Option, ) *Pool { pool, err := NewPoolWithOptions( addr, diff --git a/client/pool_options.go b/client/pool_options.go index f47b00716..90bf5bd0d 100644 --- a/client/pool_options.go +++ b/client/pool_options.go @@ -17,7 +17,7 @@ type ( password string dbName string - connOptions []func(conn *Conn) + connOptions []Option newPoolPingTimeout time.Duration } @@ -46,7 +46,7 @@ func WithLogFunc(f LogFunc) PoolOption { } } -func WithConnOptions(options ...func(conn *Conn)) PoolOption { +func WithConnOptions(options ...Option) PoolOption { return func(o *poolOptions) { o.connOptions = append(o.connOptions, options...) } diff --git a/driver/driver.go b/driver/driver.go index b86c4b374..8f132d2b3 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -11,6 +11,7 @@ import ( "net/url" "regexp" "sync" + "time" "github.com/go-mysql-org/go-mysql/client" "github.com/go-mysql-org/go-mysql/mysql" @@ -21,7 +22,10 @@ import ( var customTLSMutex sync.Mutex // Map of dsn address (makes more sense than full dsn?) to tls Config -var customTLSConfigMap = make(map[string]*tls.Config) +var ( + customTLSConfigMap = make(map[string]*tls.Config) + options = make(map[string]DriverOption) +) type driver struct { } @@ -92,26 +96,52 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { } if ci.standardDSN { - if ci.params["ssl"] != nil { - tlsConfigName := ci.params.Get("ssl") - switch tlsConfigName { - case "true": - // This actually does insecureSkipVerify - // But not even sure if it makes sense to handle false? According to - // client_test.go it doesn't - it'd result in an error - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.UseSSL(true) }) - case "custom": - // I was too concerned about mimicking what go-sql-driver/mysql does which will - // allow any name for a custom tls profile and maps the query parameter value to - // that TLSConfig variable... there is no need to be that clever. - // Instead of doing that, let's store required custom TLSConfigs in a map that - // uses the DSN address as the key - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[ci.addr]) }) - default: - return nil, errors.Errorf("Supported options are ssl=true or ssl=custom") + var timeout time.Duration + configuredOptions := make([]client.Option, 0, len(ci.params)) + for key, value := range ci.params { + if key == "ssl" && len(value) > 0 { + tlsConfigName := value[0] + switch tlsConfigName { + case "true": + // This actually does insecureSkipVerify + // But not even sure if it makes sense to handle false? According to + // client_test.go it doesn't - it'd result in an error + configuredOptions = append(configuredOptions, UseSslOption) + case "custom": + // I was too concerned about mimicking what go-sql-driver/mysql does which will + // allow any name for a custom tls profile and maps the query parameter value to + // that TLSConfig variable... there is no need to be that clever. + // Instead of doing that, let's store required custom TLSConfigs in a map that + // uses the DSN address as the key + configuredOptions = append(configuredOptions, func(c *client.Conn) error { + c.SetTLSConfig(customTLSConfigMap[ci.addr]) + return nil + }) + default: + return nil, errors.Errorf("Supported options are ssl=true or ssl=custom") + } + } else if key == "timeout" && len(value) > 0 { + if timeout, err = time.ParseDuration(value[0]); err != nil { + return nil, errors.Wrap(err, "invalid duration value for timeout option") + } + } else { + if option, ok := options[key]; ok { + opt := func(o DriverOption, v string) client.Option { + return func(c *client.Conn) error { + return o(c, v) + } + }(option, value[0]) + configuredOptions = append(configuredOptions, opt) + } else { + return nil, errors.Errorf("unsupported connection option: %s", key) + } } + } + + if timeout > 0 { + c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...) } else { - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db) + c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...) } } else { // No more processing here. Let's only support url parameters with the newer style DSN @@ -296,6 +326,11 @@ func (r *rows) Next(dest []sqldriver.Value) error { } func init() { + options["compress"] = CompressOption + options["collation"] = CollationOption + options["readTimeout"] = ReadTimeoutOption + options["writeTimeout"] = WriteTimeoutOption + sql.Register("mysql", driver{}) } @@ -324,3 +359,19 @@ func SetCustomTLSConfig(dsn string, caPem []byte, certPem []byte, keyPem []byte, return nil } + +// SetDSNOptions sets custom options to the driver that allows modifications to the connection. +// It requires a full import of the driver (not by side-effects only). +// Example of supplying a custom option: +// +// driver.SetDSNOptions(map[string]DriverOption{ +// "my_option": func(c *client.Conn, value string) error { +// c.SetCapability(mysql.CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS) +// return nil +// }, +// }) +func SetDSNOptions(customOptions map[string]DriverOption) { + for o, f := range customOptions { + options[o] = f + } +} diff --git a/driver/driver_options.go b/driver/driver_options.go new file mode 100644 index 000000000..605e68f81 --- /dev/null +++ b/driver/driver_options.go @@ -0,0 +1,50 @@ +package driver + +import ( + "time" + + "github.com/go-mysql-org/go-mysql/client" + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/errors" +) + +// DriverOption sets configuration on a client connection before the MySQL handshake. +// The value represents the query string parameter value supplied by in the DNS. +type DriverOption func(c *client.Conn, value string) error + +func UseSslOption(c *client.Conn) error { + c.UseSSL(true) + return nil +} + +func CollationOption(c *client.Conn, value string) error { + return c.SetCollation(value) +} + +func ReadTimeoutOption(c *client.Conn, value string) error { + var err error + c.ReadTimeout, err = time.ParseDuration(value) + return errors.Wrap(err, "invalid duration value for readTimeout option") +} + +func WriteTimeoutOption(c *client.Conn, value string) error { + var err error + c.WriteTimeout, err = time.ParseDuration(value) + return errors.Wrap(err, "invalid duration value for writeTimeout option") +} + +func CompressOption(c *client.Conn, value string) error { + switch value { + case "zlib": + c.SetCapability(mysql.CLIENT_COMPRESS) + case "zstd": + c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM) + case "uncompressed": + c.UnsetCapability(mysql.CLIENT_COMPRESS) + c.UnsetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM) + default: + return errors.Errorf("invalid compression algorithm '%s', valid values are 'zstd','zlib','uncompressed'", value) + } + + return nil +} diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go new file mode 100644 index 000000000..32431932a --- /dev/null +++ b/driver/driver_options_test.go @@ -0,0 +1,233 @@ +package driver + +import ( + "context" + "database/sql" + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/go-mysql-org/go-mysql/client" + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/server" + "github.com/pingcap/errors" + "github.com/siddontang/go/log" + "github.com/stretchr/testify/require" +) + +var _ server.Handler = &mockHandler{} + +type testServer struct { + *server.Server + + listener net.Listener + handler *mockHandler +} + +type mockHandler struct { +} + +func TestDriverOptions_SetCollation(t *testing.T) { + c := &client.Conn{} + err := CollationOption(c, "latin2_bin") + require.NoError(t, err) + require.Equal(t, "latin2_bin", c.GetCollation()) +} + +func TestDriverOptions_SetCompression(t *testing.T) { + var err error + c := &client.Conn{} + err = CompressOption(c, "zlib") + require.NoError(t, err) + require.True(t, c.HasCapability(mysql.CLIENT_COMPRESS)) + + err = CompressOption(c, "zstd") + require.NoError(t, err) + require.True(t, c.HasCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)) + + err = CompressOption(c, "uncompressed") + require.NoError(t, err) + require.False(t, c.HasCapability(mysql.CLIENT_COMPRESS)) + require.False(t, c.HasCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)) + + require.Error(t, CompressOption(c, "foo")) +} + +func TestDriverOptions_ConnectTimeout(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?timeout=1s") + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from table;") + require.NotNil(t, rows) + require.NoError(t, err) + + conn.Close() +} + +func TestDriverOptions_ReadTimeout(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s") + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from slow;") + require.Nil(t, rows) + require.Error(t, err) + + rows, err = conn.QueryContext(context.TODO(), "select * from fast;") + require.NotNil(t, rows) + require.NoError(t, err) + + conn.Close() +} + +func TestDriverOptions_writeTimeout(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=10") + require.NoError(t, err) + + result, err := conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);") + require.Nil(t, result) + require.Error(t, err) + + conn.Close() +} + +func CreateMockServer(t *testing.T) *testServer { + inMemProvider := server.NewInMemoryProvider() + inMemProvider.AddUser(*testUser, *testPassword) + defaultServer := server.NewDefaultServer() + + l, err := net.Listen("tcp", "127.0.0.1:3307") + require.NoError(t, err) + + handler := &mockHandler{} + + go func() { + for { + conn, err := l.Accept() + if err != nil { + return + } + + go func() { + co, err := server.NewCustomizedConn(conn, defaultServer, inMemProvider, handler) + if err != nil { + return + } + for { + err = co.HandleCommand() + if err != nil { + return + } + } + }() + } + }() + + return &testServer{ + Server: defaultServer, + listener: l, + handler: handler, + } +} + +func (s *testServer) Stop() { + s.listener.Close() +} + +func (h *mockHandler) UseDB(dbName string) error { + return nil +} + +func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { + ss := strings.Split(query, " ") + switch strings.ToLower(ss[0]) { + case "select": + var r *mysql.Resultset + var err error + //for handle go mysql driver select @@max_allowed_packet + if strings.Contains(strings.ToLower(query), "max_allowed_packet") { + r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ + {mysql.MaxPayloadLen}, + }, binary) + } else { + if strings.Contains(query, "slow") { + time.Sleep(time.Second * 5) + } + + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ + {1, "hello world"}, + }, binary) + } + + if err != nil { + return nil, errors.Trace(err) + } else { + return &mysql.Result{ + Status: 0, + Warnings: 0, + InsertId: 0, + AffectedRows: 0, + Resultset: r, + }, nil + } + case "insert": + return &mysql.Result{ + Status: 0, + Warnings: 0, + InsertId: 1, + AffectedRows: 0, + Resultset: nil, + }, nil + default: + return nil, fmt.Errorf("invalid query %s", query) + } +} + +func (h *mockHandler) HandleQuery(query string) (*mysql.Result, error) { + return h.handleQuery(query, false) +} + +func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + return nil, nil +} + +func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) { + params = 1 + columns = 0 + return params, columns, nil, nil +} + +func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { + if strings.HasPrefix(strings.ToLower(query), "select") { + return h.HandleQuery(query) + } + + return &mysql.Result{ + Status: 0, + Warnings: 0, + InsertId: 1, + AffectedRows: 0, + Resultset: nil, + }, nil +} + +func (h *mockHandler) HandleStmtClose(context interface{}) error { + return nil +} + +func (h *mockHandler) HandleOtherCommand(cmd byte, data []byte) error { + return nil +} diff --git a/driver/driver_test.go b/driver/driver_test.go index 3f7575613..1c21bd0e3 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -82,12 +82,16 @@ func TestParseDSN(t *testing.T) { // Use different numbered domains to more readily see what has failed - since we // test in a loop we get the same line number on error testDSNs := map[string]connInfo{ - "user:password@localhost?db": {standardDSN: false, addr: "localhost", user: "user", password: "password", db: "db", params: url.Values{}}, - "user@1.domain.com?db": {standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}}, - "user:password@2.domain.com/db": {standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}}, - "user:password@3.domain.com/db?ssl=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}}, - "user:password@4.domain.com/db?ssl=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}}, - "user:password@5.domain.com/db?unused=param": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}}, + "user:password@localhost?db": {standardDSN: false, addr: "localhost", user: "user", password: "password", db: "db", params: url.Values{}}, + "user@1.domain.com?db": {standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}}, + "user:password@2.domain.com/db": {standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}}, + "user:password@3.domain.com/db?ssl=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}}, + "user:password@4.domain.com/db?ssl=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}}, + "user:password@5.domain.com/db?unused=param": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}}, + "user:password@5.domain.com/db?timeout=1s": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"timeout": []string{"1s"}}}, + "user:password@5.domain.com/db?readTimeout=1m": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"readTimeout": []string{"1m"}}}, + "user:password@5.domain.com/db?writeTimeout=1m": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"writeTimeout": []string{"1m"}}}, + "user:password@5.domain.com/db?compress=zlib": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"compress": []string{"zlib"}}}, } for supplied, expected := range testDSNs { diff --git a/mysql/error.go b/mysql/error.go index abda6dea0..e9915779b 100644 --- a/mysql/error.go +++ b/mysql/error.go @@ -61,6 +61,6 @@ func NewError(errCode uint16, message string) *MyError { func ErrorCode(errMsg string) (code int) { var tmpStr string // golang scanf doesn't support %*,so I used a temporary variable - fmt.Sscanf(errMsg, "%s%d", &tmpStr, &code) + _, _ = fmt.Sscanf(errMsg, "%s%d", &tmpStr, &code) return } diff --git a/packet/conn.go b/packet/conn.go index 97860129c..9901e34be 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -11,6 +11,7 @@ import ( goErrors "errors" "io" "net" + "time" "github.com/go-mysql-org/go-mysql/compress" . "github.com/go-mysql-org/go-mysql/mysql" @@ -26,6 +27,9 @@ const DefaultBufferSize = 16 * 1024 type Conn struct { net.Conn + readTimeout time.Duration + writeTimeout time.Duration + // Buffered reader for net.Conn in Non-TLS connection only to address replication performance issue. // See https://github.com/go-mysql-org/go-mysql/pull/422 for more details. br *bufio.Reader @@ -60,6 +64,13 @@ func NewConn(conn net.Conn) *Conn { return c } +func NewConnWithTimeout(conn net.Conn, readTimeout, writeTimeout time.Duration) *Conn { + c := NewConn(conn) + c.readTimeout = readTimeout + c.writeTimeout = writeTimeout + return c +} + func NewTLSConn(conn net.Conn) *Conn { c := new(Conn) c.Conn = conn @@ -71,6 +82,13 @@ func NewTLSConn(conn net.Conn) *Conn { return c } +func NewTLSConnWithTimeout(conn net.Conn, readTimeout, writeTimeout time.Duration) *Conn { + c := NewTLSConn(conn) + c.readTimeout = readTimeout + c.writeTimeout = writeTimeout + return c +} + func (c *Conn) ReadPacket() ([]byte, error) { return c.ReadPacketReuseMem(nil) } @@ -127,6 +145,11 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { // newCompressedPacketReader creates a new compressed packet reader. func (c *Conn) newCompressedPacketReader() (io.Reader, error) { + if c.readTimeout != 0 { + if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { + return nil, err + } + } if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) } @@ -172,6 +195,11 @@ func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) { // Call ReadAtLeast with the currentPacketReader as it may change on every iteration // of this loop. + if c.readTimeout != 0 { + if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { + return written, err + } + } rd, err := io.ReadAtLeast(c.currentPacketReader(), buf, bcap) n -= int64(rd) @@ -265,7 +293,7 @@ func (c *Conn) WritePacket(data []byte) error { data[3] = c.Sequence - if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil { + if n, err := c.writeWithTimeout(data[:4+MaxPayloadLen]); err != nil { return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. err %v", err) } else if n != (4 + MaxPayloadLen) { return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. only %v bytes written, while %v expected", n, 4+MaxPayloadLen) @@ -283,7 +311,7 @@ func (c *Conn) WritePacket(data []byte) error { switch c.Compression { case MYSQL_COMPRESS_NONE: - if n, err := c.Write(data); err != nil { + if n, err := c.writeWithTimeout(data); err != nil { return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) } else if n != len(data) { return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) @@ -310,6 +338,16 @@ func (c *Conn) WritePacket(data []byte) error { return nil } +func (c *Conn) writeWithTimeout(b []byte) (n int, err error) { + if c.writeTimeout != 0 { + if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { + return n, err + } + } + + return c.Write(b) +} + func (c *Conn) writeCompressed(data []byte) (n int, err error) { var ( compressedLength, uncompressedLength int @@ -374,7 +412,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { if err != nil { return 0, err } - if _, err = c.Write(compressedPacket.Bytes()); err != nil { + if _, err = c.writeWithTimeout(compressedPacket.Bytes()); err != nil { return 0, err } diff --git a/replication/backup_test.go b/replication/backup_test.go index 1f77e8e3e..abefd3f8d 100644 --- a/replication/backup_test.go +++ b/replication/backup_test.go @@ -38,7 +38,8 @@ func (t *testSyncerSuite) TestStartBackupEndInGivenTime() { done <- true }() failTimeout := 5 * timeout - ctx, _ := context.WithTimeout(context.Background(), failTimeout) + ctx, cancel := context.WithTimeout(context.Background(), failTimeout) + defer cancel() select { case <-done: return diff --git a/replication/binlogsyncer.go b/replication/binlogsyncer.go index 72a22c45c..39e5749ea 100644 --- a/replication/binlogsyncer.go +++ b/replication/binlogsyncer.go @@ -897,12 +897,13 @@ func (b *BinlogSyncer) newConnection(ctx context.Context) (*client.Conn, error) defer cancel() return client.ConnectWithDialer(timeoutCtx, "", addr, b.cfg.User, b.cfg.Password, - "", b.cfg.Dialer, func(c *client.Conn) { + "", b.cfg.Dialer, func(c *client.Conn) error { c.SetTLSConfig(b.cfg.TLSConfig) c.SetAttributes(map[string]string{"_client_role": "binary_log_listener"}) if b.cfg.ReadTimeout > 0 { - _ = c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) + c.ReadTimeout = b.cfg.ReadTimeout } + return nil }) }