Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow setting the collation in auth handshake #860

Merged
merged 13 commits into from
Apr 30, 2024
12 changes: 11 additions & 1 deletion client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/charset"
)

const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
Expand Down Expand Up @@ -269,7 +270,16 @@ func (c *Conn) writeAuthHandshake() error {

// Charset [1 byte]
// use default collation id 33 here, is utf-8
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
data[12] = DEFAULT_COLLATION_ID
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
}
collation, err := charset.GetCollationByName(collationName)
if err != nil {
return fmt.Errorf("invalid collation name %s", collationName)
}

data[12] = byte(collation.ID)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would overflow for some collations.

  • 255 / 0xff / utf8mb4_0900_ai_ci would be fine
  • 309 / 0x0135 / utf8mb4_0900_bin would not be

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weirdly enough the protocol says this is 1 byte and that only the low 8-bits are put in this field. Not sure how that's going to work.

https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A MySQL 8.0 Client only allows the user to set the charset, not the collation:

  --default-character-set=name 
                      Set the default character set.

And as all default collations are in the 0-255 range this works with the protocol.

mysql> SELECT MAX(ID) FROM information_schema.COLLATIONS WHERE IS_DEFAULT='Yes';
+---------+
| MAX(ID) |
+---------+
|     255 |
+---------+
1 row in set (0.00 sec)


// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
Expand Down
21 changes: 20 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
func (s *clientTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "")
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic, but this is essentially a no-op since
// the collation set is the default value
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
})
require.NoError(s.T(), err)

var result *mysql.Result
Expand Down Expand Up @@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
err := s.c.SetCollation("latin1_swedish_ci")
require.Error(s.T(), err)
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
}

func (s *clientTestSuite) TestConn_SetCollation() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic
_ = conn.SetCollation("invalid_collation")
})

require.Error(s.T(), err)
}

func (s *clientTestSuite) testStmt_DropTable() {
str := `drop table if exists mixer_test_stmt`

Expand Down
24 changes: 22 additions & 2 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Conn struct {
status uint16

charset string
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
collation string

salt []byte
authPluginName string
Expand Down Expand Up @@ -67,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options .
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

dialer := &net.Dialer{}
return ConnectWithContext(ctx, addr, user, password, dbName, options...)
}

// ConnectWithContext to a MySQL addr using the provided context.
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
dialer := &net.Dialer{}
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
}

// Dialer connects to the address on the named network using the provided context.
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)

// Connect to a MySQL server using the given Dialer.
// ConnectWithDialer to a MySQL server using the given Dialer.
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
c := new(Conn)

Expand Down Expand Up @@ -357,6 +363,20 @@ func (c *Conn) SetCharset(charset string) error {
}
}

func (c *Conn) SetCollation(collation string) error {
if c.status == 0 {
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
c.collation = collation
} else {
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
}

return nil
}

func (c *Conn) GetCollation() string {
return c.collation
}

func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
Expand Down
Loading