Skip to content

Commit

Permalink
Merge branch 'main' into max/trace-ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
max-hoffman committed Jun 26, 2024
2 parents 3928f71 + 55a46c5 commit 4083c07
Show file tree
Hide file tree
Showing 9 changed files with 13,242 additions and 13,002 deletions.
23 changes: 23 additions & 0 deletions go/mysql/binlog_event_make.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ func packetize(f BinlogFormat, typ byte, flags uint16, data []byte, m BinlogEven
return result
}

// UpdateChecksum updates the checksum for the specified |event|. The BinlogFormat, |f|, indicates
// if checksums are enabled. If checksums are not enabled, then no change is made to |event|.
func UpdateChecksum(f BinlogFormat, event BinlogEvent) {
result := event.Bytes()
length := len(result)

switch f.ChecksumAlgorithm {
case BinlogChecksumAlgCRC32:
checksum := crc32.ChecksumIEEE(result[0 : length-4])
binary.LittleEndian.PutUint32(result[length-4:], checksum)
}
}

// NewInvalidEvent returns an invalid event (its size is <19).
func NewInvalidEvent() BinlogEvent {
return NewMysql56BinlogEvent([]byte{0})
Expand Down Expand Up @@ -146,6 +159,16 @@ func NewFormatDescriptionEvent(f BinlogFormat, m BinlogEventMetadata) BinlogEven
return NewMysql56BinlogEvent(ev)
}

// NewPreviousGtidsEvent creates a new Previous GTIDs BinlogEvent. The BinlogFormat, |f|, indicates if checksums are
// enabled, the BinlogEventMeatadata, |m|, specifies the unique server ID, and |gtids| is the MySQL 5.6 GTID set
// to include in the event, indicating the events that have been previously executed by the server.
func NewPreviousGtidsEvent(f BinlogFormat, m BinlogEventMetadata, gtids Mysql56GTIDSet) BinlogEvent {
data := gtids.SIDBlock()

ev := packetize(f, ePreviousGTIDsEvent, 0, data, m)
return NewMysql56BinlogEvent(ev)
}

// NewInvalidFormatDescriptionEvent returns an invalid FormatDescriptionEvent.
// The binlog version is set to 3. It IsValid() though.
func NewInvalidFormatDescriptionEvent(f BinlogFormat, m BinlogEventMetadata) BinlogEvent {
Expand Down
49 changes: 49 additions & 0 deletions go/mysql/binlog_event_make_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ limitations under the License.
package mysql

import (
"encoding/binary"
"reflect"
"testing"

"github.com/stretchr/testify/require"

binlogdatapb "github.com/dolthub/vitess/go/vt/proto/binlogdata"
)

Expand Down Expand Up @@ -141,6 +144,52 @@ func TestIntVarEvent(t *testing.T) {
}
}

func TestUpdateChecksum(t *testing.T) {
f := NewMySQL56BinlogFormat()
m := NewTestBinlogMetadata()

q := Query{
Database: "my database",
SQL: "my query",
Charset: &binlogdatapb.Charset{
Client: 0x1234,
Conn: 0x5678,
Server: 0x9abc,
},
}
event := NewQueryEvent(f, m, q)
bytes := event.Bytes()

// Calling UpdateChecksum without changing the event should not change the checksum
oldChecksum := append([]byte{}, bytes[len(bytes)-4:]...)
UpdateChecksum(f, event)
newChecksum := append([]byte{}, bytes[len(bytes)-4:]...)
require.Equal(t, oldChecksum, newChecksum)
require.Equal(t, []byte{0x65, 0xaa, 0x33, 0x0e}, newChecksum)

// Calling UpdateChecksum after changing the event should generate a new checksum
binary.LittleEndian.PutUint32(bytes[13:13+4], uint32(420))
UpdateChecksum(f, event)
newChecksum = append([]byte{}, bytes[len(bytes)-4:]...)
require.NotEqual(t, oldChecksum, newChecksum)
require.Equal(t, []byte{0x26, 0xD0, 0xa4, 0x05}, newChecksum)
}

func TestPreviousGtidsEvent(t *testing.T) {
f := NewMySQL56BinlogFormat()
m := NewTestBinlogMetadata()

gtidSetString := "32a5b8c9-4716-40f5-9a9b-3d7be0cb33d7:1-42"
gtidSet, err := ParseMysql56GTIDSet(gtidSetString)
require.NoError(t, err)

event := NewPreviousGtidsEvent(f, m, gtidSet.(Mysql56GTIDSet))
require.True(t, event.IsPreviousGTIDs())
position, err := event.PreviousGTIDs(f)
require.NoError(t, err)
require.Equal(t, gtidSetString, position.String())
}

func TestInvalidEvents(t *testing.T) {
f := NewMySQL56BinlogFormat()
m := NewTestBinlogMetadata()
Expand Down
58 changes: 48 additions & 10 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package mysql
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"strings"
Expand Down Expand Up @@ -91,6 +92,12 @@ type Handler interface {
// ConnectionClosed is called when a connection is closed.
ConnectionClosed(c *Conn)

// ConnectionAborted is called when a new connection cannot be fully established. For
// example, if a client connects to the server, but fails authentication, or can't
// negotiate an authentication handshake, this method will be called to let integrators
// know about the failed connection attempt.
ConnectionAborted(c *Conn, reason string) error

// ComInitDB is called once at the beginning to set db name,
// and subsequently for every ComInitDB event.
ComInitDB(c *Conn, schemaName string) error
Expand Down Expand Up @@ -363,7 +370,7 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
salt, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig != nil)
if err != nil {
if err != io.EOF {
log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err)
l.handleConnectionError(c, fmt.Sprintf("Cannot send HandshakeV10 packet: %v", err))
}
return
}
Expand All @@ -374,13 +381,16 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
if err != nil {
// Don't log EOF errors. They cause too much spam, same as main read loop.
if err != io.EOF {
log.Infof("Cannot read client handshake response from %s: %v, it may not be a valid MySQL client", c, err)
l.handleConnectionWarning(c, fmt.Sprintf(
"Cannot read client handshake response from %s: %v, " +
"it may not be a valid MySQL client", c, err))
}
return
}
user, authMethod, authResponse, err := l.parseClientHandshakePacket(c, true, response)
if err != nil {
log.Errorf("Cannot parse client handshake response from %s: %v", c, err)
l.handleConnectionError(c, fmt.Sprintf(
"Cannot parse client handshake response from %s: %v", c, err))
return
}

Expand All @@ -390,14 +400,16 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
// SSL was enabled. We need to re-read the auth packet.
response, err = c.readEphemeralPacket(ctx)
if err != nil {
log.Errorf("Cannot read post-SSL client handshake response from %s: %v", c, err)
l.handleConnectionError(c, fmt.Sprintf(
"Cannot read post-SSL client handshake response from %s: %v", c, err))
return
}

// Returns copies of the data, so we can recycle the buffer.
user, authMethod, authResponse, err = l.parseClientHandshakePacket(c, false, response)
if err != nil {
log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err)
l.handleConnectionError(c, fmt.Sprintf(
"Cannot parse post-SSL client handshake response from %s: %v", c, err))
return
}
c.recycleReadPacket()
Expand All @@ -421,6 +433,7 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
// See what auth method the AuthServer wants to use for that user.
authServerMethod, err := l.authServer.AuthMethod(user, conn.RemoteAddr().String())
if err != nil {
l.handleConnectionError(c, "auth server failed to determine auth method")
c.writeErrorPacketFromError(err)
return
}
Expand All @@ -433,7 +446,8 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
// ValidateHash() method.
userData, err := l.authServer.ValidateHash(salt, user, authResponse, conn.RemoteAddr())
if err != nil {
log.Warningf("Error authenticating user using MySQL native password: %v", err)
l.handleConnectionWarning(c, fmt.Sprintf(
"Error authenticating user using MySQL native password: %v", err))
c.writeErrorPacketFromError(err)
return
}
Expand All @@ -452,20 +466,22 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
data := make([]byte, 21)
data = append(salt, byte(0x00))
if err := c.writeAuthSwitchRequest(MysqlNativePassword, data); err != nil {
log.Errorf("Error writing auth switch packet for %s: %v", c, err)
l.handleConnectionError(c, fmt.Sprintf("Error writing auth switch packet for %s: %v", c, err))
return
}

response, err := c.readEphemeralPacket(ctx)
if err != nil {
log.Errorf("Error reading auth switch response for %s: %v", c, err)
l.handleConnectionError(c, fmt.Sprintf(
"Error reading auth switch response for %s: %v", c, err))
return
}
c.recycleReadPacket()

userData, err := l.authServer.ValidateHash(salt, user, response, conn.RemoteAddr())
if err != nil {
log.Warningf("Error authenticating user using MySQL native password: %v", err)
l.handleConnectionWarning(c, fmt.Sprintf(
"Error authenticating user using MySQL native password: %v", err))
c.writeErrorPacketFromError(err)
return
}
Expand All @@ -477,6 +493,8 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3

// The negotiation happens in clear text. Let's check we can.
if !l.AllowClearTextWithoutTLS.Get() && c.Capabilities&CapabilityClientSSL == 0 {
l.handleConnectionWarning(c, fmt.Sprintf(
"Cannot use clear text authentication over non-SSL connections."))
c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "Cannot use clear text authentication over non-SSL connections.")
return
}
Expand All @@ -488,14 +506,17 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
data = authServerDialogSwitchData()
}
if err := c.writeAuthSwitchRequest(authServerMethod, data); err != nil {
log.Errorf("Error writing auth switch packet for %s: %v", c, err)
l.handleConnectionError(c, fmt.Sprintf(
"Error writing auth switch packet for %s: %v", c, err))
return
}

// Then hand over the rest of the negotiation to the
// auth server.
userData, err := l.authServer.Negotiate(c, user, conn.RemoteAddr())
if err != nil {
l.handleConnectionWarning(c, fmt.Sprintf(
"Unable to negotiate authentication: %v", err))
c.writeErrorPacketFromError(err)
return
}
Expand Down Expand Up @@ -540,6 +561,23 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
}
}

// handleConnectionError logs |reason| as an error and notifies the handler that a connection has been aborted.
func (l *Listener) handleConnectionError(c *Conn, reason string) {
log.Error(reason)
if err := l.handler.ConnectionAborted(c, reason); err != nil {
log.Errorf("unable to report connection aborted to handler: %s", err)
}
}

// handleConnectionWarning logs |reason| as a warning and notifies the handler that a connection has been aborted.
func (l *Listener) handleConnectionWarning(c *Conn, reason string) {
log.Warning(reason)
if err := l.handler.ConnectionAborted(c, reason); err != nil {
log.Errorf("unable to report connection aborted to handler: %s", err)
}
}


// Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed.
func (l *Listener) Close() {
l.listener.Close()
Expand Down
4 changes: 4 additions & 0 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ func (th *testHandler) NewConnection(c *Conn) {
func (th *testHandler) ConnectionClosed(c *Conn) {
}

func (th *testHandler) ConnectionAborted(c *Conn, reason string) error {
return nil
}

func (th *testHandler) ParserOptionsForConnection(c *Conn) (sqlparser.ParserOptions, error) {
return sqlparser.ParserOptions{}, nil
}
Expand Down
29 changes: 20 additions & 9 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -6848,15 +6848,21 @@ func VarScope(nameParts ...string) (string, SetScope, string, error) {
return VarScope(nameParts[0][:dotIdx], nameParts[0][dotIdx+1:])
}
// Session scope is inferred here, but not explicitly requested
return nameParts[0][2:], SetScope_Session, "", nil
return trimQuotes(nameParts[0][2:]), SetScope_Session, "", nil
} else if strings.HasPrefix(nameParts[0], "@") {
return nameParts[0][1:], SetScope_User, "", nil
varName := nameParts[0][1:]
if len(varName) > 0 {
varName = trimQuotes(varName)
}
return varName, SetScope_User, "", nil
} else {
return nameParts[0], SetScope_None, "", nil
}
case 2:
// `@user.var` is valid, so we check for it here.
if len(nameParts[0]) >= 2 && nameParts[0][0] == '@' && nameParts[0][1] != '@' &&
if len(nameParts[0]) >= 2 &&
nameParts[0][0] == '@' &&
nameParts[0][1] != '@' &&
!strings.HasPrefix(nameParts[1], "@") { // `@user.@var` is invalid though.
return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, "", nil
}
Expand All @@ -6872,27 +6878,27 @@ func VarScope(nameParts ...string) (string, SetScope, string, error) {
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Global, nameParts[0][2:], nil
return trimQuotes(nameParts[1]), SetScope_Global, nameParts[0][2:], nil
case "@@persist":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Persist, nameParts[0][2:], nil
return trimQuotes(nameParts[1]), SetScope_Persist, nameParts[0][2:], nil
case "@@persist_only":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_PersistOnly, nameParts[0][2:], nil
return trimQuotes(nameParts[1]), SetScope_PersistOnly, nameParts[0][2:], nil
case "@@session":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Session, nameParts[0][2:], nil
return trimQuotes(nameParts[1]), SetScope_Session, nameParts[0][2:], nil
case "@@local":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Session, nameParts[0][2:], nil
return trimQuotes(nameParts[1]), SetScope_Session, nameParts[0][2:], nil
default:
// This catches `@@@GLOBAL.sys_var`. Due to the earlier check, this does not error on `@user.var`.
if strings.HasPrefix(nameParts[0], "@") {
Expand Down Expand Up @@ -7225,8 +7231,13 @@ func formatID(buf *TrackedBuffer, original, lowered string) {
isDbSystemVariable = true
}

isUserVariable := false
if !isDbSystemVariable && len(original) > 0 && original[:1] == "@" {
isUserVariable = true
}

for i, c := range original {
if !(isLetter(uint16(c)) || c == '@') && (!isDbSystemVariable || !isCarat(uint16(c))) {
if !(isLetter(uint16(c)) || c == '@') && (!isDbSystemVariable || !isCarat(uint16(c))) && !isUserVariable {
if i == 0 || !isDigit(uint16(c)) {
goto mustEscape
}
Expand Down
Loading

0 comments on commit 4083c07

Please sign in to comment.