Skip to content

Commit

Permalink
parsing intersect and except (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
jycor authored Sep 20, 2023
1 parent ef1b927 commit 648c869
Show file tree
Hide file tree
Showing 8 changed files with 12,666 additions and 12,419 deletions.
46 changes: 26 additions & 20 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ type Statement interface {

type Statements []Statement

func (*Union) iStatement() {}
func (*SetOp) iStatement() {}
func (*Select) iStatement() {}
func (*Stream) iStatement() {}
func (*Insert) iStatement() {}
Expand Down Expand Up @@ -467,7 +467,7 @@ type SelectStatement interface {
}

func (*Select) iSelectStatement() {}
func (*Union) iSelectStatement() {}
func (*SetOp) iSelectStatement() {}
func (*ParenSelect) iSelectStatement() {}
func (*ValuesStatement) iSelectStatement() {}

Expand Down Expand Up @@ -701,8 +701,8 @@ func (s *ValuesStatement) walkSubtree(visit Visit) error {
return Walk(visit, s.Rows)
}

// Union represents a UNION statement.
type Union struct {
// SetOp represents a UNION, INTERSECT, and EXCEPT statement.
type SetOp struct {
Type string
Left, Right SelectStatement
OrderBy OrderBy
Expand All @@ -712,36 +712,42 @@ type Union struct {
Into *Into
}

// Union.Type
// SetOp.Type
const (
UnionStr = "union"
UnionAllStr = "union all"
UnionDistinctStr = "union distinct"
UnionStr = "union"
UnionAllStr = "union all"
UnionDistinctStr = "union distinct"
IntersectStr = "intersect"
IntersectAllStr = "intersect all"
IntersectDistinctStr = "intersect distinct"
ExceptStr = "except"
ExceptAllStr = "except all"
ExceptDistinctStr = "except distinct"
)

// AddOrder adds an order by element
func (node *Union) AddOrder(order *Order) {
func (node *SetOp) AddOrder(order *Order) {
node.OrderBy = append(node.OrderBy, order)
}

func (node *Union) SetOrderBy(orderBy OrderBy) {
func (node *SetOp) SetOrderBy(orderBy OrderBy) {
node.OrderBy = orderBy
}

func (node *Union) SetWith(w *With) {
func (node *SetOp) SetWith(w *With) {
node.With = w
}

// SetLimit sets the limit clause
func (node *Union) SetLimit(limit *Limit) {
func (node *SetOp) SetLimit(limit *Limit) {
node.Limit = limit
}

func (node *Union) SetLock(lock string) {
func (node *SetOp) SetLock(lock string) {
node.Lock = lock
}

func (node *Union) SetInto(into *Into) error {
func (node *SetOp) SetInto(into *Into) error {
if into == nil {
if r, ok := node.Right.(*Select); ok {
node.Into = r.Into
Expand All @@ -756,17 +762,17 @@ func (node *Union) SetInto(into *Into) error {
return nil
}

func (node *Union) GetInto() *Into {
func (node *SetOp) GetInto() *Into {
return node.Into
}

// Format formats the node.
func (node *Union) Format(buf *TrackedBuffer) {
func (node *SetOp) Format(buf *TrackedBuffer) {
buf.Myprintf("%v%v %s %v%v%v%s%v", node.With, node.Left, node.Type, node.Right,
node.OrderBy, node.Limit, node.Lock, node.Into)
}

func (node *Union) walkSubtree(visit Visit) error {
func (node *SetOp) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
Expand Down Expand Up @@ -1594,9 +1600,9 @@ type InsertRows interface {
SQLNode
}

func (*Select) iInsertRows() {}
func (*Union) iInsertRows() {}
func (Values) iInsertRows() {}
func (*Select) iInsertRows() {}
func (*SetOp) iInsertRows() {}
func (Values) iInsertRows() {}
func (*ParenSelect) iInsertRows() {}

// Update represents an UPDATE statement.
Expand Down
4 changes: 2 additions & 2 deletions go/vt/sqlparser/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestAddOrder(t *testing.T) {
if err != nil {
t.Error(err)
}
dst.(*Union).AddOrder(order)
dst.(*SetOp).AddOrder(order)
buf = NewTrackedBuffer(nil)
dst.Format(buf)
want = "select * from t union select * from s order by foo asc"
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestSetLimit(t *testing.T) {
if err != nil {
t.Error(err)
}
dst.(*Union).SetLimit(limit)
dst.(*SetOp).SetLimit(limit)
buf = NewTrackedBuffer(nil)
dst.Format(buf)
want = "select * from t union select * from s limit 4"
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/impossible_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) {
if node.GroupBy != nil {
node.GroupBy.Format(buf)
}
case *Union:
case *SetOp:
buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right)
default:
node.Format(buf)
Expand Down
1 change: 1 addition & 0 deletions go/vt/sqlparser/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ var keywords = map[string]int{
"int8": INT8,
"integer": INTEGER,
"interval": INTERVAL,
"intersect": INTERSECT,
"into": INTO,
"invisible": INVISIBLE,
"invoker": INVOKER,
Expand Down
21 changes: 20 additions & 1 deletion go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,26 @@ var (
}, {
input: "select /* union order by limit lock */ 1 from t union select 1 from t order by a limit 1 for update",
output: "select /* union order by limit lock */ 1 from t union select 1 from t order by a asc limit 1 for update",
}, {
},
{
input: "select 1 from t intersect select 1 from t",
},
{
input: "select 1 from t intersect all select 1 from t",
},
{
input: "select 1 from t intersect distinct select 1 from t",
},
{
input: "select 1 from t except select 1 from t",
},
{
input: "select 1 from t except all select 1 from t",
},
{
input: "select 1 from t except distinct select 1 from t",
},
{
input: "(select id, a from t order by id limit 1) union (select id, b as a from s order by id limit 1) order by a limit 1",
output: "(select id, a from t order by id asc limit 1) union (select id, b as a from s order by id asc limit 1) order by a asc limit 1",
}, {
Expand Down
88 changes: 88 additions & 0 deletions go/vt/sqlparser/precedence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,91 @@ func TestIsPrecedence(t *testing.T) {
}
}
}

func fmtSetOp(s SelectStatement) string {
switch s := s.(type) {
case *SetOp:
return fmt.Sprintf("(%s %s %s)", fmtSetOp(s.Left), s.Type, fmtSetOp(s.Right))
case *Select:
return String(s)
case *ParenSelect:
return String(s)
}
return ""
}

func TestSetOperatorPrecedence(t *testing.T) {
validSQL := []struct {
input string
output string
}{
{
input: "select 1 union select 2 union select 3 union select 4",
output: "(((select 1 union select 2) union select 3) union select 4)",
},
{
input: "select 1 intersect select 2 intersect select 3 intersect select 4",
output: "(((select 1 intersect select 2) intersect select 3) intersect select 4)",
},
{
input: "select 1 except select 2 except select 3 except select 4",
output: "(((select 1 except select 2) except select 3) except select 4)",
},

{
input: "select 1 union select 2 intersect select 3 except select 4",
output: "((select 1 union (select 2 intersect select 3)) except select 4)",
},
{
input: "select 1 union select 2 except select 3 intersect select 4",
output: "((select 1 union select 2) except (select 3 intersect select 4))",
},

{
input: "select 1 intersect select 2 union select 3 except select 4",
output: "(((select 1 intersect select 2) union select 3) except select 4)",
},
{
input: "select 1 intersect select 2 except select 3 union select 4",
output: "(((select 1 intersect select 2) except select 3) union select 4)",
},

{
input: "select 1 except select 2 intersect select 3 union select 4",
output: "((select 1 except (select 2 intersect select 3)) union select 4)",
},
{
input: "select 1 except select 2 union select 3 intersect select 4",
output: "((select 1 except select 2) union (select 3 intersect select 4))",
},

{
input: "(table a) union (table b)",
output: "((select * from a) union (select * from b))",
},
{
input: "(table a) intersect (table b)",
output: "((select * from a) intersect (select * from b))",
},
{
input: "(table a) except (table b)",
output: "((select * from a) except (select * from b))",
},
{
input: "select 1 intersect (select 2 union select 3)",
output: "(select 1 intersect (select 2 union select 3))",
},

}
for _, tcase := range validSQL {
tree, err := Parse(tcase.input)
if err != nil {
t.Error(err)
continue
}
expr := fmtSetOp(tree.(SelectStatement))
if expr != tcase.output {
t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output)
}
}
}
Loading

0 comments on commit 648c869

Please sign in to comment.