-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1060064: Eliminate repeated subquery using CTE (part 1, with original query generation framework) #1274
Conversation
src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py
Outdated
Show resolved
Hide resolved
c6fc3de
to
ba5220f
Compare
with_stmt = cte_statement( | ||
list(duplicate_plan_to_cte_map.values()), | ||
list(duplicate_plan_to_table_name_map.values()), | ||
) | ||
final_query = with_stmt + SPACE + plan_to_query_map[self] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are finding the topmost duplicate node, we don't need to concern ourselves with finding the correct order, right. This is cool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, because we build duplicate_plan_to_cte_map
in post order, where children node is ahead of parent node. It satisfies the requirement of with clause:
A WITH clause can refer recursively to itself, and to other CTEs that appear earlier in the same clause. For instance, cte_name2 can refer to cte_name1 and itself, while cte_name1 can refer to itself, but not to cte_name2.
|
||
def cte_statement(queries: List[str], table_names: List[str]) -> str: | ||
return WITH + COMMA.join( | ||
table_name + AS + LEFT_PARENTHESIS + query + RIGHT_PARENTHESIS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better use string format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
This function is used to only include nodes that should be converted to CTEs. | ||
""" | ||
node_count_map = defaultdict(int) | ||
node_parents_map = {root: set()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel a defaultdict
is good for this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def traverse(node: SnowflakePlan) -> None: | ||
node_count_map[node] += 1 | ||
if node.source_plan and node.source_plan.children: | ||
for child in node.source_plan.children: | ||
if child not in node_parents_map: | ||
node_parents_map[child] = {node} | ||
else: | ||
node_parents_map[child].add(node) | ||
traverse(child) | ||
|
||
def is_duplicate_subtree(node: SnowflakePlan) -> bool: | ||
is_duplicate_node = node_count_map[node] > 1 | ||
if is_duplicate_node: | ||
is_any_parent_unique_node = any( | ||
node_count_map[n] == 1 for n in node_parents_map[node] | ||
) | ||
if is_any_parent_unique_node: | ||
return True | ||
else: | ||
has_multi_parents = len(node_parents_map[node]) > 1 | ||
if has_multi_parents: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding some unit test code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, very good point, added
duplicate_plan_to_cte_map = {} | ||
duplicate_plan_to_table_name_map = {} | ||
|
||
def build_plan_to_query_map_in_post_order(node: SnowflakePlan) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found it's difficult/not worthwhile to write unit test for this function because what we're going to test is the generate sql, and I need to create many fake queries in unit tests. It's already tested in integration tests. But I moved them to a util file, which is more clear to review the code
@@ -272,15 +416,16 @@ def output_dict(self) -> Dict[str, Any]: | |||
|
|||
def __copy__(self) -> "SnowflakePlan": | |||
return SnowflakePlan( | |||
self.queries.copy() if self.queries else [], | |||
copy.deepcopy(self.queries) if self.queries else [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is deepcopy required here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because self.queries contains Query object, instead of query string, which can be used in original query plan object
2b60dee
to
6fcf71f
Compare
Please answer these questions before submitting your pull requests. Thanks!
What GitHub issue is this PR addressing? Make sure that there is an accompanying issue to your PR.
Fixes SNOW-1060064
Fill out the following pre-review checklist:
Please describe how your code solves the related issue.
This PR adds an optimization to eliminate repeated subquery using CTE (see details in https://docs.google.com/document/d/1vVUYqLeD_nQRVaH3SX2c4jDNJoBGHJlg8PqiJ8qxgVM/edit), with original query generation framework.
The basic idea is to convert repeated subqueries to a CTE, e.g.,
( SELECT * FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT))) WHERE ("A" = 1 :: INT))) UNION ALL ( SELECT * FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT))) WHERE ("A" = 1 :: INT)))
will be converted to
WITH SNOWPARK_TEMP_TABLE_SD2LMA9GKT AS ( SELECT * FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT))) WHERE ("A" = 1 :: INT))) ( SELECT * FROM (SNOWPARK_TEMP_TABLE_SD2LMA9GKT)) UNION ALL ( SELECT * FROM (SNOWPARK_TEMP_TABLE_SD2LMA9GKT))