diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 09e944afd6b..ff7d2897e1f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -982,6 +982,8 @@ def do_resolve_with_resolved_children( logical_plan.comment, resolved_children[logical_plan.children[0]], logical_plan, + self.session._use_scoped_temp_objects, + logical_plan.is_generated, ) if isinstance(logical_plan, Limit): diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 5a55ab615eb..9f9d39055d8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -41,6 +41,7 @@ import snowflake.connector import snowflake.snowpark from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + TEMPORARY_STRING_SET, aggregate_statement, attribute_to_schema_string, batch_insert_into_statement, @@ -763,10 +764,17 @@ def save_as_table( clustering_keys: Iterable[str], comment: Optional[str], child: SnowflakePlan, - logical_plan: Optional[LogicalPlan], + source_plan: Optional[LogicalPlan], + use_scoped_temp_objects: bool, + is_generated: bool, # true if the table is generated internally ) -> SnowflakePlan: - full_table_name = ".".join(table_name) + if is_generated and mode != SaveMode.ERROR_IF_EXISTS: + raise ValueError( + "Internally generated tables must be called with mode ERROR_IF_EXISTS" + ) + full_table_name = ".".join(table_name) + is_temp_table_type = table_type in TEMPORARY_STRING_SET # here get the column definition from the child attributes. In certain cases we have # the attributes set to ($1, VariantType()) which cannot be used as valid column name # in save as table. So we rename ${number} with COL{number}. @@ -791,6 +799,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): table_type=table_type, clustering_key=clustering_keys, comment=comment, + use_scoped_temp_objects=use_scoped_temp_objects, + is_generated=is_generated, ) # so that dataframes created from non-select statements, @@ -799,7 +809,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): return SnowflakePlan( [ *child.queries[0:-1], - Query(create_table), + Query(create_table, is_ddl_on_temp_object=is_temp_table_type), Query( insert_into_statement( table_name=full_table_name, @@ -807,12 +817,13 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): column_names=column_names, ), params=child.queries[-1].params, + is_ddl_on_temp_object=is_temp_table_type, ), ], create_table, child.post_actions, {}, - logical_plan, + source_plan, api_calls=child.api_calls, session=self.session, ) @@ -826,7 +837,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): column_names=column_names, ), child, - logical_plan, + source_plan, ) else: return get_create_and_insert_plan(child, replace=False, error=False) @@ -837,7 +848,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): full_table_name, x, [x.name for x in child.attributes], True ), child, - logical_plan, + source_plan, ) else: return self.build( @@ -851,7 +862,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): comment=comment, ), child, - logical_plan, + source_plan, + is_ddl_on_temp_object=is_temp_table_type, ) elif mode == SaveMode.OVERWRITE: return self.build( @@ -865,7 +877,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): comment=comment, ), child, - logical_plan, + source_plan, + is_ddl_on_temp_object=is_temp_table_type, ) elif mode == SaveMode.IGNORE: return self.build( @@ -879,9 +892,13 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): comment=comment, ), child, - logical_plan, + source_plan, + is_ddl_on_temp_object=is_temp_table_type, ) elif mode == SaveMode.ERROR_IF_EXISTS: + if is_generated: + return get_create_and_insert_plan(child, replace=False, error=True) + return self.build( lambda x: create_table_as_select_statement( full_table_name, @@ -892,7 +909,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): comment=comment, ), child, - logical_plan, + source_plan, + is_ddl_on_temp_object=is_temp_table_type, ) def limit( @@ -997,59 +1015,6 @@ def create_or_replace_dynamic_table( source_plan, ) - def create_temp_table( - self, - name: str, - child: SnowflakePlan, - *, - use_scoped_temp_objects: bool = False, - is_generated: bool = False, - ) -> SnowflakePlan: - child = child.replace_repeated_subquery_with_cte() - return self.build_from_multiple_queries( - lambda x: self.create_table_and_insert( - self.session, - name, - child.schema_query, - x, - use_scoped_temp_objects=use_scoped_temp_objects, - is_generated=is_generated, - ), - child, - None, - child.schema_query, - is_ddl_on_temp_object=True, - ) - - def create_table_and_insert( - self, - session, - name: str, - schema_query: str, - query: "Query", - *, - use_scoped_temp_objects: bool = False, - is_generated: bool = False, - ) -> List["Query"]: - attributes = session._get_result_attributes(schema_query) - create_table = create_table_statement( - name, - attribute_to_schema_string(attributes), - table_type="temporary", - use_scoped_temp_objects=use_scoped_temp_objects, - is_generated=is_generated, - ) - - return [ - Query(create_table), - Query( - insert_into_statement( - table_name=name, column_names=None, child=query.sql - ), - params=query.params, - ), - ] - def read_file( self, path: str, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 10a43a9fcd6..946829b0a0b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -134,6 +134,7 @@ def __init__( table_type: str = "", clustering_exprs: Optional[Iterable[Expression]] = None, comment: Optional[str] = None, + is_generated: bool = False, ) -> None: super().__init__() self.table_name = table_name @@ -144,6 +145,7 @@ def __init__( self.children.append(query) self.clustering_exprs = clustering_exprs or [] self.comment = comment + self.is_generated = is_generated @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index b2bed904ee1..47b4731d890 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -60,6 +60,8 @@ CopyIntoTableNode, Limit, LogicalPlan, + SaveMode, + SnowflakeCreateTable, ) from snowflake.snowpark._internal.analyzer.sort_expression import ( Ascending, @@ -3959,14 +3961,18 @@ def cache_result( if isinstance(self._session._conn, MockServerConnection): self.write.save_as_table(temp_table_name, create_temp_table=True) else: - create_temp_table = self._session._analyzer.plan_builder.create_temp_table( - temp_table_name, - self._plan, - use_scoped_temp_objects=self._session._use_scoped_temp_objects, - is_generated=True, + df = self._with_plan( + SnowflakeCreateTable( + [temp_table_name], + None, + SaveMode.ERROR_IF_EXISTS, + self._plan, + table_type="temp", + is_generated=True, + ) ) self._session._conn.execute( - create_temp_table, + df._plan, _statement_params=create_or_update_statement_params_with_query_tag( statement_params or self._statement_params, self._session.query_tag, diff --git a/tests/integ/scala/test_snowflake_plan_suite.py b/tests/integ/scala/test_snowflake_plan_suite.py index c5bb285a575..123c7e31802 100644 --- a/tests/integ/scala/test_snowflake_plan_suite.py +++ b/tests/integ/scala/test_snowflake_plan_suite.py @@ -14,6 +14,7 @@ Query, SnowflakePlan, ) +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SaveMode from snowflake.snowpark._internal.utils import TempObjectType from snowflake.snowpark.functions import col, lit, table_function from snowflake.snowpark.session import Session @@ -228,9 +229,15 @@ def test_create_scoped_temp_table(session): df = session.table(table_name) temp_table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) assert ( - session._plan_builder.create_temp_table( - temp_table_name, + session._plan_builder.save_as_table( + [temp_table_name], + None, + SaveMode.ERROR_IF_EXISTS, + "temp", + None, + None, df._plan, + None, use_scoped_temp_objects=True, is_generated=True, ) @@ -239,9 +246,15 @@ def test_create_scoped_temp_table(session): == f' CREATE SCOPED TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))' ) assert ( - session._plan_builder.create_temp_table( - temp_table_name, + session._plan_builder.save_as_table( + [temp_table_name], + None, + SaveMode.ERROR_IF_EXISTS, + "temp", + None, + None, df._plan, + None, use_scoped_temp_objects=False, is_generated=True, ) @@ -249,16 +262,38 @@ def test_create_scoped_temp_table(session): .sql == f' CREATE TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))' ) - assert ( - session._plan_builder.create_temp_table( - temp_table_name, + expected_sql = f' CREATE TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))' + assert expected_sql in ( + session._plan_builder.save_as_table( + [temp_table_name], + None, + SaveMode.ERROR_IF_EXISTS, + "temporary", + None, + None, df._plan, + None, use_scoped_temp_objects=True, is_generated=False, ) .queries[0] .sql - == f' CREATE TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))' ) + with pytest.raises( + ValueError, + match="Internally generated tables must be called with mode ERROR_IF_EXISTS", + ): + session._plan_builder.save_as_table( + [temp_table_name], + None, + SaveMode.APPEND, + "temporary", + None, + None, + df._plan, + None, + use_scoped_temp_objects=True, + is_generated=True, + ) finally: Utils.drop_table(session, table_name) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index f5fb4d5d8d8..176ae57281d 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -1462,6 +1462,25 @@ def test_df_col(session): assert isinstance(c._expression, Star) +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="Session.query_history is not supported", + run=False, +) +def test_cache_result_query(session): + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + with session.query_history() as history: + df.cache_result() + + assert len(history.queries) == 2 + assert "CREATE SCOPED TEMPORARY TABLE" in history.queries[0].sql_text + assert ( + "INSERT INTO" in history.queries[1].sql_text + and 'SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT)' + in history.queries[1].sql_text + ) + + def test_create_dataframe_with_basic_data_types(session): data1 = [ 1,