From 79c24342ab2a9d4fe0338b3a371fb9cb19fcdc26 Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Fri, 26 Jan 2024 23:35:53 +0000 Subject: [PATCH] init --- .../snowpark/_internal/analyzer/analyzer.py | 4 + .../_internal/analyzer/analyzer_utils.py | 19 +++ .../_internal/analyzer/snowflake_plan.py | 153 +++++++++++++++--- .../_internal/analyzer/snowflake_plan_node.py | 7 + src/snowflake/snowpark/dataframe.py | 5 + tests/integ/test_cte.py | 110 +++++++++++++ 6 files changed, 278 insertions(+), 20 deletions(-) create mode 100644 tests/integ/test_cte.py diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 3bf9e11e99d..ba0f8f0b5a7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -89,6 +89,7 @@ SnowflakePlanBuilder, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + CTE, CopyIntoLocationNode, CopyIntoTableNode, Limit, @@ -1126,6 +1127,9 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Selectable): return self.plan_builder.select_statement(logical_plan) + if isinstance(logical_plan, CTE): + return self.plan_builder.cte(resolved_children[logical_plan.child]) + raise TypeError( f"Cannot resolve type logical_plan of {type(logical_plan).__name__} to a SnowflakePlan" ) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 2b7a11f4d02..a56e44c7590 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -157,6 +157,7 @@ INTERSECT = f" {Intersect.sql} " EXCEPT = f" {Except.sql} " NOT_NULL = " NOT NULL " +WITH = " WITH " TEMPORARY_STRING_SET = frozenset(["temporary", "temp"]) @@ -1382,3 +1383,21 @@ def get_file_format_spec( file_format_str += FORMAT_NAME + EQUALS + file_format_name file_format_str += RIGHT_PARENTHESIS return file_format_str + + +def cte_statement(query: str, table_name: str) -> str: + return WITH + table_name + AS + LEFT_PARENTHESIS + query + RIGHT_PARENTHESIS + + +def combine_cte_statements(statements: List[str]) -> str: + # order is maintained + statements_without_dups = list(dict.fromkeys(statements)) + if len(statements_without_dups) == 1: + return statements_without_dups[0] + elif len(statements_without_dups) > 1: + return ( + statements_without_dups[0] + + COMMA + + SPACE + + COMMA.join(s[len(WITH) :] for s in statements_without_dups[1:]) + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index bba846ce46b..9381245c357 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -35,9 +35,11 @@ import snowflake.connector import snowflake.snowpark from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + SPACE, aggregate_statement, attribute_to_schema_string, batch_insert_into_statement, + combine_cte_statements, copy_into_location, copy_into_table, create_file_format_statement, @@ -45,6 +47,7 @@ create_or_replace_view_statement, create_table_as_select_statement, create_table_statement, + cte_statement, delete_statement, drop_file_format_if_exists_statement, drop_table_if_exists_statement, @@ -301,14 +304,24 @@ def build( is_ddl_on_temp_object: bool = False, ) -> SnowflakePlan: select_child = self.add_result_scan_if_not_select(child) - queries = select_child.queries[:-1] + [ - Query( - sql_generator(select_child.queries[-1].sql), + last_query = select_child.queries[-1] + if isinstance(last_query, CTEQuery): + new_query = CTEQuery( + sql_generator(last_query.select_sql), + last_query.with_sql, query_id_place_holder="", is_ddl_on_temp_object=is_ddl_on_temp_object, - params=select_child.queries[-1].params, + params=last_query.params, ) - ] + else: + new_query = Query( + sql_generator(last_query.sql), + query_id_place_holder="", + is_ddl_on_temp_object=is_ddl_on_temp_object, + params=last_query.params, + ) + + queries = select_child.queries[:-1] + [new_query] new_schema_query = ( schema_query if schema_query else sql_generator(child.schema_query) ) @@ -369,21 +382,37 @@ def build_binary( ) -> SnowflakePlan: select_left = self.add_result_scan_if_not_select(left) select_right = self.add_result_scan_if_not_select(right) - queries = ( - select_left.queries[:-1] - + select_right.queries[:-1] - + [ - Query( - sql_generator( - select_left.queries[-1].sql, select_right.queries[-1].sql - ), - params=[ - *select_left.queries[-1].params, - *select_right.queries[-1].params, - ], - ) - ] - ) + select_left_last_query = select_left.queries[-1] + select_right_last_query = select_right.queries[-1] + params = [*select_left_last_query.params, *select_right_last_query.params] + + is_left_cte = isinstance(select_left_last_query, CTEQuery) + is_right_cte = isinstance(select_right_last_query, CTEQuery) + if is_left_cte and is_right_cte: + sql = sql_generator( + select_left_last_query.select_sql, select_right_last_query.select_sql + ) + with_sql = combine_cte_statements( + [select_left_last_query.with_sql, select_right_last_query.with_sql] + ) + elif is_left_cte: + sql = sql_generator( + select_left_last_query.select_sql, select_right_last_query.sql + ) + with_sql = select_left_last_query.with_sql + elif is_right_cte: + sql = sql_generator( + select_left_last_query.sql, select_right_last_query.select_sql + ) + with_sql = select_right_last_query.with_sql + else: + sql = sql_generator(select_left_last_query.sql, select_right_last_query.sql) + with_sql = None + if with_sql: + last_query = CTEQuery(sql, with_sql, params=params) + else: + last_query = Query(sql, params=params) + queries = select_left.queries[:-1] + select_right.queries[:-1] + [last_query] left_schema_query = schema_value_statement(select_left.attributes) right_schema_query = schema_value_statement(select_right.attributes) @@ -1205,6 +1234,33 @@ def add_result_scan_if_not_select(self, plan: SnowflakePlan) -> SnowflakePlan: session=self.session, ) + def cte(self, plan: SnowflakePlan) -> SnowflakePlan: + last_query = plan.queries[-1] + # CTE only supports select query + assert is_sql_select_statement(last_query.sql) + if isinstance(last_query, CTEQuery): + # once the frontend CTE API is called, we always move the select sql to be a CTE + new_query = last_query.move_select_sql_to_cte() + else: + new_query = CTEQuery( + last_query.sql, + query_id_place_holder=last_query.query_id_place_holder, + is_ddl_on_temp_object=last_query.is_ddl_on_temp_object, + params=last_query.params, + ) + queries = plan.queries[:-1] + [new_query] + return SnowflakePlan( + queries, + plan.schema_query, + plan.post_actions, + plan.expr_to_alias, + plan.source_plan, + plan.is_ddl_on_temp_object, + api_calls=plan.api_calls, + df_aliased_col_name_to_real_col_name=plan.df_aliased_col_name_to_real_col_name, + session=self.session, + ) + class Query: def __init__( @@ -1250,3 +1306,60 @@ def __init__( ) -> None: super().__init__(sql) self.rows = rows + + +class CTEQuery(Query): + """ + CTEQuery class has two extra attributes, `with_sql` and `select_sql`, which are broken down + from the full sql (`sql`). + """ + + def __init__( + self, + select_sql: str, + with_sql: Optional[str] = None, + *, + query_id_place_holder: Optional[str] = None, + is_ddl_on_temp_object: bool = False, + params: Optional[Sequence[Any]] = None, + ) -> None: + if with_sql: + self.with_sql = with_sql + self.select_sql = select_sql + else: + # When with_sql is None, we convert the plain select sql to + # a with sql + a select * from table sql + self.with_sql, self.select_sql = CTEQuery.generate_with_and_select_sql( + select_sql + ) + super().__init__( + self.with_sql + SPACE + self.select_sql, + query_id_place_holder=query_id_place_holder, + is_ddl_on_temp_object=is_ddl_on_temp_object, + params=params, + ) + + @staticmethod + def generate_with_and_select_sql(sql: str) -> Tuple[str, str]: + # The rewritten logic is (for example), + # select a, b from t1 -> + # with t2 as (select a, b from t1) select * from t2 + temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) + select_sql = project_statement([], temp_table_name) + with_sql = cte_statement(sql, temp_table_name) + return with_sql, select_sql + + def move_select_sql_to_cte(self) -> "CTEQuery": + # The rewritten logic is (for example), + # with t2 as (select a, b from t1) select a, b from t2 -> + # with t3 as (select a, b from t2), t2 as (select a, b from t1) select * from t3 + # and it can actually be called recursively + with_sql, select_sql = CTEQuery.generate_with_and_select_sql(self.select_sql) + with_sql = combine_cte_statements([self.with_sql, with_sql]) + return CTEQuery( + select_sql, + with_sql, + query_id_place_holder=self.query_id_place_holder, + is_ddl_on_temp_object=self.is_ddl_on_temp_object, + params=self.params, + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 3c57658ea4a..dc98c7d0805 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -154,3 +154,10 @@ def __init__( self.file_format_name = file_format_name self.file_format_type = file_format_type self.copy_options = copy_options + + +class CTE(LogicalPlan): + def __init__(self, child: LogicalPlan) -> None: + super().__init__() + self.child = child + self.children.append(child) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 6f4c447675b..ff2a892d991 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -55,6 +55,7 @@ SelectTableFunction, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + CTE, CopyIntoTableNode, Limit, LogicalPlan, @@ -3737,6 +3738,10 @@ def random_split( ] return res_dfs + def _as_cte(self) -> "DataFrame": + """Returns a new DataFrame where select query is rewritten using common table expressions (CTE).""" + return self._with_plan(CTE(self._plan)) + @property def queries(self) -> Dict[str, List[str]]: """ diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py new file mode 100644 index 00000000000..c39d4ab3471 --- /dev/null +++ b/tests/integ/test_cte.py @@ -0,0 +1,110 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import re + +import pytest + +from snowflake.snowpark._internal.analyzer import analyzer +from snowflake.snowpark._internal.utils import TEMP_OBJECT_NAME_PREFIX +from snowflake.snowpark.functions import col +from tests.utils import Utils + +WITH = "WITH" +original_threshold = analyzer.ARRAY_BIND_THRESHOLD + + +@pytest.fixture(autouse=True) +def cleanup(session): + yield + analyzer.ARRAY_BIND_THRESHOLD = original_threshold + + +@pytest.fixture(params=[False, True]) +def has_multi_queries(request): + return request.param + + +@pytest.fixture(scope="function") +def df(session, has_multi_queries): + # TODO SNOW-1020742: Integerate CTE support with sql simplifier + session.sql_simplifier_enabled = False + if has_multi_queries: + analyzer.ARRAY_BIND_THRESHOLD = 2 + else: + analyzer.ARRAY_BIND_THRESHOLD = original_threshold + return session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + + +def is_select_star_from_table_query(query): + # check whether the query ends with `SELECT * FROM (SNOWPARK_TEMP_TABLE_XXX)` + pattern = re.compile( + rf"SELECT\s+\*\s+FROM \({TEMP_OBJECT_NAME_PREFIX}TABLE_[0-9A-Z]+\)$" + ) + return bool(pattern.search(query)) + + +def check_result(df_result, df_cte_result): + Utils.check_answer(df_result, df_cte_result) + last_query = df_cte_result.queries["queries"][-1] + assert last_query.startswith(WITH) + assert last_query.count(WITH) == 1 + + # move the select sql to CTE + df_cte_result2 = df_cte_result._as_cte() + Utils.check_answer(df_result, df_cte_result2) + last_query = df_cte_result2.queries["queries"][-1] + assert last_query.startswith(WITH) + assert last_query.count(WITH) == 1 + assert is_select_star_from_table_query(last_query) + + +@pytest.mark.parametrize( + "action", + [ + lambda x: x, + lambda x: x.select("a"), + lambda x: x.select("a", "b"), + lambda x: x.select("*"), + lambda x: x.select("a", "b").select("b"), + lambda x: x.filter(col("a") == 1), + lambda x: x.filter(col("a") == 1).select("b"), + lambda x: x.select("a").filter(col("a") == 1), + lambda x: x.sort("a", ascending=False), + lambda x: x.filter(col("a") == 1).sort("a"), + lambda x: x.limit(1), + lambda x: x.sort("a").limit(1), + lambda x: x.drop("b"), + lambda x: x.select("a", "b").drop("b"), + lambda x: x.agg({"a": "count", "b": "sum"}), + lambda x: x.group_by("a").min("b"), + ], +) +def test_basic(session, df, action): + df_cte = df._as_cte() + df_result, df_cte_result = action(df), action(df_cte) + check_result(df_result, df_cte_result) + + +@pytest.mark.parametrize( + "action", + [ + lambda x, y: x.union_all(y), + lambda x, y: x.select("a").union_all(y.select("a")), + lambda x, y: x.except_(y), + lambda x, y: x.select("a").except_(y.select("a")), + lambda x, y: x.join(y.select("a", "b"), rsuffix="_y"), + lambda x, y: x.select("a").join(y, rsuffix="_y"), + lambda x, y: x.join(y.select("a"), rsuffix="_y"), + ], +) +def test_binary(session, df, action): + df_cte = df._as_cte() + df_result = action(df, df) + for df_cte_result in [ + action(df_cte, df), + action(df, df_cte), + action(df_cte, df_cte), + ]: + check_result(df_result, df_cte_result)