diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 8ef0e7e2103..92cc13a7508 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -157,6 +157,7 @@ INTERSECT = f" {Intersect.sql} " EXCEPT = f" {Except.sql} " NOT_NULL = " NOT NULL " +WITH = "WITH " TEMPORARY_STRING_SET = frozenset(["temporary", "temp"]) @@ -1382,3 +1383,11 @@ def get_file_format_spec( file_format_str += FORMAT_NAME + EQUALS + file_format_name file_format_str += RIGHT_PARENTHESIS return file_format_str + + +def cte_statement(queries: List[str], table_names: List[str]) -> str: + result = COMMA.join( + f"{table_name}{AS}{LEFT_PARENTHESIS}{query}{RIGHT_PARENTHESIS}" + for query, table_name in zip(queries, table_names) + ) + return f"{WITH}{result}" diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py new file mode 100644 index 00000000000..b5f5eb02851 --- /dev/null +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -0,0 +1,121 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from collections import defaultdict +from typing import TYPE_CHECKING, Set + +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + SPACE, + cte_statement, + project_statement, +) +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) + +if TYPE_CHECKING: + from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + + +def find_duplicate_subtrees(root: "SnowflakePlan") -> Set["SnowflakePlan"]: + """ + Returns a set containing all duplicate subtrees in query plan tree. + The root of a duplicate subtree is defined as a duplicate node, if + - it appears more than once in the tree, AND + - one of its parent is unique (only appear once) in the tree, OR + - it has multiple different parents + + For example, + root + / \ + df5 df6 + / | | \ + df3 df3 df4 df4 + | | | | + df2 df2 df2 df2 + | | | | + df1 df1 df1 df1 + + df4, df3 and df2 are duplicate subtrees. + + This function is used to only include nodes that should be converted to CTEs. + """ + node_count_map = defaultdict(int) + node_parents_map = defaultdict(set) + + def traverse(node: "SnowflakePlan") -> None: + node_count_map[node] += 1 + if node.source_plan and node.source_plan.children: + for child in node.source_plan.children: + node_parents_map[child].add(node) + traverse(child) + + def is_duplicate_subtree(node: "SnowflakePlan") -> bool: + is_duplicate_node = node_count_map[node] > 1 + if is_duplicate_node: + is_any_parent_unique_node = any( + node_count_map[n] == 1 for n in node_parents_map[node] + ) + if is_any_parent_unique_node: + return True + else: + has_multi_parents = len(node_parents_map[node]) > 1 + if has_multi_parents: + return True + return False + + traverse(root) + return {node for node in node_count_map if is_duplicate_subtree(node)} + + +def create_cte_query( + node: "SnowflakePlan", duplicate_plan_set: Set["SnowflakePlan"] +) -> str: + plan_to_query_map = {} + duplicate_plan_to_cte_map = {} + duplicate_plan_to_table_name_map = {} + + def build_plan_to_query_map_in_post_order(node: "SnowflakePlan") -> None: + """ + Builds a mapping from query plans to queries that are optimized with CTEs, + in post-traversal order. We can get the final query from the mapping value of the root node. + The reason of using poster-traversal order is that chained CTEs have to be built + from bottom (innermost subquery) to top (outermost query). + """ + if not node.source_plan or node in plan_to_query_map: + return + + for child in node.source_plan.children: + build_plan_to_query_map_in_post_order(child) + + if not node.placeholder_query: + plan_to_query_map[node] = node.queries[-1].sql + else: + plan_to_query_map[node] = node.placeholder_query + for child in node.source_plan.children: + # replace the placeholder (id) with child query + plan_to_query_map[node] = plan_to_query_map[node].replace( + child._id, plan_to_query_map[child] + ) + + # duplicate subtrees will be converted CTEs + if node in duplicate_plan_set: + # when a subquery is converted a CTE to with clause, + # it will be replaced by `SELECT * from TEMP_TABLE` in the original query + table_name = random_name_for_temp_object(TempObjectType.CTE) + select_stmt = project_statement([], table_name) + duplicate_plan_to_table_name_map[node] = table_name + duplicate_plan_to_cte_map[node] = plan_to_query_map[node] + plan_to_query_map[node] = select_stmt + + build_plan_to_query_map_in_post_order(node) + + # construct with clause + with_stmt = cte_statement( + list(duplicate_plan_to_cte_map.values()), + list(duplicate_plan_to_table_name_map.values()), + ) + final_query = with_stmt + SPACE + plan_to_query_map[node] + return final_query diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 0e2e70477ac..e2661cb9ac0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -2,7 +2,8 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - +import copy +import hashlib import re import sys import uuid @@ -75,6 +76,10 @@ JoinType, SetOperation, ) +from snowflake.snowpark._internal.analyzer.cte_utils import ( + create_cte_query, + find_duplicate_subtrees, +) from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( @@ -198,6 +203,7 @@ def __init__( df_aliased_col_name_to_real_col_name: Optional[ DefaultDict[str, Dict[str, str]] ] = None, + placeholder_query: Optional[str] = None, *, session: "snowflake.snowpark.session.Session", ) -> None: @@ -220,6 +226,49 @@ def __init__( ) else: self.df_aliased_col_name_to_real_col_name = defaultdict(dict) + # In the placeholder query, subquery (child) is held by the ID of query plan + # It is used for optimization, by replacing a subquery with a CTE + self.placeholder_query = placeholder_query + # encode an id for CTE optimization + self._id = hashlib.sha256( + f"{queries[-1].sql}#{queries[-1].params}".encode() + ).hexdigest()[:10] + + def __eq__(self, other: "SnowflakePlan") -> bool: + return isinstance(other, SnowflakePlan) and (self._id == other._id) + + def __hash__(self) -> int: + return hash(self._id) + + def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": + # parameter protection + # TODO SNOW-106671: enable cte optimization with sql simplifier + if ( + not self.session._cte_optimization_enabled + or self.session._sql_simplifier_enabled + ): + return self + + # if source_plan is none, it must be a leaf node, no optimization is needed + if self.source_plan is None: + return self + + # only select statement can be converted to CTEs + if not is_sql_select_statement(self.queries[-1].sql): + return self + + # if there is no duplicate node, no optimization will be performed + duplicate_plan_set = find_duplicate_subtrees(self) + if not duplicate_plan_set: + return self + + # create CTE query + final_query = create_cte_query(self, duplicate_plan_set) + + # all other parts of query are unchanged, but just replace the original query + plan = copy.copy(self) + plan.queries[-1].sql = final_query + return plan def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePlan": pre_queries = self.queries[:-1] @@ -271,17 +320,32 @@ def output_dict(self) -> Dict[str, Any]: return self._output_dict def __copy__(self) -> "SnowflakePlan": - return SnowflakePlan( - self.queries.copy() if self.queries else [], - self.schema_query, - self.post_actions.copy() if self.post_actions else None, - dict(self.expr_to_alias) if self.expr_to_alias else None, - self.source_plan, - self.is_ddl_on_temp_object, - self.api_calls.copy() if self.api_calls else None, - self.df_aliased_col_name_to_real_col_name, - session=self.session, - ) + if self.session._cte_optimization_enabled: + return SnowflakePlan( + copy.deepcopy(self.queries) if self.queries else [], + self.schema_query, + copy.deepcopy(self.post_actions) if self.post_actions else None, + dict(self.expr_to_alias) if self.expr_to_alias else None, + self.source_plan, + self.is_ddl_on_temp_object, + copy.deepcopy(self.api_calls) if self.api_calls else None, + self.df_aliased_col_name_to_real_col_name, + session=self.session, + placeholder_query=self.placeholder_query, + ) + else: + return SnowflakePlan( + self.queries.copy() if self.queries else [], + self.schema_query, + self.post_actions.copy() if self.post_actions else None, + dict(self.expr_to_alias) if self.expr_to_alias else None, + self.source_plan, + self.is_ddl_on_temp_object, + self.api_calls.copy() if self.api_calls else None, + self.df_aliased_col_name_to_real_col_name, + session=self.session, + placeholder_query=self.placeholder_query, + ) def add_aliases(self, to_add: Dict) -> None: self.expr_to_alias = {**self.expr_to_alias, **to_add} @@ -312,6 +376,7 @@ def build( new_schema_query = ( schema_query if schema_query else sql_generator(child.schema_query) ) + placeholder_query = sql_generator(select_child._id) return SnowflakePlan( queries, @@ -323,6 +388,7 @@ def build( api_calls=select_child.api_calls, df_aliased_col_name_to_real_col_name=child.df_aliased_col_name_to_real_col_name, session=self.session, + placeholder_query=placeholder_query, ) @SnowflakePlan.Decorator.wrap_exception @@ -348,6 +414,7 @@ def build_from_multiple_queries( if schema_query is not None else multi_sql_generator(Query(child.schema_query))[-1].sql ) + placeholder_query = multi_sql_generator(Query(child._id))[-1].sql return SnowflakePlan( queries, @@ -357,6 +424,7 @@ def build_from_multiple_queries( source_plan, api_calls=select_child.api_calls, session=self.session, + placeholder_query=placeholder_query, ) @SnowflakePlan.Decorator.wrap_exception @@ -388,6 +456,7 @@ def build_binary( left_schema_query = schema_value_statement(select_left.attributes) right_schema_query = schema_value_statement(select_right.attributes) schema_query = sql_generator(left_schema_query, right_schema_query) + placeholder_query = sql_generator(select_left._id, select_right._id) common_columns = set(select_left.expr_to_alias.keys()).intersection( select_right.expr_to_alias.keys() @@ -410,6 +479,7 @@ def build_binary( source_plan, api_calls=api_calls, session=self.session, + placeholder_query=placeholder_query, ) def query( @@ -504,7 +574,10 @@ def aggregate( ) def filter( - self, condition: str, child: SnowflakePlan, source_plan: Optional[LogicalPlan] + self, + condition: str, + child: SnowflakePlan, + source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: return self.build(lambda x: filter_statement(condition, x), child, source_plan) @@ -525,7 +598,10 @@ def sample( ) def sort( - self, order: List[str], child: SnowflakePlan, source_plan: Optional[LogicalPlan] + self, + order: List[str], + child: SnowflakePlan, + source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: return self.build(lambda x: sort_statement(order, x), child, source_plan) @@ -585,6 +661,8 @@ def save_as_table( column_definition_with_hidden_columns, ) + child = child.replace_repeated_subquery_with_cte() + def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): create_table = create_table_statement( full_table_name, @@ -734,6 +812,7 @@ def create_or_replace_view( if not is_sql_select_statement(child.queries[0].sql.lower().strip()): raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEWS_FROM_SELECT_ONLY() + child = child.replace_repeated_subquery_with_cte() return self.build( lambda x: create_or_replace_view_statement(name, x, is_temp), child, @@ -753,6 +832,7 @@ def create_or_replace_dynamic_table( if not is_sql_select_statement(child.queries[0].sql.lower().strip()): raise SnowparkClientExceptionMessages.PLAN_CREATE_DYNAMIC_TABLE_FROM_SELECT_ONLY() + child = child.replace_repeated_subquery_with_cte() return self.build( lambda x: create_or_replace_dynamic_table_statement( name, warehouse, lag, x @@ -769,6 +849,7 @@ def create_temp_table( use_scoped_temp_objects: bool = False, is_generated: bool = False, ) -> SnowflakePlan: + child = child.replace_repeated_subquery_with_cte() return self.build_from_multiple_queries( lambda x: self.create_table_and_insert( self.session, @@ -1050,6 +1131,7 @@ def copy_into_location( header: bool = False, **copy_options: Optional[Any], ) -> SnowflakePlan: + query = query.replace_repeated_subquery_with_cte() return self.build( lambda x: copy_into_location( query=x, @@ -1075,6 +1157,7 @@ def update( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: + source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: update_statement( table_name, @@ -1104,6 +1187,7 @@ def delete( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: + source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: delete_statement( table_name, @@ -1131,6 +1215,7 @@ def merge( clauses: List[str], source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: + source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: merge_statement(table_name, x, join_expr, clauses), source_data, diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index a650087d6a7..85d38f5fe3d 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -554,7 +554,8 @@ def get_result_set( Union[List[ResultMetadata], List["ResultMetadataV2"]], ]: action_id = plan.session._generate_new_action_id() - + # potentially optimize the query using CTEs + plan = plan.replace_repeated_subquery_with_cte() result, result_meta = None, None try: placeholders = {} diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index c4ec0bd2e4a..1c5fe3c1be7 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -188,6 +188,7 @@ class TempObjectType(Enum): TABLE_FUNCTION = "TABLE_FUNCTION" DYNAMIC_TABLE = "DYNAMIC_TABLE" AGGREGATE_FUNCTION = "AGGREGATE_FUNCTION" + CTE = "CTE" def validate_object_name(name: str): diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 89c783c37c8..214a89400c4 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -3746,9 +3746,10 @@ def queries(self) -> Dict[str, List[str]]: evaluate this DataFrame with the key `queries`, and a list of post-execution actions (e.g., queries to clean up temporary objects) with the key `post_actions`. """ + plan = self._plan.replace_repeated_subquery_with_cte() return { - "queries": [query.sql.strip() for query in self._plan.queries], - "post_actions": [query.sql.strip() for query in self._plan.post_actions], + "queries": [query.sql.strip() for query in plan.queries], + "post_actions": [query.sql.strip() for query in plan.post_actions], } def explain(self) -> None: @@ -3762,19 +3763,20 @@ def explain(self) -> None: print(self._explain_string()) def _explain_string(self) -> str: + plan = self._plan.replace_repeated_subquery_with_cte() output_queries = "\n---\n".join( - f"{i+1}.\n{query.sql.strip()}" for i, query in enumerate(self._plan.queries) + f"{i+1}.\n{query.sql.strip()}" for i, query in enumerate(plan.queries) ) msg = f"""---------DATAFRAME EXECUTION PLAN---------- Query List: {output_queries}""" # if query list contains more then one queries, skip execution plan - if len(self._plan.queries) == 1: - exec_plan = self._session._explain_query(self._plan.queries[0].sql) + if len(plan.queries) == 1: + exec_plan = self._session._explain_query(plan.queries[0].sql) if exec_plan: msg = f"{msg}\nLogical Execution Plan:\n{exec_plan}" else: - msg = f"{self._plan.queries[0].sql} can't be explained" + msg = f"{plan.queries[0].sql} can't be explained" return f"{msg}\n--------------------------------------------" diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py index 2968131cb34..ed0f4eaf810 100644 --- a/src/snowflake/snowpark/mock/_connection.py +++ b/src/snowflake/snowpark/mock/_connection.py @@ -343,7 +343,8 @@ def _to_data_or_iter( ) if to_iter else _fix_pandas_df_fixed_type( - results_cursor.fetch_pandas_all(split_blocks=True), results_cursor + results_cursor.fetch_pandas_all(split_blocks=True), + results_cursor, ) ) except NotSupportedError: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 02c2afa3b08..88d2401d426 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -460,6 +460,7 @@ def __init__( _PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING, True ) ) + self._cte_optimization_enabled: bool = False self._use_logical_type_for_create_df: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME_STRING, True diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py new file mode 100644 index 00000000000..a2e098f6202 --- /dev/null +++ b/tests/integ/test_cte.py @@ -0,0 +1,231 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import re + +import pytest + +from snowflake.snowpark._internal.analyzer import analyzer +from snowflake.snowpark._internal.utils import ( + TEMP_OBJECT_NAME_PREFIX, + TempObjectType, + random_name_for_temp_object, +) +from snowflake.snowpark.functions import col, when_matched +from tests.utils import Utils + +WITH = "WITH" + + +@pytest.fixture(autouse=True) +def setup(session): + # TODO SNOW-106671: enable cte optimization with sql simplifier + is_sql_simplifier_enabled = session._sql_simplifier_enabled + is_cte_optimization_enabled = session._cte_optimization_enabled + session._sql_simplifier_enabled = False + session._cte_optimization_enabled = True + yield + session._sql_simplifier_enabled = is_sql_simplifier_enabled + session._cte_optimization_enabled = is_cte_optimization_enabled + + +def check_result(session, df, expect_cte_optimized): + session._cte_optimization_enabled = False + result = df.collect() + + session._cte_optimization_enabled = True + cte_result = df.collect() + + Utils.check_answer(cte_result, result) + last_query = df.queries["queries"][-1] + if expect_cte_optimized: + assert last_query.startswith(WITH) + assert last_query.count(WITH) == 1 + else: + assert last_query.count(WITH) == 0 + + +def count_number_of_ctes(query): + # a CTE is represented with a pattern `SNOWPARK_TEMP_xxx AS` + pattern = re.compile(rf"{TEMP_OBJECT_NAME_PREFIX}CTE_[0-9A-Z]+\sAS") + return len(pattern.findall(query)) + + +@pytest.mark.parametrize( + "action", + [ + lambda x: x.select("a", "b").select("b"), + lambda x: x.filter(col("a") == 1).select("b"), + lambda x: x.select("a").filter(col("a") == 1), + ], +) +def test_no_duplicate_unary(session, action): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + check_result(session, action(df), expect_cte_optimized=False) + + +@pytest.mark.parametrize( + "action", + [ + lambda x, y: x.union_all(y), + lambda x, y: x.select("a").union_all(y.select("a")), + lambda x, y: x.except_(y), + lambda x, y: x.select("a").except_(y.select("a")), + lambda x, y: x.join(y.select("a", "b"), rsuffix="_y"), + lambda x, y: x.select("a").join(y, rsuffix="_y"), + lambda x, y: x.join(y.select("a"), rsuffix="_y"), + ], +) +def test_binary(session, action): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + check_result(session, action(df, df), expect_cte_optimized=True) + + df1 = session.create_dataframe([[3, 4], [2, 1]], schema=["a", "b"]) + check_result(session, action(df, df1), expect_cte_optimized=False) + + # multiple queries + original_threshold = analyzer.ARRAY_BIND_THRESHOLD + try: + analyzer.ARRAY_BIND_THRESHOLD = 2 + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_threshold + check_result(session, action(df2, df2), expect_cte_optimized=True) + + +@pytest.mark.parametrize( + "action", + [ + lambda x, y: x.union_all(y), + lambda x, y: x.join(y.select("a")), + ], +) +def test_number_of_ctes(session, action): + df3 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df2 = df3.filter(col("a") == 1) + df1 = df2.select("*") + + # only df1 will be converted to a CTE + root = action(df1, df1) + check_result(session, root, expect_cte_optimized=True) + assert count_number_of_ctes(root.queries["queries"][-1]) == 1 + + # df1 and df3 will be converted to CTEs + root = action(root, df3) + check_result(session, root, expect_cte_optimized=True) + assert count_number_of_ctes(root.queries["queries"][-1]) == 2 + + # df1, df2 and df3 will be converted to CTEs + root = action(root, df2) + check_result(session, root, expect_cte_optimized=True) + assert count_number_of_ctes(root.queries["queries"][-1]) == 3 + + +def test_different_df_same_query(session): + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df = df2.union_all(df1) + check_result(session, df, expect_cte_optimized=True) + assert count_number_of_ctes(df.queries["queries"][-1]) == 1 + + +def test_same_duplicate_subtree(session): + """ + root + / \ + df3 df3 + | | + df2 df2 + | | + df1 df1 + + Only should df3 be converted to a CTE + """ + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df2 = df1.filter(col("a") == 1) + df3 = df2.select("b") + df_result1 = df3.union_all(df3) + check_result(session, df_result1, expect_cte_optimized=True) + assert count_number_of_ctes(df_result1.queries["queries"][-1]) == 1 + + """ + root + / \ + df5 df6 + / | | \ + df3 df3 df4 df4 + | | | | + df2 df2 df2 df2 + | | | | + df1 df1 df1 df1 + + df4, df3 and df2 should be converted to CTEs + """ + df4 = df2.select("a") + df_result2 = df3.union_all(df3).union_all(df4.union_all(df4)) + check_result(session, df_result2, expect_cte_optimized=True) + assert count_number_of_ctes(df_result2.queries["queries"][-1]) == 3 + + +@pytest.mark.parametrize("mode", ["append", "overwrite", "errorifexists", "ignore"]) +def test_save_as_table(session, mode): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + with session.query_history() as query_history: + df.union_all(df).write.save_as_table( + random_name_for_temp_object(TempObjectType.TABLE), + table_type="temp", + mode=mode, + ) + query = query_history.queries[-1].sql_text + assert query.count(WITH) == 1 + assert count_number_of_ctes(query) == 1 + + +def test_create_or_replace_view(session): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + with session.query_history() as query_history: + df.union_all(df).create_or_replace_temp_view( + random_name_for_temp_object(TempObjectType.VIEW) + ) + query = query_history.queries[-1].sql_text + assert query.count(WITH) == 1 + assert count_number_of_ctes(query) == 1 + + +def test_table_update_delete_merge(session): + table_name = random_name_for_temp_object(TempObjectType.VIEW) + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df.write.save_as_table(table_name, table_type="temp") + source_df = df.union_all(df) + t = session.table(table_name) + + # update + with session.query_history() as query_history: + t.update({"b": 0}, t.a == source_df.a, source_df) + query = query_history.queries[-1].sql_text + assert query.count(WITH) == 1 + assert count_number_of_ctes(query) == 1 + + # delete + with session.query_history() as query_history: + t.delete(t.a == source_df.a, source_df) + query = query_history.queries[-1].sql_text + assert query.count(WITH) == 1 + assert count_number_of_ctes(query) == 1 + + # merge + with session.query_history() as query_history: + t.merge( + source_df, t.a == source_df.a, [when_matched().update({"b": source_df.b})] + ) + query = query_history.queries[-1].sql_text + assert query.count(WITH) == 1 + assert count_number_of_ctes(query) == 1 + + +def test_explain(session): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + explain_string = df.union_all(df)._explain_string() + assert "WithReference" in explain_string + assert "WithClause" in explain_string diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py new file mode 100644 index 00000000000..4cf5240461b --- /dev/null +++ b/tests/unit/test_cte.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from unittest import mock + +import pytest + +from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + + +def test_case1(): + nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] + for i, node in enumerate(nodes): + node.source_plan = node + node._id = i + nodes[0].children = [nodes[1], nodes[3]] + nodes[1].children = [nodes[2], nodes[2]] + nodes[2].children = [nodes[4]] + nodes[3].children = [nodes[5], nodes[6]] + nodes[4].children = [nodes[5]] + nodes[5].children = [] + nodes[6].children = [] + + expected_duplicate_subtree_ids = {2, 5} + return nodes[0], expected_duplicate_subtree_ids + + +def test_case2(): + nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] + for i, node in enumerate(nodes): + node.source_plan = node + node._id = i + nodes[0].children = [nodes[1], nodes[3]] + nodes[1].children = [nodes[2], nodes[2]] + nodes[2].children = [nodes[4], nodes[4]] + nodes[3].children = [nodes[6], nodes[6]] + nodes[4].children = [nodes[5]] + nodes[5].children = [] + nodes[6].children = [nodes[4], nodes[4]] + + expected_duplicate_subtree_ids = {2, 4, 6} + return nodes[0], expected_duplicate_subtree_ids + + +@pytest.mark.parametrize("test_case", [test_case1(), test_case2()]) +def test_find_duplicate_subtrees(test_case): + plan1, expected_duplicate_subtree_ids = test_case + duplicate_subtrees = find_duplicate_subtrees(plan1) + assert {node._id for node in duplicate_subtrees} == expected_duplicate_subtree_ids diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index c2168de21f7..5c7e4fb921e 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -107,6 +107,7 @@ def test_dataframe_method_alias(): def test_copy_into_format_name_syntax(format_type, sql_simplifier_enabled): fake_session = mock.create_autospec(snowflake.snowpark.session.Session) fake_session.sql_simplifier_enabled = sql_simplifier_enabled + fake_session._cte_optimization_enabled = False fake_session._conn = mock.create_autospec(ServerConnection) fake_session._plan_builder = SnowflakePlanBuilder(fake_session) fake_session._analyzer = Analyzer(fake_session) diff --git a/tests/unit/test_server_connection.py b/tests/unit/test_server_connection.py index fd7e7b1388a..e1fc2a4a7d8 100644 --- a/tests/unit/test_server_connection.py +++ b/tests/unit/test_server_connection.py @@ -117,6 +117,7 @@ def test_get_result_set_exception(mock_server_connection): fake_session._generate_new_action_id.return_value = 1 fake_session._last_canceled_id = 100 fake_session._conn = mock_server_connection + fake_session._cte_optimization_enabled = False fake_plan = SnowflakePlan( queries=[Query("fake query 1"), Query("fake query 2")], schema_query="fake schema query",