Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu committed Mar 5, 2024
1 parent 73650b7 commit e97a05a
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 104 deletions.
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,8 @@ def get_file_format_spec(


def cte_statement(queries: List[str], table_names: List[str]) -> str:
return WITH + COMMA.join(
table_name + AS + LEFT_PARENTHESIS + query + RIGHT_PARENTHESIS
result = COMMA.join(
f"{table_name}{AS}{LEFT_PARENTHESIS}{query}{RIGHT_PARENTHESIS}"
for query, table_name in zip(queries, table_names)
)
return f"{WITH}{result}"
121 changes: 121 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/cte_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from collections import defaultdict
from typing import TYPE_CHECKING, Set

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
SPACE,
cte_statement,
project_statement,
)
from snowflake.snowpark._internal.utils import (
TempObjectType,
random_name_for_temp_object,
)

if TYPE_CHECKING:
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan


def find_duplicate_subtrees(root: "SnowflakePlan") -> Set["SnowflakePlan"]:
"""
Returns a set containing all duplicate subtrees in query plan tree.
The root of a duplicate subtree is defined as a duplicate node, if
- it appears more than once in the tree, AND
- one of its parent is unique (only appear once) in the tree, OR
- it has multiple different parents
For example,
root
/ \
df5 df6
/ | | \
df3 df3 df4 df4
| | | |
df2 df2 df2 df2
| | | |
df1 df1 df1 df1
df4, df3 and df2 are duplicate subtrees.
This function is used to only include nodes that should be converted to CTEs.
"""
node_count_map = defaultdict(int)
node_parents_map = defaultdict(set)

def traverse(node: "SnowflakePlan") -> None:
node_count_map[node] += 1
if node.source_plan and node.source_plan.children:
for child in node.source_plan.children:
node_parents_map[child].add(node)
traverse(child)

def is_duplicate_subtree(node: "SnowflakePlan") -> bool:
is_duplicate_node = node_count_map[node] > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
node_count_map[n] == 1 for n in node_parents_map[node]
)
if is_any_parent_unique_node:
return True
else:
has_multi_parents = len(node_parents_map[node]) > 1
if has_multi_parents:
return True
return False

traverse(root)
return {node for node in node_count_map if is_duplicate_subtree(node)}


def create_cte_query(
node: "SnowflakePlan", duplicate_plan_set: Set["SnowflakePlan"]
) -> str:
plan_to_query_map = {}
duplicate_plan_to_cte_map = {}
duplicate_plan_to_table_name_map = {}

def build_plan_to_query_map_in_post_order(node: "SnowflakePlan") -> None:
"""
Builds a mapping from query plans to queries that are optimized with CTEs,
in post-traversal order. We can get the final query from the mapping value of the root node.
The reason of using poster-traversal order is that chained CTEs have to be built
from bottom (innermost subquery) to top (outermost query).
"""
if not node.source_plan or node in plan_to_query_map:
return

for child in node.source_plan.children:
build_plan_to_query_map_in_post_order(child)

if not node.placeholder_query:
plan_to_query_map[node] = node.queries[-1].sql
else:
plan_to_query_map[node] = node.placeholder_query
for child in node.source_plan.children:
# replace the placeholder (id) with child query
plan_to_query_map[node] = plan_to_query_map[node].replace(
child._id, plan_to_query_map[child]
)

# duplicate subtrees will be converted CTEs
if node in duplicate_plan_set:
# when a subquery is converted a CTE to with clause,
# it will be replaced by `SELECT * from TEMP_TABLE` in the original query
table_name = random_name_for_temp_object(TempObjectType.CTE)
select_stmt = project_statement([], table_name)
duplicate_plan_to_table_name_map[node] = table_name
duplicate_plan_to_cte_map[node] = plan_to_query_map[node]
plan_to_query_map[node] = select_stmt

build_plan_to_query_map_in_post_order(node)

# construct with clause
with_stmt = cte_statement(
list(duplicate_plan_to_cte_map.values()),
list(duplicate_plan_to_table_name_map.values()),
)
final_query = with_stmt + SPACE + plan_to_query_map[node]
return final_query
107 changes: 6 additions & 101 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
List,
Optional,
Sequence,
Set,
Tuple,
)

Expand All @@ -37,7 +36,6 @@
import snowflake.connector
import snowflake.snowpark
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
SPACE,
aggregate_statement,
attribute_to_schema_string,
batch_insert_into_statement,
Expand All @@ -48,7 +46,6 @@
create_or_replace_view_statement,
create_table_as_select_statement,
create_table_statement,
cte_statement,
delete_statement,
drop_file_format_if_exists_statement,
drop_table_if_exists_statement,
Expand Down Expand Up @@ -79,6 +76,10 @@
JoinType,
SetOperation,
)
from snowflake.snowpark._internal.analyzer.cte_utils import (
create_cte_query,
find_duplicate_subtrees,
)
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
Expand Down Expand Up @@ -256,109 +257,13 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan":
if not is_sql_select_statement(self.queries[-1].sql):
return self

def find_duplicate_subtrees(root: SnowflakePlan) -> Set[SnowflakePlan]:
"""
Returns a set containing all duplicate subtrees in query plan tree.
The root of a duplicate subtree is defined as a duplicate node, if
- it appears more than once in the tree, AND
- one of its parent is unique (only appear once) in the tree, OR
- it has multiple different parents
For example,
root
/ \
df5 df6
/ | | \
df3 df3 df4 df4
| | | |
df2 df2 df2 df2
| | | |
df1 df1 df1 df1
df4, df3 and df2 are duplicate subtrees.
This function is used to only include nodes that should be converted to CTEs.
"""
node_count_map = defaultdict(int)
node_parents_map = {root: set()}

def traverse(node: SnowflakePlan) -> None:
node_count_map[node] += 1
if node.source_plan and node.source_plan.children:
for child in node.source_plan.children:
if child not in node_parents_map:
node_parents_map[child] = {node}
else:
node_parents_map[child].add(node)
traverse(child)

def is_duplicate_subtree(node: SnowflakePlan) -> bool:
is_duplicate_node = node_count_map[node] > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
node_count_map[n] == 1 for n in node_parents_map[node]
)
if is_any_parent_unique_node:
return True
else:
has_multi_parents = len(node_parents_map[node]) > 1
if has_multi_parents:
return True
return False

traverse(root)
return {node for node in node_count_map if is_duplicate_subtree(node)}

# if there is no duplicate node, no optimization will be performed
duplicate_plan_set = find_duplicate_subtrees(self)
if not duplicate_plan_set:
return self

plan_to_query_map = {}
duplicate_plan_to_cte_map = {}
duplicate_plan_to_table_name_map = {}

def build_plan_to_query_map_in_post_order(node: SnowflakePlan) -> None:
"""
Builds a mapping from query plans to queries that are optimized with CTEs,
in post-traversal order. We can get the final query from the mapping value of the root node.
The reason of using poster-traversal order is that chained CTEs have to be built
from bottom (innermost subquery) to top (outermost query).
"""
if not node.source_plan or node in plan_to_query_map:
return

for child in node.source_plan.children:
build_plan_to_query_map_in_post_order(child)

if not node.placeholder_query:
plan_to_query_map[node] = node.queries[-1].sql
else:
plan_to_query_map[node] = node.placeholder_query
for child in node.source_plan.children:
# replace the placeholder (id) with child query
plan_to_query_map[node] = plan_to_query_map[node].replace(
child._id, plan_to_query_map[child]
)

# duplicate subtrees will be converted CTEs
if node in duplicate_plan_set:
# when a subquery is converted a CTE to with clause,
# it will be replaced by `SELECT * from TEMP_TABLE` in the original query
table_name = random_name_for_temp_object(TempObjectType.TABLE)
select_stmt = project_statement([], table_name)
duplicate_plan_to_table_name_map[node] = table_name
duplicate_plan_to_cte_map[node] = plan_to_query_map[node]
plan_to_query_map[node] = select_stmt

build_plan_to_query_map_in_post_order(self)

# construct with clause
with_stmt = cte_statement(
list(duplicate_plan_to_cte_map.values()),
list(duplicate_plan_to_table_name_map.values()),
)
final_query = with_stmt + SPACE + plan_to_query_map[self]
# create CTE query
final_query = create_cte_query(self, duplicate_plan_set)

# all other parts of query are unchanged, but just replace the original query
plan = copy.copy(self)
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class TempObjectType(Enum):
TABLE_FUNCTION = "TABLE_FUNCTION"
DYNAMIC_TABLE = "DYNAMIC_TABLE"
AGGREGATE_FUNCTION = "AGGREGATE_FUNCTION"
CTE = "CTE"


def validate_object_name(name: str):
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def check_result(session, df, expect_cte_optimized):

def count_number_of_ctes(query):
# a CTE is represented with a pattern `SNOWPARK_TEMP_xxx AS`
pattern = re.compile(rf"{TEMP_OBJECT_NAME_PREFIX}TABLE_[0-9A-Z]+\sAS")
pattern = re.compile(rf"{TEMP_OBJECT_NAME_PREFIX}CTE_[0-9A-Z]+\sAS")
return len(pattern.findall(query))


Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_cte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from unittest import mock

import pytest

from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan


def test_case1():
nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)]
for i, node in enumerate(nodes):
node.source_plan = node
node._id = i
nodes[0].children = [nodes[1], nodes[3]]
nodes[1].children = [nodes[2], nodes[2]]
nodes[2].children = [nodes[4]]
nodes[3].children = [nodes[5], nodes[6]]
nodes[4].children = [nodes[5]]
nodes[5].children = []
nodes[6].children = []

expected_duplicate_subtree_ids = {2, 5}
return nodes[0], expected_duplicate_subtree_ids


def test_case2():
nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)]
for i, node in enumerate(nodes):
node.source_plan = node
node._id = i
nodes[0].children = [nodes[1], nodes[3]]
nodes[1].children = [nodes[2], nodes[2]]
nodes[2].children = [nodes[4], nodes[4]]
nodes[3].children = [nodes[6], nodes[6]]
nodes[4].children = [nodes[5]]
nodes[5].children = []
nodes[6].children = [nodes[4], nodes[4]]

expected_duplicate_subtree_ids = {2, 4, 6}
return nodes[0], expected_duplicate_subtree_ids


@pytest.mark.parametrize("test_case", [test_case1(), test_case2()])
def test_find_duplicate_subtrees(test_case):
plan1, expected_duplicate_subtree_ids = test_case
duplicate_subtrees = find_duplicate_subtrees(plan1)
assert {node._id for node in duplicate_subtrees} == expected_duplicate_subtree_ids

0 comments on commit e97a05a

Please sign in to comment.