Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow setting the collation in auth handshake #860

Merged
merged 13 commits into from
Apr 30, 2024
30 changes: 26 additions & 4 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/charset"
)

const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
Expand Down Expand Up @@ -268,8 +269,25 @@ func (c *Conn) writeAuthHandshake() error {
data[11] = 0x00

// Charset [1 byte]
// use default collation id 33 here, is utf-8
data[12] = DEFAULT_COLLATION_ID
// use default collation id 33 here, is `utf8mb3_general_ci`
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
}
collation, err := charset.GetCollationByName(collationName)
if err != nil {
return fmt.Errorf("invalid collation name %s", collationName)
}

// the MySQL protocol calls for the collation id to be sent as 1, where only the
// lower 8 bits are used in this field. But wireshark shows that the first byte of
// the 23 bytes of filler is used to send the right middle 8 bits of the collation id.
// see https://github.com/mysql/mysql-server/pull/541
data[12] = byte(collation.ID & 0xff)
// if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of
// padding the filler with a 0. If ID is > 255 then the first byte of filler will contain
// the right middle 8 bits of the collation ID.
data[13] = byte((collation.ID & 0xff00) >> 8)

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
Expand All @@ -291,8 +309,12 @@ func (c *Conn) writeAuthHandshake() error {
}

// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
// the filler starts at position 13, but the first byte of the filler
// has been set with the collation id earlier, so position 13 at this point
// will be either 0x00, or the right middle 8 bits of the collation id.
// Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00.
pos := 14
for ; pos < 14+22; pos++ {
data[pos] = 0
}

Expand Down
78 changes: 77 additions & 1 deletion client/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package client

import (
"net"
"testing"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/stretchr/testify/require"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
)

func TestConnGenAttributes(t *testing.T) {
Expand Down Expand Up @@ -34,3 +38,75 @@ func TestConnGenAttributes(t *testing.T) {
require.Subset(t, data, fixt)
}
}

func TestConnCollation(t *testing.T) {
collations := []string{
"big5_chinese_ci",
"utf8_general_ci",
"utf8mb4_0900_ai_ci",
"utf8mb4_de_pb_0900_ai_ci",
"utf8mb4_ja_0900_as_cs",
"utf8mb4_0900_bin",
"utf8mb4_zh_pinyin_tidb_as_cs",
}

// test all supported collations by calling writeAuthHandshake() and reading the bytes
// sent to the server to ensure the collation id is set correctly
for _, c := range collations {
collation, err := charset.GetCollationByName(c)
require.NoError(t, err)
server := sendAuthResponse(t, collation.Name)
// read the all the bytes of the handshake response so that client goroutine can complete without blocking
// on the server read.
handShakeResponse := make([]byte, 128)
_, err = server.Read(handShakeResponse)
require.NoError(t, err)

// validate the collation id is set correctly
// if the collation ID is <= 255 the collation ID is stored in the 12th byte
if collation.ID <= 255 {
require.Equal(t, byte(collation.ID), handShakeResponse[12])
// the 13th byte should always be 0x00
require.Equal(t, byte(0x00), handShakeResponse[13])
} else {
// if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes
require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12])
require.Equal(t, byte(collation.ID>>8), handShakeResponse[13])
}

// sanity check: validate the 22 bytes of filler with value 0x00 are set correctly
for i := 14; i < 14+22; i++ {
require.Equal(t, byte(0x00), handShakeResponse[i])
}

// and finally the username
username := string(handShakeResponse[36:40])
require.Equal(t, "test", username)

require.NoError(t, server.Close())
}
}

func sendAuthResponse(t *testing.T, collation string) net.Conn {
server, client := net.Pipe()
c := &Conn{
Conn: &packet.Conn{
Conn: client,
},
authPluginName: "mysql_native_password",
user: "test",
db: "test",
password: "test",
proto: "tcp",
collation: collation,
salt: ([]byte)("123456781234567812345678"),
}

go func() {
err := c.writeAuthHandshake()
require.NoError(t, err)
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
err = c.Close()
require.NoError(t, err)
}()
return server
}
22 changes: 21 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
func (s *clientTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "")
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic, but this is essentially a no-op since
// the collation set is the default value
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
})
require.NoError(s.T(), err)

var result *mysql.Result
Expand Down Expand Up @@ -228,6 +232,22 @@ func (s *clientTestSuite) TestConn_SetCharset() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
err := s.c.SetCollation("latin1_swedish_ci")
require.Error(s.T(), err)
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
require.ErrorContains(s.T(), err, "cannot set collation after connection is established")
}

func (s *clientTestSuite) TestConn_SetCollation() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic
_ = conn.SetCollation("invalid_collation")
})

require.Error(s.T(), err)
}

func (s *clientTestSuite) testStmt_DropTable() {
str := `drop table if exists mixer_test_stmt`

Expand Down
23 changes: 21 additions & 2 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Conn struct {
status uint16

charset string
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
collation string

salt []byte
authPluginName string
Expand Down Expand Up @@ -67,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options .
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

dialer := &net.Dialer{}
return ConnectWithContext(ctx, addr, user, password, dbName, options...)
}

// ConnectWithContext to a MySQL addr using the provided context.
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
dialer := &net.Dialer{}
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
}

// Dialer connects to the address on the named network using the provided context.
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)

// Connect to a MySQL server using the given Dialer.
// ConnectWithDialer to a MySQL server using the given Dialer.
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
c := new(Conn)

Expand Down Expand Up @@ -357,6 +363,19 @@ func (c *Conn) SetCharset(charset string) error {
}
}

func (c *Conn) SetCollation(collation string) error {
if len(c.serverVersion) != 0 {
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
}

c.collation = collation
return nil
}

func (c *Conn) GetCollation() string {
return c.collation
}

func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
Expand Down
Loading