Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ANSI_QUOTES parsing mode #256

Merged
merged 5 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1086,16 +1086,22 @@ func (c *Conn) handleNextCommand(handler Handler) error {
var statement sqlparser.Statement
var remainder string

parserOptions, err := handler.ParserOptionsForConnection(c)
if err != nil {
return err
}

if !c.DisableClientMultiStatements && c.Capabilities&CapabilityClientMultiStatements != 0 {
var ri int
statement, ri, err = sqlparser.ParseOne(query)
statement, ri, err = sqlparser.ParseOneWithOptions(query, parserOptions)
if ri < len(query) {
remainder = query[ri:]
}
} else {
statement, err = sqlparser.Parse(query)
statement, err = sqlparser.ParseWithOptions(query, parserOptions)
}
if err != nil {
log.Errorf("Error while parsing prepared statement: %s", err.Error())
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
Expand All @@ -1113,6 +1119,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
return nil
}

// Walk the parsed statement tree and find any SQLVal nodes that are parameterized.
paramsCount := uint16(0)
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
Expand Down
8 changes: 8 additions & 0 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/dolthub/vitess/go/vt/log"
querypb "github.com/dolthub/vitess/go/vt/proto/query"
"github.com/dolthub/vitess/go/vt/proto/vtrpc"
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/dolthub/vitess/go/vt/vterrors"
)

Expand Down Expand Up @@ -121,6 +122,13 @@ type Handler interface {
WarningCount(c *Conn) uint16

ComResetConnection(c *Conn)

// ParserOptionsForConnection returns any parser options that should be used for the given connection. For
// example, if the connection has enabled ANSI_QUOTES or ANSI SQL_MODE, then the parser needs to know that
// in order to parse queries correctly. This is primarily needed when a prepared statement request comes in,
// and the Vitess layer needs to parse the query to identify the query parameters so that the correct response
// packets can be sent.
ParserOptionsForConnection(c *Conn) (sqlparser.ParserOptions, error)
}

// Listener is the MySQL server protocol listener.
Expand Down
5 changes: 5 additions & 0 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/dolthub/vitess/go/sqltypes"
vtenv "github.com/dolthub/vitess/go/vt/env"
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/dolthub/vitess/go/vt/tlstest"
"github.com/dolthub/vitess/go/vt/vterrors"
"github.com/dolthub/vitess/go/vt/vttls"
Expand Down Expand Up @@ -111,6 +112,10 @@ func (th *testHandler) NewConnection(c *Conn) {
func (th *testHandler) ConnectionClosed(c *Conn) {
}

func (th *testHandler) ParserOptionsForConnection(c *Conn) (sqlparser.ParserOptions, error) {
return sqlparser.ParserOptions{}, nil
}

func (th *testHandler) ComInitDB(c *Conn, schemaName string) error {
return nil
}
Expand Down
40 changes: 34 additions & 6 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,54 @@ func yyParsePooled(yylex yyLexer) int {
// a set of types, define the function as iTypeName.
// This will help avoid name collisions.

// ParserOptions defines options that customize how statements are parsed.
type ParserOptions struct {
// AnsiQuotes controls whether " characters are treated as the identifier character, as
// defined in the SQL92 standard, or as a string quote character. By default, AnsiQuotes is
// disabled, and ` characters are treated as the identifier character (and not a string
fulghum marked this conversation as resolved.
Show resolved Hide resolved
// quoting character). When AnsiQuotes is set to true, " characters are instead treated
// as identifier quotes and NOT valid as string quotes. Note that the ` character may always
// be used to quote identifiers, regardless of whether AnsiQuotes is enabled or not. For
// more info, see: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_ansi_quotes
AnsiQuotes bool
}

// Parse parses the SQL in full and returns a Statement, which
// is the AST representation of the query. If a DDL statement
// is partially parsed but still contains a syntax error, the
// error is ignored and the DDL is returned anyway.
func Parse(sql string) (Statement, error) {
return ParseWithOptions(sql, ParserOptions{})
}

// ParseWithOptions fully parses the SQL in |sql|, using any custom options specified
// in |options|, and returns a Statement, which is the AST representation of the query.
// If a DDL statement is partially parsed but contains a syntax error, the
// error is ignored and the DDL is returned anyway.
func ParseWithOptions(sql string, options ParserOptions) (Statement, error) {
tokenizer := NewStringTokenizer(sql)
if options.AnsiQuotes {
tokenizer = NewStringTokenizerForAnsiQuotes(sql)
}
return parseTokenizer(sql, tokenizer)
}

// ParseOne parses the first SQL statement in the given string and returns the
// index of the start of the next statement in |sql|. If there was only one
// statement in |sql|, the value of the returned index will be |len(sql)|.
func ParseOne(sql string) (Statement, int, error) {
return ParseOneWithOptions(sql, ParserOptions{})
}

// ParseOneWithOptions parses the first SQL statement in |sql|, using any parsing
// options specified in |options|, and returns the parsed Statement, along with
// the index of the start of the next statement in |sql|. If there was only one
// statement in |sql|, the value of the returned index will be |len(sql)|.
func ParseOneWithOptions(sql string, options ParserOptions) (Statement, int, error) {
tokenizer := NewStringTokenizer(sql)
if options.AnsiQuotes {
tokenizer = NewStringTokenizerForAnsiQuotes(sql)
}
tokenizer.stopAfterFirstStmt = true
tree, err := parseTokenizer(sql, tokenizer)
if err != nil {
Expand Down Expand Up @@ -214,12 +248,6 @@ func stringIsUnbrokenQuote(s string, quoteChar byte) bool {
return true
}

// ParseTokenizer is a raw interface to parse from the given tokenizer.
// This does not used pooled parsers, and should not be used in general.
func ParseTokenizer(tokenizer *Tokenizer) int {
return yyParse(tokenizer)
}

// ParseNext parses a single SQL statement from the tokenizer
// returning a Statement which is the AST representation of the query.
// The tokenizer will always read up to the end of the statement, allowing for
Expand Down
78 changes: 77 additions & 1 deletion go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3093,6 +3093,7 @@ var (
"description 'description'",
},
}

// Any tests that contain multiple statements within the body (such as BEGIN/END blocks) should go here.
// validSQL is used by TestParseNextValid, which expects a semicolon to mean the end of a full statement.
// Multi-statement bodies do not follow this expectation, hence they are excluded from TestParseNextValid.
Expand Down Expand Up @@ -3381,6 +3382,41 @@ end case;
end`,
},
}

// validAnsiQuotesSQL contains SQL statements that are valid when the ANSI_QUOTES SQL mode is enabled. This
// mode treats double quotes (and backticks) as identifier quotes, and single quotes as string quotes.
validAnsiQuotesSQL = []parseTest{
{
input: `select "count", "foo", "bar" from t order by "COUNT"`,
output: "select `count`, foo, bar from t order by `COUNT` asc",
},
{
input: `INSERT INTO hourly_logins ("applications_id", "count", "hour") VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE "count" = "count" + VALUES(count)`,
output: "insert into hourly_logins(applications_id, `count`, `hour`) values (:v1, :v2, :v3) on duplicate key update count = `count` + values(`count`)",
},
{
input: `CREATE TABLE "webhook_events" ("pk" int primary key, "event" varchar(255) DEFAULT NULL)`,
output: "create table webhook_events (\n\tpk int primary key,\n\t`event` varchar(255) default null\n)",
},
{
input: `with "test" as (select 1 from "dual"), "test_two" as (select 2 from "dual") select * from "test", "test_two" union all (with "b" as (with "c" as (select 1, 2 from "dual") select * from "c") select * from "b")`,
output: "with test as (select 1 from `dual`), test_two as (select 2 from `dual`) select * from test, test_two union all (with b as (with c as (select 1, 2 from `dual`) select * from c) select * from b)",
},
{
input: `select '"' from t order by "foo"`,
output: `select '\"' from t order by foo asc`,
},
{
// Assert that quote escaping is the same as when ANSI_QUOTES is off
input: `select '''foo'''`,
output: `select '\'foo\''`,
},
{
// Assert that quote escaping is the same as when ANSI_QUOTES is off
input: `select """""""foo"""""""`,
output: "select `\"\"\"foo\"\"\"`",
},
}
)

func TestValid(t *testing.T) {
Expand All @@ -3390,6 +3426,20 @@ func TestValid(t *testing.T) {
}
}

func TestAnsiQuotesMode(t *testing.T) {
parserOptions := ParserOptions{AnsiQuotes: true}
for _, tcase := range validAnsiQuotesSQL {
runParseTestCaseWithParserOptions(t, tcase, parserOptions)
}
for _, tcase := range invalidAnsiQuotesSQL {
t.Run(tcase.input, func(t *testing.T) {
_, err := ParseWithOptions(tcase.input, parserOptions)
require.NotNil(t, err)
assert.Equal(t, tcase.output, err.Error())
})
}
}

func TestSingle(t *testing.T) {
validSQL = append(validSQL, validMultiStatementSql...)
for _, tcase := range validSQL {
Expand Down Expand Up @@ -4424,12 +4474,19 @@ func TestKeywords(t *testing.T) {
}
}

// runParseTestCase runs the specific test case, |tcase|, using the default parser options.
func runParseTestCase(t *testing.T, tcase parseTest) bool {
return runParseTestCaseWithParserOptions(t, tcase, ParserOptions{})
}

// runParseTestCaseWithParserOptions runs the specific test case, |tcase|, using the specified parser
// options, |options|, to control any parser behaviors.
func runParseTestCaseWithParserOptions(t *testing.T, tcase parseTest, options ParserOptions) bool {
return t.Run(tcase.input, func(t *testing.T) {
if tcase.output == "" {
tcase.output = tcase.input
}
tree, err := Parse(tcase.input)
tree, err := ParseWithOptions(tcase.input, options)
require.NoError(t, err)

assertTestcaseOutput(t, tcase, tree)
Expand Down Expand Up @@ -6332,6 +6389,25 @@ var (
output: "You have an error in your SQL syntax; At least one event field to alter needs to be defined at position 52 near 'myevent'",
},
}

// invalidAnsiQuotesSQL contains invalid SQL statements that use ANSI_QUOTES mode.
invalidAnsiQuotesSQL = []parseTest{
{
// Assert that the two identifier quotes do not match each other
input: "select \"foo`",
output: "syntax error at position 13 near 'foo`'",
},
{
// Assert that the two identifier quotes do not match each other
input: "select `bar\"",
output: "syntax error at position 13 near 'bar\"'",
},
{
// Assert that single and double quotes do not auto concatenate in ANSI_QUOTES mode
input: "select 'a' \"b\" 'c'",
output: "syntax error at position 19 near 'c'",
},
}
)

func TestErrors(t *testing.T) {
Expand Down
Loading