Skip to content

Commit

Permalink
feat: Support mysql upsert syntax (#5)
Browse files Browse the repository at this point in the history
* feat: Support mysql upsert syntax

* fix: parser_test

---------

Co-authored-by: 江 杨 <[email protected]>
  • Loading branch information
geekeryy and 江 杨 authored Mar 29, 2024
1 parent c96041c commit c95d80e
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 31 deletions.
56 changes: 33 additions & 23 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,47 +689,57 @@ type UpsertClause struct {

DoNothing bool // position of NOTHING keyword after DO
DoUpdate bool // position of UPDATE keyword after DO
DuplicateKey bool // position of ON DUPLICATE KEY UPDATE keyword
Assignments []*Assignment // list of column assignments
UpdateWhereExpr Expr // optional conditional expression for DO UPDATE SET
}

// String returns the string representation of the clause.
func (c *UpsertClause) String() string {
var buf bytes.Buffer
buf.WriteString("ON CONFLICT")

if len(c.Columns) != 0 {
buf.WriteString(" (")
for i, col := range c.Columns {
if c.DuplicateKey {
buf.WriteString("ON DUPLICATE KEY UPDATE ")
for i := range c.Assignments {
if i != 0 {
buf.WriteString(", ")
}
buf.WriteString(col.String())
buf.WriteString(c.Assignments[i].String())
}
buf.WriteString(")")
} else {
buf.WriteString("ON CONFLICT")

if c.WhereExpr != nil {
fmt.Fprintf(&buf, " WHERE %s", c.WhereExpr.String())
}
}
if len(c.Columns) != 0 {
buf.WriteString(" (")
for i, col := range c.Columns {
if i != 0 {
buf.WriteString(", ")
}
buf.WriteString(col.String())
}
buf.WriteString(")")

buf.WriteString(" DO")
if c.DoNothing {
buf.WriteString(" NOTHING")
} else {
buf.WriteString(" UPDATE SET ")
for i := range c.Assignments {
if i != 0 {
buf.WriteString(", ")
if c.WhereExpr != nil {
fmt.Fprintf(&buf, " WHERE %s", c.WhereExpr.String())
}
buf.WriteString(c.Assignments[i].String())
}

if c.UpdateWhereExpr != nil {
fmt.Fprintf(&buf, " WHERE %s", c.UpdateWhereExpr.String())
buf.WriteString(" DO")
if c.DoNothing {
buf.WriteString(" NOTHING")
} else {
buf.WriteString(" UPDATE SET ")
for i := range c.Assignments {
if i != 0 {
buf.WriteString(", ")
}
buf.WriteString(c.Assignments[i].String())
}

if c.UpdateWhereExpr != nil {
fmt.Fprintf(&buf, " WHERE %s", c.UpdateWhereExpr.String())
}
}
}

return buf.String()
}

Expand Down
22 changes: 15 additions & 7 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,12 @@ func (p *Parser) parseUpsertClause() (_ *UpsertClause, err error) {

// Parse "ON CONFLICT"
p.lex()
if p.peek() != CONFLICT {
return &clause, p.errorExpected(p.pos, p.tok, "CONFLICT")
switch p.peek() {
case CONFLICT:
case DUPLICATE:
clause.DuplicateKey = true
default:
return &clause, p.errorExpected(p.pos, p.tok, "CONFLICT or DUPLICATE")
}
p.lex()

Expand Down Expand Up @@ -244,9 +248,11 @@ func (p *Parser) parseUpsertClause() (_ *UpsertClause, err error) {
}
}

// Parse "DO NOTHING" or "DO UPDATE SET".
if p.peek() != DO {
// Parse "DO NOTHING" or "DO UPDATE SET" or "ON DUPLICATE KEY".
if !clause.DuplicateKey && p.peek() != DO {
return &clause, p.errorExpected(p.pos, p.tok, "DO")
} else if clause.DuplicateKey && p.peek() != KEY {
return &clause, p.errorExpected(p.pos, p.tok, "KEY")
}
p.lex()

Expand All @@ -262,10 +268,12 @@ func (p *Parser) parseUpsertClause() (_ *UpsertClause, err error) {
// Otherwise parse "UPDATE SET"
p.lex()
clause.DoUpdate = true
if p.peek() != SET {
return &clause, p.errorExpected(p.pos, p.tok, "SET")
if !clause.DuplicateKey {
if p.peek() != SET {
return &clause, p.errorExpected(p.pos, p.tok, "SET")
}
p.lex()
}
p.lex()

// Parse list of assignments.
for {
Expand Down
2 changes: 1 addition & 1 deletion parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ func TestParser_ParseStatement(t *testing.T) {
AssertParseStatementError(t, `INSERT INTO tbl (x) VALUES (1`, `1:29: expected comma or right paren, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO tbl (x) SELECT`, `1:26: expected expression, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO tbl (x) DEFAULT`, `1:27: expected VALUES, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO tbl (x) VALUES (1) ON`, `1:33: expected CONFLICT, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO tbl (x) VALUES (1) ON`, `1:33: expected CONFLICT or DUPLICATE, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO tbl (x) VALUES (1) ON CONFLICT (`, `1:44: expected expression, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO tbl (x) VALUES (1) ON CONFLICT (x`, `1:45: expected comma or right paren, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO tbl (x) VALUES (1) ON CONFLICT (x) WHERE`, `1:52: expected expression, found 'EOF'`)
Expand Down
2 changes: 2 additions & 0 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ const (
WINDOW
WITH
WITHOUT
DUPLICATE
keyword_end

ANY // ???
Expand Down Expand Up @@ -444,6 +445,7 @@ var tokens = [...]string{
WINDOW: "WINDOW",
WITH: "WITH",
WITHOUT: "WITHOUT",
DUPLICATE: "DUPLICATE",
}

func (tok Token) String() string {
Expand Down

0 comments on commit c95d80e

Please sign in to comment.