Skip to content

Commit

Permalink
Additional Driver args for compression and connection read/write time…
Browse files Browse the repository at this point in the history
…outs (#885)

* allow setting the collation in auth handshake

* Allow connect with context in order to provide configurable connect timeouts

* add driver arguments

* check for empty ssl value when setting conn options

* allow setting the collation in auth handshake (#860)
* Allow connect with context in order to provide configurable connect timeouts
* support collations IDs greater than 255 on the auth handshake
---------

Co-authored-by: dvilaverde <[email protected]>

* refactored and added more driver args

* revert change to Makefile

* added tests for timeouts

* adding more tests

* fixing linting issues

* avoiding panic on test complete

* revert returning set readtimeout error in binlogsyncer

* fixing nil violation when connection with timeout from binlogsyncer

* Update README.md

Co-authored-by: Daniël van Eeden <[email protected]>

* addressing pull request feedback

* revert rename driver arg ssl to tls

* addressing PR feedback

* write compressed packet using writeWithTimeout

* updated README.md

---------

Co-authored-by: dvilaverde <[email protected]>
Co-authored-by: Daniël van Eeden <[email protected]>
  • Loading branch information
3 people authored Jun 6, 2024
1 parent 6c99b4b commit b13191f
Show file tree
Hide file tree
Showing 15 changed files with 551 additions and 58 deletions.
89 changes: 89 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,95 @@ func main() {
}
```

### Driver Options

Configuration options can be provided by the standard DSN (Data Source Name).

```
[user[:password]@]addr[/db[?param=X]]
```

#### `compress`

Enable compression between the client and the server. Valid values are 'zstd','zlib','uncompressed'.

| Type | Default | Example |
| --------- | ------------- | --------------------------------------- |
| string | uncompressed | user:pass@localhost/mydb?compress=zlib |

#### `readTimeout`

I/O read timeout. The time unit is specified in the argument value using
golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.

0 means no timeout.

| Type | Default | Example |
| --------- | --------- | ------------------------------------------- |
| duration | 0 | user:pass@localhost/mydb?readTimeout=10s |

#### `ssl`

Enable TLS between client and server. Valid values are `true` or `custom`. When using `custom`,
the connection will use the TLS configuration set by SetCustomTLSConfig matching the host.

| Type | Default | Example |
| --------- | --------- | ------------------------------------------- |
| string | | user:pass@localhost/mydb?ssl=true |

#### `timeout`

Timeout is the maximum amount of time a dial will wait for a connect to complete.
The time unit is specified in the argument value using golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.

0 means no timeout.

| Type | Default | Example |
| --------- | --------- | ------------------------------------------- |
| duration | 0 | user:pass@localhost/mydb?timeout=1m |

#### `writeTimeout`

I/O write timeout. The time unit is specified in the argument value using
golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.

0 means no timeout.

| Type | Default | Example |
| --------- | --------- | ----------------------------------------------- |
| duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s |

### Custom Driver Options

The driver package exposes the function `SetDSNOptions`, allowing for modification of the
connection by adding custom driver options.
It requires a full import of the driver (not by side-effects only).

Example of defining a custom option:

```golang
import (
"database/sql"

"github.com/go-mysql-org/go-mysql/driver"
)

func main() {
driver.SetDSNOptions(map[string]DriverOption{
"no_metadata": func(c *client.Conn, value string) error {
c.SetCapability(mysql.CLIENT_OPTIONAL_RESULTSET_METADATA)
return nil
},
})

// dsn format: "user:password@addr/dbname?"
dsn := "[email protected]:3306/test?no_metadata=true"
db, _ := sql.Open(dsn)
db.Close()
}
```


We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-)

## Donate
Expand Down
7 changes: 4 additions & 3 deletions canal/canal.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ func (c *Canal) prepareSyncer() error {
return nil
}

func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) {
func (c *Canal) connect(options ...client.Option) (*client.Conn, error) {
ctx, cancel := context.WithTimeout(c.ctx, time.Second*10)
defer cancel()

Expand All @@ -511,10 +511,11 @@ func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) {
func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err error) {
c.connLock.Lock()
defer c.connLock.Unlock()
argF := make([]func(*client.Conn), 0)
argF := make([]client.Option, 0)
if c.cfg.TLSConfig != nil {
argF = append(argF, func(conn *client.Conn) {
argF = append(argF, func(conn *client.Conn) error {
conn.SetTLSConfig(c.cfg.TLSConfig)
return nil
})
}

Expand Down
20 changes: 12 additions & 8 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ func TestClientSuite(t *testing.T) {
func (s *clientTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
// test the collation logic, but this is essentially a no-op since
// the collation set is the default value
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
return conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
})
require.NoError(s.T(), err)

Expand Down Expand Up @@ -91,8 +91,9 @@ func (s *clientTestSuite) TestConn_Ping() {

func (s *clientTestSuite) TestConn_Compress() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
conn.SetCapability(mysql.CLIENT_COMPRESS)
return nil
})
require.NoError(s.T(), err)

Expand Down Expand Up @@ -142,8 +143,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() {
// Verify that the provided tls.Config is used when attempting to connect to mysql.
// An empty tls.Config will result in a connection error.
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error {
c.UseSSL(false)
return nil
})
expected := "either ServerName or InsecureSkipVerify must be specified in the tls.Config"

Expand All @@ -153,8 +155,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() {
func (s *clientTestSuite) TestConn_TLS_Skip_Verify() {
// An empty tls.Config will result in a connection error but we can configure to skip it.
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error {
c.UseSSL(true)
return nil
})
require.NoError(s.T(), err)
}
Expand All @@ -165,8 +168,9 @@ func (s *clientTestSuite) TestConn_TLS_Certificate() {
// "x509: certificate is valid for MySQL_Server_8.0.12_Auto_Generated_Server_Certificate, not not-a-valid-name"
tlsConfig := NewClientTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, false, "not-a-valid-name")
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error {
c.SetTLSConfig(tlsConfig)
return nil
})
require.Error(s.T(), err)
if !strings.Contains(errors.ErrorStack(err), "certificate is not valid for any names") &&
Expand Down Expand Up @@ -251,9 +255,9 @@ func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {

func (s *clientTestSuite) TestConn_SetCollation() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
// test the collation logic
_ = conn.SetCollation("invalid_collation")
return conn.SetCollation("invalid_collation")
})

require.Error(s.T(), err)
Expand Down
42 changes: 31 additions & 11 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"github.com/go-mysql-org/go-mysql/utils"
)

type Option func(*Conn) error

type Conn struct {
*packet.Conn

Expand All @@ -27,6 +29,10 @@ type Conn struct {
tlsConfig *tls.Config
proto string

// Connection read and write timeouts to set on the connection
ReadTimeout time.Duration
WriteTimeout time.Duration

serverVersion string
// server capabilities
capability uint32
Expand Down Expand Up @@ -66,24 +72,26 @@ func getNetProto(addr string) string {

// Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
// Accepts a series of configuration functions as a variadic argument.
func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
func Connect(addr, user, password, dbName string, options ...Option) (*Conn, error) {
return ConnectWithTimeout(addr, user, password, dbName, time.Second*10, options...)
}

return ConnectWithContext(ctx, addr, user, password, dbName, options...)
// ConnectWithTimeout to a MySQL address using a timeout.
func ConnectWithTimeout(addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) {
return ConnectWithContext(context.Background(), addr, user, password, dbName, time.Second*10, options...)
}

// ConnectWithContext to a MySQL addr using the provided context.
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
dialer := &net.Dialer{}
func ConnectWithContext(ctx context.Context, addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) {
dialer := &net.Dialer{Timeout: timeout}
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
}

// Dialer connects to the address on the named network using the provided context.
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)

// ConnectWithDialer to a MySQL server using the given Dialer.
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) {
c := new(Conn)

c.attributes = map[string]string{
Expand All @@ -108,23 +116,28 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st
c.password = password
c.db = dbName
c.proto = network
c.Conn = packet.NewConn(conn)

// use default charset here, utf-8
c.charset = DEFAULT_CHARSET

// Apply configuration functions.
for i := range options {
options[i](c)
for _, option := range options {
if err := option(c); err != nil {
// must close the connection in the event the provided configuration is not valid
_ = conn.Close()
return nil, err
}
}

c.Conn = packet.NewConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout)
if c.tlsConfig != nil {
seq := c.Conn.Sequence
c.Conn = packet.NewTLSConn(conn)
c.Conn = packet.NewTLSConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout)
c.Conn.Sequence = seq
}

if err = c.handshake(); err != nil {
// in the event of an error c.handshake() will close the connection
return nil, errors.Trace(err)
}

Expand All @@ -139,11 +152,13 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st
if len(c.collation) != 0 {
collation, err := charset.GetCollationByName(c.collation)
if err != nil {
c.Close()
return nil, errors.Trace(fmt.Errorf("invalid collation name %s", c.collation))
}

if collation.ID > 255 {
if _, err := c.exec(fmt.Sprintf("SET NAMES %s COLLATE %s", c.charset, c.collation)); err != nil {
c.Close()
return nil, errors.Trace(err)
}
}
Expand Down Expand Up @@ -206,6 +221,11 @@ func (c *Conn) UnsetCapability(cap uint32) {
c.ccaps &= ^cap
}

// HasCapability returns true if the connection has the specific capability
func (c *Conn) HasCapability(cap uint32) bool {
return c.ccaps&cap > 0
}

// UseSSL: use default SSL
// pass to options when connect
func (c *Conn) UseSSL(insecureSkipVerify bool) {
Expand Down
3 changes: 2 additions & 1 deletion client/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ func TestConnSuite(t *testing.T) {
func (s *connTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) {
s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) error {
// required for the ExecuteMultiple test
c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS)
c.SetAttributes(map[string]string{"attrtest": "attrvalue"})
return nil
})
require.NoError(s.T(), err)

Expand Down
2 changes: 1 addition & 1 deletion client/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func NewPool(
user string,
password string,
dbName string,
options ...func(conn *Conn),
options ...Option,
) *Pool {
pool, err := NewPoolWithOptions(
addr,
Expand Down
4 changes: 2 additions & 2 deletions client/pool_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type (
password string
dbName string

connOptions []func(conn *Conn)
connOptions []Option

newPoolPingTimeout time.Duration
}
Expand Down Expand Up @@ -46,7 +46,7 @@ func WithLogFunc(f LogFunc) PoolOption {
}
}

func WithConnOptions(options ...func(conn *Conn)) PoolOption {
func WithConnOptions(options ...Option) PoolOption {
return func(o *poolOptions) {
o.connOptions = append(o.connOptions, options...)
}
Expand Down
Loading

0 comments on commit b13191f

Please sign in to comment.