From 5427a8d66217a7251b22769cf841916922a2b4f1 Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Fri, 26 Apr 2024 14:16:14 -0400 Subject: [PATCH 01/13] allow setting the collation in auth handshake --- client/auth.go | 12 +++++++++++- client/client_test.go | 21 ++++++++++++++++++++- client/conn.go | 16 ++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/client/auth.go b/client/auth.go index e4fa908d3..7392f8fdd 100644 --- a/client/auth.go +++ b/client/auth.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "github.com/pingcap/tidb/pkg/parser/charset" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" @@ -269,7 +270,16 @@ func (c *Conn) writeAuthHandshake() error { // Charset [1 byte] // use default collation id 33 here, is utf-8 - data[12] = DEFAULT_COLLATION_ID + collationName := c.collation + if len(collationName) == 0 { + collationName = DEFAULT_COLLATION_NAME + } + collation, err := charset.GetCollationByName(collationName) + if err != nil { + return fmt.Errorf("invalid collation name %s", collationName) + } + + data[12] = byte(collation.ID) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest diff --git a/client/client_test.go b/client/client_test.go index c47c795ef..b27c4c669 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) { func (s *clientTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "") + s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic, but this is essentially a no-op since + // the collation set is the default value + _ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) + }) require.NoError(s.T(), err) var result *mysql.Result @@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() { require.NoError(s.T(), err) } +func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { + err := s.c.SetCollation("latin1_swedish_ci") + require.Error(s.T(), err) +} + +func (s *clientTestSuite) TestConn_SetCollation() { + addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic + _ = conn.SetCollation("invalid_collation") + }) + + require.Error(s.T(), err) +} + func (s *clientTestSuite) testStmt_DropTable() { str := `drop table if exists mixer_test_stmt` diff --git a/client/conn.go b/client/conn.go index b1f3e52d1..1db021762 100644 --- a/client/conn.go +++ b/client/conn.go @@ -37,6 +37,8 @@ type Conn struct { status uint16 charset string + // sets the collation to be set on the auth handshake, this does not issue a 'set names' command + collation string salt []byte authPluginName string @@ -357,6 +359,20 @@ func (c *Conn) SetCharset(charset string) error { } } +func (c *Conn) SetCollation(collation string) error { + if c.status == 0 { + c.collation = collation + } else { + return errors.Trace(errors.Errorf("cannot set collation after connection is established")) + } + + return nil +} + +func (c *Conn) GetCollation() string { + return c.collation +} + func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil { return nil, errors.Trace(err) From 10339ddc0ad003634f447d92547195bea02c544d Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Fri, 26 Apr 2024 16:42:52 -0400 Subject: [PATCH 02/13] Allow connect with context in order to provide configurable connect timeouts --- client/conn.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client/conn.go b/client/conn.go index 1db021762..9d7014951 100644 --- a/client/conn.go +++ b/client/conn.go @@ -69,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options . ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - dialer := &net.Dialer{} + return ConnectWithContext(ctx, addr, user, password, dbName, options...) +} +// ConnectWithContext to a MySQL addr using the provided context. +func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { + dialer := &net.Dialer{} return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...) } // Dialer connects to the address on the named network using the provided context. type Dialer func(ctx context.Context, network, address string) (net.Conn, error) -// Connect to a MySQL server using the given Dialer. +// ConnectWithDialer to a MySQL server using the given Dialer. func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) { c := new(Conn) From 2f5a8124e4cb86c0ff98113e0cd118c8550c5185 Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Sat, 27 Apr 2024 17:42:42 -0400 Subject: [PATCH 03/13] fixing linting error --- client/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/auth.go b/client/auth.go index 7392f8fdd..8e5727829 100644 --- a/client/auth.go +++ b/client/auth.go @@ -5,11 +5,11 @@ import ( "crypto/tls" "encoding/binary" "fmt" - "github.com/pingcap/tidb/pkg/parser/charset" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/charset" ) const defaultAuthPluginName = AUTH_NATIVE_PASSWORD From e6de19d622211d20e32fd447bd263aee8853a331 Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Mon, 29 Apr 2024 08:21:28 -0400 Subject: [PATCH 04/13] support collations IDs greater than 255 on the auth handshake --- client/auth.go | 14 ++++++++- client/auth_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/client/auth.go b/client/auth.go index 8e5727829..b9c8d36e2 100644 --- a/client/auth.go +++ b/client/auth.go @@ -279,7 +279,14 @@ func (c *Conn) writeAuthHandshake() error { return fmt.Errorf("invalid collation name %s", collationName) } - data[12] = byte(collation.ID) + // the MySQL protocol calls for the collation id to be sent as 1, where only the + // lower 8 bits are used in this field. But wireshark shows that the first by of + // the 23 bytes of filler is used to send the upper 8 bits of the collation id. + // see https://github.com/mysql/mysql-server/pull/541 + data[12] = byte(collation.ID & 0xff) + if collation.ID > 255 { + data[13] = byte(collation.ID >> 8) + } // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -302,6 +309,11 @@ func (c *Conn) writeAuthHandshake() error { // Filler [23 bytes] (all 0x00) pos := 13 + if collation.ID > 255 { + // skip setting the first byte of the filler to 0x00 since it is used to + // send the upper 8 bits of the collation id + pos++ + } for ; pos < 13+23; pos++ { data[pos] = 0 } diff --git a/client/auth_test.go b/client/auth_test.go index 85dba1e98..e5e451ca5 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -1,6 +1,9 @@ package client import ( + "github.com/go-mysql-org/go-mysql/packet" + "github.com/pingcap/tidb/pkg/parser/charset" + "net" "testing" "github.com/go-mysql-org/go-mysql/mysql" @@ -34,3 +37,73 @@ func TestConnGenAttributes(t *testing.T) { require.Subset(t, data, fixt) } } + +func TestConnCollation(t *testing.T) { + collations := []string{"big5_chinese_ci", + "utf8_general_ci", + "utf8mb4_0900_ai_ci", + "utf8mb4_de_pb_0900_ai_ci", + "utf8mb4_ja_0900_as_cs", + "utf8mb4_0900_bin", + "utf8mb4_zh_pinyin_tidb_as_cs"} + + // test all supported collations by calling writeAuthHandshake() and reading the bytes + // sent to the server to ensure the collation id is set correctly + for _, c := range collations { + collation, err := charset.GetCollationByName(c) + require.NoError(t, err) + server := sendAuthResponse(t, collation.Name) + // read the all the bytes of the handshake response so that client goroutine can complete without blocking + // on the server read. + handShakeResponse := make([]byte, 128) + _, err = server.Read(handShakeResponse) + require.NoError(t, err) + + // validate the collation id is set correctly + // if the collation ID is <= 255 the collation ID is stored in the 12th byte + if collation.ID <= 255 { + require.Equal(t, byte(collation.ID), handShakeResponse[12]) + // sanity check: validate the 23 bytes of filler with value 0x00 are set correctly + for i := 13; i < 13+23; i++ { + require.Equal(t, byte(0x00), handShakeResponse[i]) + } + } else { + // if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes + require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12]) + require.Equal(t, byte(collation.ID>>8), handShakeResponse[13]) + + // sanity check: validate the 22 bytes of filler with value 0x00 are set correctly + for i := 14; i < 14+22; i++ { + require.Equal(t, byte(0x00), handShakeResponse[i]) + } + } + + // and finally the username + password := string(handShakeResponse[36:40]) + require.Equal(t, "test", password) + + require.NoError(t, server.Close()) + } +} + +func sendAuthResponse(t *testing.T, collation string) net.Conn { + server, client := net.Pipe() + c := &Conn{ + Conn: &packet.Conn{ + Conn: client, + }, + authPluginName: "mysql_native_password", + user: "test", + db: "test", + password: "test", + proto: "tcp", + collation: collation, + salt: ([]byte)("123456781234567812345678"), + } + + go func() { + err := c.writeAuthHandshake() + require.NoError(t, err) + }() + return server +} From 54ca96f01cc9c8892f5e94c18f29310e835a015f Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Mon, 29 Apr 2024 08:37:59 -0400 Subject: [PATCH 05/13] Update client/auth.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniƫl van Eeden --- client/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/auth.go b/client/auth.go index b9c8d36e2..e49f26687 100644 --- a/client/auth.go +++ b/client/auth.go @@ -269,7 +269,7 @@ func (c *Conn) writeAuthHandshake() error { data[11] = 0x00 // Charset [1 byte] - // use default collation id 33 here, is utf-8 + // use default collation id 33 here, is `utf8mb3_general_ci` collationName := c.collation if len(collationName) == 0 { collationName = DEFAULT_COLLATION_NAME From c7c5b97ed93b06cdb82aac08c0d39296dd7b746f Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Mon, 29 Apr 2024 09:48:14 -0400 Subject: [PATCH 06/13] address PR feedback --- client/auth.go | 21 ++++++++++----------- client/auth_test.go | 30 +++++++++++++++--------------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/client/auth.go b/client/auth.go index e49f26687..d56f0a972 100644 --- a/client/auth.go +++ b/client/auth.go @@ -281,12 +281,12 @@ func (c *Conn) writeAuthHandshake() error { // the MySQL protocol calls for the collation id to be sent as 1, where only the // lower 8 bits are used in this field. But wireshark shows that the first by of - // the 23 bytes of filler is used to send the upper 8 bits of the collation id. + // the 23 bytes of filler is used to send the right middle 8 bits of the collation id. // see https://github.com/mysql/mysql-server/pull/541 data[12] = byte(collation.ID & 0xff) - if collation.ID > 255 { - data[13] = byte(collation.ID >> 8) - } + // if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of + // padding the filler with a 0. + data[13] = byte((collation.ID & 0xff00) >> 8) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -308,13 +308,12 @@ func (c *Conn) writeAuthHandshake() error { } // Filler [23 bytes] (all 0x00) - pos := 13 - if collation.ID > 255 { - // skip setting the first byte of the filler to 0x00 since it is used to - // send the upper 8 bits of the collation id - pos++ - } - for ; pos < 13+23; pos++ { + // the filler starts at position 13, but the first byte of the filler + // maybe have been set by the collation id earlier. So we only position 13 + // will be either 0x00 or the right middle 8 bits of the collation id. Therefore + // here we start at position 14 and fill the remaining 22 bytes with 0x00. + pos := 14 + for ; pos < 14+22; pos++ { data[pos] = 0 } diff --git a/client/auth_test.go b/client/auth_test.go index e5e451ca5..00efea5cc 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -39,13 +39,15 @@ func TestConnGenAttributes(t *testing.T) { } func TestConnCollation(t *testing.T) { - collations := []string{"big5_chinese_ci", - "utf8_general_ci", - "utf8mb4_0900_ai_ci", - "utf8mb4_de_pb_0900_ai_ci", - "utf8mb4_ja_0900_as_cs", - "utf8mb4_0900_bin", - "utf8mb4_zh_pinyin_tidb_as_cs"} + collations := []string{ + //"big5_chinese_ci", + //"utf8_general_ci", + //"utf8mb4_0900_ai_ci", + //"utf8mb4_de_pb_0900_ai_ci", + //"utf8mb4_ja_0900_as_cs", + //"utf8mb4_0900_bin", + "utf8mb4_zh_pinyin_tidb_as_cs", + } // test all supported collations by calling writeAuthHandshake() and reading the bytes // sent to the server to ensure the collation id is set correctly @@ -63,19 +65,17 @@ func TestConnCollation(t *testing.T) { // if the collation ID is <= 255 the collation ID is stored in the 12th byte if collation.ID <= 255 { require.Equal(t, byte(collation.ID), handShakeResponse[12]) - // sanity check: validate the 23 bytes of filler with value 0x00 are set correctly - for i := 13; i < 13+23; i++ { - require.Equal(t, byte(0x00), handShakeResponse[i]) - } + // the 13th byte should always be 0x00 + require.Equal(t, byte(0x00), handShakeResponse[13]) } else { // if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12]) require.Equal(t, byte(collation.ID>>8), handShakeResponse[13]) + } - // sanity check: validate the 22 bytes of filler with value 0x00 are set correctly - for i := 14; i < 14+22; i++ { - require.Equal(t, byte(0x00), handShakeResponse[i]) - } + // sanity check: validate the 22 bytes of filler with value 0x00 are set correctly + for i := 14; i < 14+22; i++ { + require.Equal(t, byte(0x00), handShakeResponse[i]) } // and finally the username From 541e28481974044ea6510bbc910eabb084382ec4 Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Mon, 29 Apr 2024 09:51:53 -0400 Subject: [PATCH 07/13] fixing comments --- client/auth.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/client/auth.go b/client/auth.go index d56f0a972..93686c69e 100644 --- a/client/auth.go +++ b/client/auth.go @@ -285,7 +285,8 @@ func (c *Conn) writeAuthHandshake() error { // see https://github.com/mysql/mysql-server/pull/541 data[12] = byte(collation.ID & 0xff) // if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of - // padding the filler with a 0. + // padding the filler with a 0. If ID is > 255 then the first byte of filler will contain + // the right middle 8 bits of the collation ID. data[13] = byte((collation.ID & 0xff00) >> 8) // SSL Connection Request Packet @@ -309,9 +310,9 @@ func (c *Conn) writeAuthHandshake() error { // Filler [23 bytes] (all 0x00) // the filler starts at position 13, but the first byte of the filler - // maybe have been set by the collation id earlier. So we only position 13 - // will be either 0x00 or the right middle 8 bits of the collation id. Therefore - // here we start at position 14 and fill the remaining 22 bytes with 0x00. + // has been set earlier with collaction id earlier, so position 13 at this point + // will be either 0x00 or the right middle 8 bits of the collation id. + // Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00. pos := 14 for ; pos < 14+22; pos++ { data[pos] = 0 From 4f41dc33ae1d71cd9709d5a8146d0d2b32fefaf3 Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Mon, 29 Apr 2024 09:52:04 -0400 Subject: [PATCH 08/13] fixing comments --- client/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/auth.go b/client/auth.go index 93686c69e..e56ada574 100644 --- a/client/auth.go +++ b/client/auth.go @@ -310,7 +310,7 @@ func (c *Conn) writeAuthHandshake() error { // Filler [23 bytes] (all 0x00) // the filler starts at position 13, but the first byte of the filler - // has been set earlier with collaction id earlier, so position 13 at this point + // has been set earlier with collation id earlier, so position 13 at this point // will be either 0x00 or the right middle 8 bits of the collation id. // Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00. pos := 14 From c471d01b40c76ac845001eb7bb4c27224e428b5c Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Mon, 29 Apr 2024 09:57:55 -0400 Subject: [PATCH 09/13] fix linting errors --- client/auth_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client/auth_test.go b/client/auth_test.go index 00efea5cc..328e19a2c 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -1,13 +1,14 @@ package client import ( - "github.com/go-mysql-org/go-mysql/packet" - "github.com/pingcap/tidb/pkg/parser/charset" "net" "testing" - "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/tidb/pkg/parser/charset" "github.com/stretchr/testify/require" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/packet" ) func TestConnGenAttributes(t *testing.T) { From 993e3396c3d48a8f9e39089b4cffe8d9501036db Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Mon, 29 Apr 2024 10:13:27 -0400 Subject: [PATCH 10/13] restore tests that were commented out accidently --- client/auth_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/client/auth_test.go b/client/auth_test.go index 328e19a2c..177af1f43 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -41,12 +41,12 @@ func TestConnGenAttributes(t *testing.T) { func TestConnCollation(t *testing.T) { collations := []string{ - //"big5_chinese_ci", - //"utf8_general_ci", - //"utf8mb4_0900_ai_ci", - //"utf8mb4_de_pb_0900_ai_ci", - //"utf8mb4_ja_0900_as_cs", - //"utf8mb4_0900_bin", + "big5_chinese_ci", + "utf8_general_ci", + "utf8mb4_0900_ai_ci", + "utf8mb4_de_pb_0900_ai_ci", + "utf8mb4_ja_0900_as_cs", + "utf8mb4_0900_bin", "utf8mb4_zh_pinyin_tidb_as_cs", } From ef8e8008b1da56f893b5718303a0765d180340b1 Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Mon, 29 Apr 2024 11:33:33 -0400 Subject: [PATCH 11/13] fixing more typos in the comments --- client/auth.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client/auth.go b/client/auth.go index e56ada574..1f4d7c1de 100644 --- a/client/auth.go +++ b/client/auth.go @@ -280,7 +280,7 @@ func (c *Conn) writeAuthHandshake() error { } // the MySQL protocol calls for the collation id to be sent as 1, where only the - // lower 8 bits are used in this field. But wireshark shows that the first by of + // lower 8 bits are used in this field. But wireshark shows that the first byte of // the 23 bytes of filler is used to send the right middle 8 bits of the collation id. // see https://github.com/mysql/mysql-server/pull/541 data[12] = byte(collation.ID & 0xff) @@ -310,8 +310,8 @@ func (c *Conn) writeAuthHandshake() error { // Filler [23 bytes] (all 0x00) // the filler starts at position 13, but the first byte of the filler - // has been set earlier with collation id earlier, so position 13 at this point - // will be either 0x00 or the right middle 8 bits of the collation id. + // has been set with the collation id earlier, so position 13 at this point + // will be either 0x00, or the right middle 8 bits of the collation id. // Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00. pos := 14 for ; pos < 14+22; pos++ { From 447ea4f7214399806a99c862078d79b43a93d09a Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Tue, 30 Apr 2024 07:31:18 -0400 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: lance6716 --- client/auth_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/auth_test.go b/client/auth_test.go index 177af1f43..14bbcd081 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -80,8 +80,8 @@ func TestConnCollation(t *testing.T) { } // and finally the username - password := string(handShakeResponse[36:40]) - require.Equal(t, "test", password) + username := string(handShakeResponse[36:40]) + require.Equal(t, "test", username) require.NoError(t, server.Close()) } From 5f0308a20c43c977326746957b2dd6256d5b4825 Mon Sep 17 00:00:00 2001 From: dvilaverde Date: Tue, 30 Apr 2024 07:45:04 -0400 Subject: [PATCH 13/13] addressing PR feedback --- client/auth_test.go | 2 ++ client/client_test.go | 1 + client/conn.go | 5 ++--- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/client/auth_test.go b/client/auth_test.go index 14bbcd081..0837f1767 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -105,6 +105,8 @@ func sendAuthResponse(t *testing.T, collation string) net.Conn { go func() { err := c.writeAuthHandshake() require.NoError(t, err) + err = c.Close() + require.NoError(t, err) }() return server } diff --git a/client/client_test.go b/client/client_test.go index b27c4c669..10515e622 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -235,6 +235,7 @@ func (s *clientTestSuite) TestConn_SetCharset() { func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { err := s.c.SetCollation("latin1_swedish_ci") require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "cannot set collation after connection is established") } func (s *clientTestSuite) TestConn_SetCollation() { diff --git a/client/conn.go b/client/conn.go index 9d7014951..9fc7faf16 100644 --- a/client/conn.go +++ b/client/conn.go @@ -364,12 +364,11 @@ func (c *Conn) SetCharset(charset string) error { } func (c *Conn) SetCollation(collation string) error { - if c.status == 0 { - c.collation = collation - } else { + if len(c.serverVersion) != 0 { return errors.Trace(errors.Errorf("cannot set collation after connection is established")) } + c.collation = collation return nil }