Skip to content

Commit

Permalink
fix detection of multi-statements in ComPrepare (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
jycor authored Jul 29, 2024
1 parent 9f8cd7c commit 60026c4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 23 deletions.
52 changes: 29 additions & 23 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1105,32 +1105,47 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
return nil
}

// Populate PrepareData
c.StatementID++
prepare := &PrepareData{
StatementID: c.StatementID,
PrepareStmt: query,
}

var err error
var statement sqlparser.Statement
var remainder string

parserOptions, err := handler.ParserOptionsForConnection(c)
if err != nil {
log.Errorf("unable to determine parser options for current connection: %s", err.Error())
return err
}

var queries []string
if !c.DisableClientMultiStatements && c.Capabilities&CapabilityClientMultiStatements != 0 {
var ri int
statement, ri, err = sqlparser.ParseOneWithOptions(ctx, query, parserOptions)
if ri < len(query) {
remainder = query[ri:]
queries, err = sqlparser.SplitStatementToPieces(query)
if err != nil {
log.Errorf("error splitting query: %v", c, err)
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
return nil
}
if len(queries) != 1 {
err := fmt.Errorf("cannot prepare multiple statements")
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
return nil
}
} else {
statement, err = sqlparser.ParseWithOptions(ctx, query, parserOptions)
queries = []string{query}
}

// Populate PrepareData
c.StatementID++
prepare := &PrepareData{
StatementID: c.StatementID,
PrepareStmt: queries[0],
}

statement, err = sqlparser.ParseWithOptions(ctx, query, parserOptions)
if err != nil {
log.Errorf("Error while parsing prepared statement: %s", err.Error())
if werr := c.writeErrorPacketFromError(err); werr != nil {
Expand All @@ -1140,15 +1155,6 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
}
return nil
}
if remainder != "" {
err := fmt.Errorf("can not prepare multiple statements")
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
return nil
}

// Walk the parsed statement tree and find any SQLVal nodes that are parameterized.
paramsCount := uint16(0)
Expand Down
66 changes: 66 additions & 0 deletions go/mysql/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1114,3 +1114,69 @@ func TestExecuteQueries(t *testing.T) {
}
})
}

func TestComStmtPrepareWithTrailingNewLine (t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

sql := "select ?;\n"
data := MockQueryPackets(t, sql)
if err := cConn.writePacket(data); err != nil {
t.Fatalf("writePacket failed: %v", err)
}

prepare, result := MockPrepareData(t)
sConn.PrepareData = make(map[uint32]*PrepareData)
sConn.PrepareData[prepare.StatementID] = prepare

sConn.Capabilities |= CapabilityClientMultiStatements
handler := &testHandler{
result: result,
}
err := sConn.handleNextCommand(context.Background(), handler)
if err != nil {
t.Fatalf("handleNextCommand failed: %v", err)
}

if err := cConn.ExecuteStreamFetch(sql); err != nil {
t.Fatalf("ExecuteStreamFetch(%v) failed: %v", sql, err)
return
}
}

func TestComStmtPrepareMultiStmt (t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

sql := "select ?; select ?;"
data := MockQueryPackets(t, sql)
if err := cConn.writePacket(data); err != nil {
t.Fatalf("writePacket failed: %v", err)
}

prepare, result := MockPrepareData(t)
sConn.PrepareData = make(map[uint32]*PrepareData)
sConn.PrepareData[prepare.StatementID] = prepare

sConn.Capabilities |= CapabilityClientMultiStatements
handler := &testHandler{
result: result,
}
err := sConn.handleNextCommand(context.Background(), handler)
if err != nil {
t.Fatalf("handleNextCommand failed: %v", err)
}

if err := cConn.ExecuteStreamFetch(sql); err == nil {
t.Fatalf("expected error, but received nil")
return
}
}

0 comments on commit 60026c4

Please sign in to comment.