Skip to content

Commit

Permalink
First pass at supporting ANSI_QUOTES SQL mode in the parser
Browse files Browse the repository at this point in the history
  • Loading branch information
fulghum committed Jul 24, 2023
1 parent dc2f84e commit 4c05b8e
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 22 deletions.
11 changes: 9 additions & 2 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1086,16 +1086,22 @@ func (c *Conn) handleNextCommand(handler Handler) error {
var statement sqlparser.Statement
var remainder string

parserOptions, err := handler.ParserOptionsForConnection(c)
if err != nil {
return err
}

if !c.DisableClientMultiStatements && c.Capabilities&CapabilityClientMultiStatements != 0 {
var ri int
statement, ri, err = sqlparser.ParseOne(query)
statement, ri, err = sqlparser.ParseOneWithOptions(query, parserOptions)
if ri < len(query) {
remainder = query[ri:]
}
} else {
statement, err = sqlparser.Parse(query)
statement, err = sqlparser.ParseWithOptions(query, parserOptions)
}
if err != nil {
log.Errorf("Error while parsing prepared statement: %s", err.Error())
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
Expand All @@ -1113,6 +1119,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
return nil
}

// Walk the parsed statement tree and find any SQLVal nodes that are parameterized.
paramsCount := uint16(0)
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
Expand Down
8 changes: 8 additions & 0 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/dolthub/vitess/go/vt/log"
querypb "github.com/dolthub/vitess/go/vt/proto/query"
"github.com/dolthub/vitess/go/vt/proto/vtrpc"
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/dolthub/vitess/go/vt/vterrors"
)

Expand Down Expand Up @@ -121,6 +122,13 @@ type Handler interface {
WarningCount(c *Conn) uint16

ComResetConnection(c *Conn)

// ParserOptionsForConnection returns any parser options that should be used for the given connection. For
// example, if the connection has enabled ANSI_QUOTES or ANSI SQL_MODE, then the parser needs to know that
// in order to parse queries correctly. This is primarily needed when a prepared statement request comes in,
// and the Vitess layer needs to parse the query to identify the query parameters so that the correct response
// packets can be sent.
ParserOptionsForConnection(c *Conn) (sqlparser.ParserOptions, error)
}

// Listener is the MySQL server protocol listener.
Expand Down
40 changes: 34 additions & 6 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,54 @@ func yyParsePooled(yylex yyLexer) int {
// a set of types, define the function as iTypeName.
// This will help avoid name collisions.

// ParserOptions defines options that customize how statements are parsed.
type ParserOptions struct {
// AnsiQuotes controls whether " characters are treated as the identifier character, as
// defined in the SQL92 standard, or as a string quote character. By default, AnsiQuotes is
// disabled, and ` characters are treated as the identifier character (and not a string
// quoting character). When AnsiQuotes is set to true, " characters are instead treated
// as identifier quotes and NOT valid as string quotes. Note that the ` character may always
// be used to quote identifiers, regardless of whether AnsiQuotes is enabled or not. For
// more info, see: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_ansi_quotes
AnsiQuotes bool
}

// Parse parses the SQL in full and returns a Statement, which
// is the AST representation of the query. If a DDL statement
// is partially parsed but still contains a syntax error, the
// error is ignored and the DDL is returned anyway.
func Parse(sql string) (Statement, error) {
return ParseWithOptions(sql, ParserOptions{})
}

// ParseWithOptions fully parses the SQL in |sql|, using any custom options specified
// in |options|, and returns a Statement, which is the AST representation of the query.
// If a DDL statement is partially parsed but contains a syntax error, the
// error is ignored and the DDL is returned anyway.
func ParseWithOptions(sql string, options ParserOptions) (Statement, error) {
tokenizer := NewStringTokenizer(sql)
if options.AnsiQuotes {
tokenizer = NewStringTokenizerForAnsiQuotes(sql)
}
return parseTokenizer(sql, tokenizer)
}

// ParseOne parses the first SQL statement in the given string and returns the
// index of the start of the next statement in |sql|. If there was only one
// statement in |sql|, the value of the returned index will be |len(sql)|.
func ParseOne(sql string) (Statement, int, error) {
return ParseOneWithOptions(sql, ParserOptions{})
}

// ParseOneWithOptions parses the first SQL statement in |sql|, using any parsing
// options specified in |options|, and returns the parsed Statement, along with
// the index of the start of the next statement in |sql|. If there was only one
// statement in |sql|, the value of the returned index will be |len(sql)|.
func ParseOneWithOptions(sql string, options ParserOptions) (Statement, int, error) {
tokenizer := NewStringTokenizer(sql)
if options.AnsiQuotes {
tokenizer = NewStringTokenizerForAnsiQuotes(sql)
}
tokenizer.stopAfterFirstStmt = true
tree, err := parseTokenizer(sql, tokenizer)
if err != nil {
Expand Down Expand Up @@ -214,12 +248,6 @@ func stringIsUnbrokenQuote(s string, quoteChar byte) bool {
return true
}

// ParseTokenizer is a raw interface to parse from the given tokenizer.
// This does not used pooled parsers, and should not be used in general.
func ParseTokenizer(tokenizer *Tokenizer) int {
return yyParse(tokenizer)
}

// ParseNext parses a single SQL statement from the tokenizer
// returning a Statement which is the AST representation of the query.
// The tokenizer will always read up to the end of the statement, allowing for
Expand Down
65 changes: 64 additions & 1 deletion go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3071,6 +3071,7 @@ var (
"description 'description'",
},
}

// Any tests that contain multiple statements within the body (such as BEGIN/END blocks) should go here.
// validSQL is used by TestParseNextValid, which expects a semicolon to mean the end of a full statement.
// Multi-statement bodies do not follow this expectation, hence they are excluded from TestParseNextValid.
Expand Down Expand Up @@ -3359,6 +3360,29 @@ end case;
end`,
},
}

// validAnsiQuotesSql contains SQL statements that are valid when the ANSI_QUOTES SQL mode is enabled. This
// mode treats double quotes (and backticks) as identifier quotes, and single quotes as string quotes.
validAnsiQuotesSql = []parseTest{
{
input: `select "count", "foo", "bar" from t order by "COUNT"`,
output: "select `count`, foo, bar from t order by `COUNT` asc",
},
{
input: `select '"' from t order by "foo"`,
output: `select '\"' from t order by foo asc`,
},
{
// Assert that quote escaping is the same as when ANSI_QUOTES is off
input: `select '''foo'''`,
output: `select '\'foo\''`,
},
{
// Assert that quote escaping is the same as when ANSI_QUOTES is off
input: `select """""""foo"""""""`,
output: "select `\"\"\"foo\"\"\"`",
},
}
)

func TestValid(t *testing.T) {
Expand All @@ -3368,6 +3392,20 @@ func TestValid(t *testing.T) {
}
}

func TestAnsiQuotesMode(t *testing.T) {
parserOptions := ParserOptions{AnsiQuotes: true}
for _, tcase := range validAnsiQuotesSql {
runParseTestCaseWithParserOptions(t, tcase, parserOptions)
}
for _, tcase := range invalidAnsiQuotesSQL {
t.Run(tcase.input, func(t *testing.T) {
_, err := ParseWithOptions(tcase.input, parserOptions)
require.NotNil(t, err)
assert.Equal(t, tcase.output, err.Error())
})
}
}

func TestSingle(t *testing.T) {
validSQL = append(validSQL, validMultiStatementSql...)
for _, tcase := range validSQL {
Expand Down Expand Up @@ -4394,12 +4432,19 @@ func TestKeywords(t *testing.T) {
}
}

// runParseTestCase runs the specific test case, |tcase|, using the default parser options.
func runParseTestCase(t *testing.T, tcase parseTest) bool {
return runParseTestCaseWithParserOptions(t, tcase, ParserOptions{})
}

// runParseTestCaseWithParserOptions runs the specific test case, |tcase|, using the specified parser
// options, |options|, to control any parser behaviors.
func runParseTestCaseWithParserOptions(t *testing.T, tcase parseTest, options ParserOptions) bool {
return t.Run(tcase.input, func(t *testing.T) {
if tcase.output == "" {
tcase.output = tcase.input
}
tree, err := Parse(tcase.input)
tree, err := ParseWithOptions(tcase.input, options)
require.NoError(t, err)

assertTestcaseOutput(t, tcase, tree)
Expand Down Expand Up @@ -6302,6 +6347,24 @@ var (
output: "You have an error in your SQL syntax; At least one event field to alter needs to be defined at position 52 near 'myevent'",
},
}

// invalidAnsiQuotesSQL contains invalid SQL statements that use ANSI_QUOTES mode.
invalidAnsiQuotesSQL = []parseTest{
{
// Assert that the two identifier quotes do not match each other
input: "select \"foo`",
output: "syntax error at position 13 near 'foo`'",
},
{
// Assert that the two identifier quotes do not match each other
input: "select `bar\"",
output: "syntax error at position 13 near 'bar\"'",
},
{
input: "select 'a' \"b\" 'c'",
output: "syntax error at position 19 near 'c'",
},
}
)

func TestErrors(t *testing.T) {
Expand Down
65 changes: 52 additions & 13 deletions go/vt/sqlparser/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ import (
const (
defaultBufSize = 4096
eofChar = 0x100
backtickQuote = uint16('`')
doubleQuote = uint16('"')
singleQuote = uint16('\'')
)

// Tokenizer is the struct used to generate SQL
Expand Down Expand Up @@ -67,6 +70,14 @@ type Tokenizer struct {
bufPos int
bufSize int

// identifierQuotes holds the characters that are treated as identifier quotes. This always includes
// the backtick char. When the ANSI_QUOTES SQL mode is enabled, it also includes the double quote char.
identifierQuotes []uint16

// stringLiteralQuotes holds the characters that are treated as string literal quotes. This always includes the
// single quote char. When ANSI_QUOTES SQL mode is NOT enabled, this also contains the double quote character.
stringLiteralQuotes []uint16

queryBuf []byte
}

Expand All @@ -77,6 +88,21 @@ func NewStringTokenizer(sql string) *Tokenizer {
return &Tokenizer{
buf: buf,
bufSize: len(buf),
identifierQuotes: []uint16{backtickQuote},
stringLiteralQuotes: []uint16{doubleQuote, singleQuote},
}
}

// NewStringTokenizerForAnsiQuotes creates a new Tokenizer for the specified |sql| string, configured for
// ANSI_QUOTES SQL mode, meaning that any double quotes will be interpreted as quotes around an identifier,
// not around a string literal.
func NewStringTokenizerForAnsiQuotes(sql string) *Tokenizer {
buf := []byte(sql)
return &Tokenizer{
buf: buf,
bufSize: len(buf),
identifierQuotes: []uint16{backtickQuote, doubleQuote},
stringLiteralQuotes: []uint16{singleQuote},
}
}

Expand Down Expand Up @@ -965,16 +991,28 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
return NE, nil
}
return int(ch), nil
case '\'', '"':
case contains(tkn.stringLiteralQuotes, ch):
return tkn.scanString(ch, STRING)
case '`':
return tkn.scanLiteralIdentifier()
case contains(tkn.identifierQuotes, ch):
return tkn.scanLiteralIdentifier(ch)
default:
return LEX_ERROR, []byte{byte(ch)}
}
}
}

// contains searches the specified |slice| for the target |x|, and returns the same value of |x| if it is found. The
// target value is returned, instead of a boolean response, so that this function can be directly used inside the
// switch statement above that switches on a uint16 value.
func contains(slice []uint16, x uint16) uint16 {
for _, element := range slice {
if element == x {
return element
}
}
return 0
}

// skipStatement scans until end of statement.
func (tkn *Tokenizer) skipStatement() int {
for {
Expand Down Expand Up @@ -1022,7 +1060,7 @@ func (tkn *Tokenizer) scanIdentifier(firstByte byte, isDbSystemVariable bool) (i
func (tkn *Tokenizer) scanHex() (int, []byte) {
buffer := &bytes2.Buffer{}
tkn.scanMantissa(16, buffer)
if tkn.lastChar != '\'' {
if tkn.lastChar != singleQuote {
return LEX_ERROR, buffer.Bytes()
}
tkn.next()
Expand All @@ -1042,23 +1080,24 @@ func (tkn *Tokenizer) scanBitLiteral() (int, []byte) {
return BIT_LITERAL, buffer.Bytes()
}

func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) {
// TODO: Add godocs
func (tkn *Tokenizer) scanLiteralIdentifier(startingChar uint16) (int, []byte) {
buffer := &bytes2.Buffer{}
backTickSeen := false
identifierQuoteSeen := false
for {
if backTickSeen {
if tkn.lastChar != '`' {
if identifierQuoteSeen {
if tkn.lastChar != startingChar {
break
}
backTickSeen = false
buffer.WriteByte('`')
identifierQuoteSeen = false
buffer.WriteByte(byte(startingChar))
tkn.next()
continue
}
// The previous char was not a backtick.
switch tkn.lastChar {
case '`':
backTickSeen = true
case startingChar:
identifierQuoteSeen = true
case eofChar:
// Premature EOF.
return LEX_ERROR, buffer.Bytes()
Expand Down Expand Up @@ -1227,7 +1266,7 @@ func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {

// mysql strings get auto concatenated, so see if the next token is a string and scan it if so
tkn.skipBlank()
if tkn.lastChar == '\'' || tkn.lastChar == '"' {
if contains(tkn.stringLiteralQuotes, tkn.lastChar) == tkn.lastChar {
delim := tkn.lastChar
tkn.next()
nextTyp, nextStr := tkn.scanString(delim, STRING)
Expand Down

0 comments on commit 4c05b8e

Please sign in to comment.