Skip to content

Commit

Permalink
bugfix: Handle CTEs with columns named in the CTE def
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Nov 8, 2024
1 parent d9ab9f7 commit dc0fce9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
9 changes: 9 additions & 0 deletions go/test/endtoend/vtgate/vitess_tester/cte/queries.test
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ WITH RECURSIVE numbers AS (SELECT 1 AS n
SELECT *
FROM numbers;

# Simple recursive CTE using literal values, column named in the CTE def
WITH RECURSIVE numbers(n) AS (SELECT 1
UNION ALL
SELECT n + 1
FROM numbers
WHERE n < 5)
SELECT *
FROM numbers;

# Recursive CTE joined with a normal table
WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id
FROM employees
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) {
}
}

func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string, distinct bool) {
func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string, distinct bool, columns sqlparser.Columns) {
cteUnion := &sqlparser.Union{
Left: qb.stmt.(sqlparser.SelectStatement),
Right: other.stmt.(sqlparser.SelectStatement),
Expand All @@ -254,7 +254,7 @@ func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string
Recursive: true,
CTEs: []*sqlparser.CommonTableExpr{{
ID: sqlparser.NewIdentifierCS(name),
Columns: nil,
Columns: columns,
Subquery: cteUnion,
}},
},
Expand Down Expand Up @@ -726,7 +726,7 @@ func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) {
panic(err)
}

qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct)
qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct, op.Def.Columns)
}

func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/cte_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -2260,8 +2260,8 @@
"Name": "main",
"Sharded": false
},
"FieldQuery": "with recursive cte as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1",
"Query": "with recursive cte as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte",
"FieldQuery": "with recursive cte(n) as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1",
"Query": "with recursive cte(n) as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte",
"Table": "dual"
},
"TablesUsed": [
Expand Down

0 comments on commit dc0fce9

Please sign in to comment.