From e97a05aca2c0e5d2425911af98dfd473e035886d Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Mon, 4 Mar 2024 17:47:19 -0800 Subject: [PATCH] address comment --- .../_internal/analyzer/analyzer_utils.py | 5 +- .../snowpark/_internal/analyzer/cte_utils.py | 121 ++++++++++++++++++ .../_internal/analyzer/snowflake_plan.py | 107 +--------------- src/snowflake/snowpark/_internal/utils.py | 1 + tests/integ/test_cte.py | 2 +- tests/unit/test_cte.py | 51 ++++++++ 6 files changed, 183 insertions(+), 104 deletions(-) create mode 100644 src/snowflake/snowpark/_internal/analyzer/cte_utils.py create mode 100644 tests/unit/test_cte.py diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 64accb64a8e..92cc13a7508 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -1386,7 +1386,8 @@ def get_file_format_spec( def cte_statement(queries: List[str], table_names: List[str]) -> str: - return WITH + COMMA.join( - table_name + AS + LEFT_PARENTHESIS + query + RIGHT_PARENTHESIS + result = COMMA.join( + f"{table_name}{AS}{LEFT_PARENTHESIS}{query}{RIGHT_PARENTHESIS}" for query, table_name in zip(queries, table_names) ) + return f"{WITH}{result}" diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py new file mode 100644 index 00000000000..b5f5eb02851 --- /dev/null +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -0,0 +1,121 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from collections import defaultdict +from typing import TYPE_CHECKING, Set + +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + SPACE, + cte_statement, + project_statement, +) +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) + +if TYPE_CHECKING: + from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + + +def find_duplicate_subtrees(root: "SnowflakePlan") -> Set["SnowflakePlan"]: + """ + Returns a set containing all duplicate subtrees in query plan tree. + The root of a duplicate subtree is defined as a duplicate node, if + - it appears more than once in the tree, AND + - one of its parent is unique (only appear once) in the tree, OR + - it has multiple different parents + + For example, + root + / \ + df5 df6 + / | | \ + df3 df3 df4 df4 + | | | | + df2 df2 df2 df2 + | | | | + df1 df1 df1 df1 + + df4, df3 and df2 are duplicate subtrees. + + This function is used to only include nodes that should be converted to CTEs. + """ + node_count_map = defaultdict(int) + node_parents_map = defaultdict(set) + + def traverse(node: "SnowflakePlan") -> None: + node_count_map[node] += 1 + if node.source_plan and node.source_plan.children: + for child in node.source_plan.children: + node_parents_map[child].add(node) + traverse(child) + + def is_duplicate_subtree(node: "SnowflakePlan") -> bool: + is_duplicate_node = node_count_map[node] > 1 + if is_duplicate_node: + is_any_parent_unique_node = any( + node_count_map[n] == 1 for n in node_parents_map[node] + ) + if is_any_parent_unique_node: + return True + else: + has_multi_parents = len(node_parents_map[node]) > 1 + if has_multi_parents: + return True + return False + + traverse(root) + return {node for node in node_count_map if is_duplicate_subtree(node)} + + +def create_cte_query( + node: "SnowflakePlan", duplicate_plan_set: Set["SnowflakePlan"] +) -> str: + plan_to_query_map = {} + duplicate_plan_to_cte_map = {} + duplicate_plan_to_table_name_map = {} + + def build_plan_to_query_map_in_post_order(node: "SnowflakePlan") -> None: + """ + Builds a mapping from query plans to queries that are optimized with CTEs, + in post-traversal order. We can get the final query from the mapping value of the root node. + The reason of using poster-traversal order is that chained CTEs have to be built + from bottom (innermost subquery) to top (outermost query). + """ + if not node.source_plan or node in plan_to_query_map: + return + + for child in node.source_plan.children: + build_plan_to_query_map_in_post_order(child) + + if not node.placeholder_query: + plan_to_query_map[node] = node.queries[-1].sql + else: + plan_to_query_map[node] = node.placeholder_query + for child in node.source_plan.children: + # replace the placeholder (id) with child query + plan_to_query_map[node] = plan_to_query_map[node].replace( + child._id, plan_to_query_map[child] + ) + + # duplicate subtrees will be converted CTEs + if node in duplicate_plan_set: + # when a subquery is converted a CTE to with clause, + # it will be replaced by `SELECT * from TEMP_TABLE` in the original query + table_name = random_name_for_temp_object(TempObjectType.CTE) + select_stmt = project_statement([], table_name) + duplicate_plan_to_table_name_map[node] = table_name + duplicate_plan_to_cte_map[node] = plan_to_query_map[node] + plan_to_query_map[node] = select_stmt + + build_plan_to_query_map_in_post_order(node) + + # construct with clause + with_stmt = cte_statement( + list(duplicate_plan_to_cte_map.values()), + list(duplicate_plan_to_table_name_map.values()), + ) + final_query = with_stmt + SPACE + plan_to_query_map[node] + return final_query diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index f31f762c980..ed7734b7d5c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -18,7 +18,6 @@ List, Optional, Sequence, - Set, Tuple, ) @@ -37,7 +36,6 @@ import snowflake.connector import snowflake.snowpark from snowflake.snowpark._internal.analyzer.analyzer_utils import ( - SPACE, aggregate_statement, attribute_to_schema_string, batch_insert_into_statement, @@ -48,7 +46,6 @@ create_or_replace_view_statement, create_table_as_select_statement, create_table_statement, - cte_statement, delete_statement, drop_file_format_if_exists_statement, drop_table_if_exists_statement, @@ -79,6 +76,10 @@ JoinType, SetOperation, ) +from snowflake.snowpark._internal.analyzer.cte_utils import ( + create_cte_query, + find_duplicate_subtrees, +) from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( @@ -256,109 +257,13 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": if not is_sql_select_statement(self.queries[-1].sql): return self - def find_duplicate_subtrees(root: SnowflakePlan) -> Set[SnowflakePlan]: - """ - Returns a set containing all duplicate subtrees in query plan tree. - The root of a duplicate subtree is defined as a duplicate node, if - - it appears more than once in the tree, AND - - one of its parent is unique (only appear once) in the tree, OR - - it has multiple different parents - - For example, - root - / \ - df5 df6 - / | | \ - df3 df3 df4 df4 - | | | | - df2 df2 df2 df2 - | | | | - df1 df1 df1 df1 - - df4, df3 and df2 are duplicate subtrees. - - This function is used to only include nodes that should be converted to CTEs. - """ - node_count_map = defaultdict(int) - node_parents_map = {root: set()} - - def traverse(node: SnowflakePlan) -> None: - node_count_map[node] += 1 - if node.source_plan and node.source_plan.children: - for child in node.source_plan.children: - if child not in node_parents_map: - node_parents_map[child] = {node} - else: - node_parents_map[child].add(node) - traverse(child) - - def is_duplicate_subtree(node: SnowflakePlan) -> bool: - is_duplicate_node = node_count_map[node] > 1 - if is_duplicate_node: - is_any_parent_unique_node = any( - node_count_map[n] == 1 for n in node_parents_map[node] - ) - if is_any_parent_unique_node: - return True - else: - has_multi_parents = len(node_parents_map[node]) > 1 - if has_multi_parents: - return True - return False - - traverse(root) - return {node for node in node_count_map if is_duplicate_subtree(node)} - # if there is no duplicate node, no optimization will be performed duplicate_plan_set = find_duplicate_subtrees(self) if not duplicate_plan_set: return self - plan_to_query_map = {} - duplicate_plan_to_cte_map = {} - duplicate_plan_to_table_name_map = {} - - def build_plan_to_query_map_in_post_order(node: SnowflakePlan) -> None: - """ - Builds a mapping from query plans to queries that are optimized with CTEs, - in post-traversal order. We can get the final query from the mapping value of the root node. - The reason of using poster-traversal order is that chained CTEs have to be built - from bottom (innermost subquery) to top (outermost query). - """ - if not node.source_plan or node in plan_to_query_map: - return - - for child in node.source_plan.children: - build_plan_to_query_map_in_post_order(child) - - if not node.placeholder_query: - plan_to_query_map[node] = node.queries[-1].sql - else: - plan_to_query_map[node] = node.placeholder_query - for child in node.source_plan.children: - # replace the placeholder (id) with child query - plan_to_query_map[node] = plan_to_query_map[node].replace( - child._id, plan_to_query_map[child] - ) - - # duplicate subtrees will be converted CTEs - if node in duplicate_plan_set: - # when a subquery is converted a CTE to with clause, - # it will be replaced by `SELECT * from TEMP_TABLE` in the original query - table_name = random_name_for_temp_object(TempObjectType.TABLE) - select_stmt = project_statement([], table_name) - duplicate_plan_to_table_name_map[node] = table_name - duplicate_plan_to_cte_map[node] = plan_to_query_map[node] - plan_to_query_map[node] = select_stmt - - build_plan_to_query_map_in_post_order(self) - - # construct with clause - with_stmt = cte_statement( - list(duplicate_plan_to_cte_map.values()), - list(duplicate_plan_to_table_name_map.values()), - ) - final_query = with_stmt + SPACE + plan_to_query_map[self] + # create CTE query + final_query = create_cte_query(self, duplicate_plan_set) # all other parts of query are unchanged, but just replace the original query plan = copy.copy(self) diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index c4ec0bd2e4a..1c5fe3c1be7 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -188,6 +188,7 @@ class TempObjectType(Enum): TABLE_FUNCTION = "TABLE_FUNCTION" DYNAMIC_TABLE = "DYNAMIC_TABLE" AGGREGATE_FUNCTION = "AGGREGATE_FUNCTION" + CTE = "CTE" def validate_object_name(name: str): diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index fd4a80458e4..a2e098f6202 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -48,7 +48,7 @@ def check_result(session, df, expect_cte_optimized): def count_number_of_ctes(query): # a CTE is represented with a pattern `SNOWPARK_TEMP_xxx AS` - pattern = re.compile(rf"{TEMP_OBJECT_NAME_PREFIX}TABLE_[0-9A-Z]+\sAS") + pattern = re.compile(rf"{TEMP_OBJECT_NAME_PREFIX}CTE_[0-9A-Z]+\sAS") return len(pattern.findall(query)) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py new file mode 100644 index 00000000000..4cf5240461b --- /dev/null +++ b/tests/unit/test_cte.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from unittest import mock + +import pytest + +from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + + +def test_case1(): + nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] + for i, node in enumerate(nodes): + node.source_plan = node + node._id = i + nodes[0].children = [nodes[1], nodes[3]] + nodes[1].children = [nodes[2], nodes[2]] + nodes[2].children = [nodes[4]] + nodes[3].children = [nodes[5], nodes[6]] + nodes[4].children = [nodes[5]] + nodes[5].children = [] + nodes[6].children = [] + + expected_duplicate_subtree_ids = {2, 5} + return nodes[0], expected_duplicate_subtree_ids + + +def test_case2(): + nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] + for i, node in enumerate(nodes): + node.source_plan = node + node._id = i + nodes[0].children = [nodes[1], nodes[3]] + nodes[1].children = [nodes[2], nodes[2]] + nodes[2].children = [nodes[4], nodes[4]] + nodes[3].children = [nodes[6], nodes[6]] + nodes[4].children = [nodes[5]] + nodes[5].children = [] + nodes[6].children = [nodes[4], nodes[4]] + + expected_duplicate_subtree_ids = {2, 4, 6} + return nodes[0], expected_duplicate_subtree_ids + + +@pytest.mark.parametrize("test_case", [test_case1(), test_case2()]) +def test_find_duplicate_subtrees(test_case): + plan1, expected_duplicate_subtree_ids = test_case + duplicate_subtrees = find_duplicate_subtrees(plan1) + assert {node._id for node in duplicate_subtrees} == expected_duplicate_subtree_ids