Skip to content

Commit

Permalink
SNOW-1533740: use SnowflakeCreateTable for cache result (#1911)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Jul 20, 2024
1 parent c482c53 commit dc16701
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 77 deletions.
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ def do_resolve_with_resolved_children(
logical_plan.comment,
resolved_children[logical_plan.children[0]],
logical_plan,
self.session._use_scoped_temp_objects,
logical_plan.is_generated,
)

if isinstance(logical_plan, Limit):
Expand Down
91 changes: 28 additions & 63 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import snowflake.connector
import snowflake.snowpark
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
TEMPORARY_STRING_SET,
aggregate_statement,
attribute_to_schema_string,
batch_insert_into_statement,
Expand Down Expand Up @@ -763,10 +764,17 @@ def save_as_table(
clustering_keys: Iterable[str],
comment: Optional[str],
child: SnowflakePlan,
logical_plan: Optional[LogicalPlan],
source_plan: Optional[LogicalPlan],
use_scoped_temp_objects: bool,
is_generated: bool, # true if the table is generated internally
) -> SnowflakePlan:
full_table_name = ".".join(table_name)
if is_generated and mode != SaveMode.ERROR_IF_EXISTS:
raise ValueError(
"Internally generated tables must be called with mode ERROR_IF_EXISTS"
)

full_table_name = ".".join(table_name)
is_temp_table_type = table_type in TEMPORARY_STRING_SET
# here get the column definition from the child attributes. In certain cases we have
# the attributes set to ($1, VariantType()) which cannot be used as valid column name
# in save as table. So we rename ${number} with COL{number}.
Expand All @@ -791,6 +799,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
table_type=table_type,
clustering_key=clustering_keys,
comment=comment,
use_scoped_temp_objects=use_scoped_temp_objects,
is_generated=is_generated,
)

# so that dataframes created from non-select statements,
Expand All @@ -799,20 +809,21 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
return SnowflakePlan(
[
*child.queries[0:-1],
Query(create_table),
Query(create_table, is_ddl_on_temp_object=is_temp_table_type),
Query(
insert_into_statement(
table_name=full_table_name,
child=child.queries[-1].sql,
column_names=column_names,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
],
create_table,
child.post_actions,
{},
logical_plan,
source_plan,
api_calls=child.api_calls,
session=self.session,
)
Expand All @@ -826,7 +837,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
column_names=column_names,
),
child,
logical_plan,
source_plan,
)
else:
return get_create_and_insert_plan(child, replace=False, error=False)
Expand All @@ -837,7 +848,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
full_table_name, x, [x.name for x in child.attributes], True
),
child,
logical_plan,
source_plan,
)
else:
return self.build(
Expand All @@ -851,7 +862,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
logical_plan,
source_plan,
is_ddl_on_temp_object=is_temp_table_type,
)
elif mode == SaveMode.OVERWRITE:
return self.build(
Expand All @@ -865,7 +877,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
logical_plan,
source_plan,
is_ddl_on_temp_object=is_temp_table_type,
)
elif mode == SaveMode.IGNORE:
return self.build(
Expand All @@ -879,9 +892,13 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
logical_plan,
source_plan,
is_ddl_on_temp_object=is_temp_table_type,
)
elif mode == SaveMode.ERROR_IF_EXISTS:
if is_generated:
return get_create_and_insert_plan(child, replace=False, error=True)

return self.build(
lambda x: create_table_as_select_statement(
full_table_name,
Expand All @@ -892,7 +909,8 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
logical_plan,
source_plan,
is_ddl_on_temp_object=is_temp_table_type,
)

def limit(
Expand Down Expand Up @@ -997,59 +1015,6 @@ def create_or_replace_dynamic_table(
source_plan,
)

def create_temp_table(
self,
name: str,
child: SnowflakePlan,
*,
use_scoped_temp_objects: bool = False,
is_generated: bool = False,
) -> SnowflakePlan:
child = child.replace_repeated_subquery_with_cte()
return self.build_from_multiple_queries(
lambda x: self.create_table_and_insert(
self.session,
name,
child.schema_query,
x,
use_scoped_temp_objects=use_scoped_temp_objects,
is_generated=is_generated,
),
child,
None,
child.schema_query,
is_ddl_on_temp_object=True,
)

def create_table_and_insert(
self,
session,
name: str,
schema_query: str,
query: "Query",
*,
use_scoped_temp_objects: bool = False,
is_generated: bool = False,
) -> List["Query"]:
attributes = session._get_result_attributes(schema_query)
create_table = create_table_statement(
name,
attribute_to_schema_string(attributes),
table_type="temporary",
use_scoped_temp_objects=use_scoped_temp_objects,
is_generated=is_generated,
)

return [
Query(create_table),
Query(
insert_into_statement(
table_name=name, column_names=None, child=query.sql
),
params=query.params,
),
]

def read_file(
self,
path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
table_type: str = "",
clustering_exprs: Optional[Iterable[Expression]] = None,
comment: Optional[str] = None,
is_generated: bool = False,
) -> None:
super().__init__()
self.table_name = table_name
Expand All @@ -144,6 +145,7 @@ def __init__(
self.children.append(query)
self.clustering_exprs = clustering_exprs or []
self.comment = comment
self.is_generated = is_generated

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
Expand Down
18 changes: 12 additions & 6 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
CopyIntoTableNode,
Limit,
LogicalPlan,
SaveMode,
SnowflakeCreateTable,
)
from snowflake.snowpark._internal.analyzer.sort_expression import (
Ascending,
Expand Down Expand Up @@ -3959,14 +3961,18 @@ def cache_result(
if isinstance(self._session._conn, MockServerConnection):
self.write.save_as_table(temp_table_name, create_temp_table=True)
else:
create_temp_table = self._session._analyzer.plan_builder.create_temp_table(
temp_table_name,
self._plan,
use_scoped_temp_objects=self._session._use_scoped_temp_objects,
is_generated=True,
df = self._with_plan(
SnowflakeCreateTable(
[temp_table_name],
None,
SaveMode.ERROR_IF_EXISTS,
self._plan,
table_type="temp",
is_generated=True,
)
)
self._session._conn.execute(
create_temp_table,
df._plan,
_statement_params=create_or_update_statement_params_with_query_tag(
statement_params or self._statement_params,
self._session.query_tag,
Expand Down
51 changes: 43 additions & 8 deletions tests/integ/scala/test_snowflake_plan_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Query,
SnowflakePlan,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SaveMode
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.functions import col, lit, table_function
from snowflake.snowpark.session import Session
Expand Down Expand Up @@ -228,9 +229,15 @@ def test_create_scoped_temp_table(session):
df = session.table(table_name)
temp_table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
assert (
session._plan_builder.create_temp_table(
temp_table_name,
session._plan_builder.save_as_table(
[temp_table_name],
None,
SaveMode.ERROR_IF_EXISTS,
"temp",
None,
None,
df._plan,
None,
use_scoped_temp_objects=True,
is_generated=True,
)
Expand All @@ -239,26 +246,54 @@ def test_create_scoped_temp_table(session):
== f' CREATE SCOPED TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))'
)
assert (
session._plan_builder.create_temp_table(
temp_table_name,
session._plan_builder.save_as_table(
[temp_table_name],
None,
SaveMode.ERROR_IF_EXISTS,
"temp",
None,
None,
df._plan,
None,
use_scoped_temp_objects=False,
is_generated=True,
)
.queries[0]
.sql
== f' CREATE TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))'
)
assert (
session._plan_builder.create_temp_table(
temp_table_name,
expected_sql = f' CREATE TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))'
assert expected_sql in (
session._plan_builder.save_as_table(
[temp_table_name],
None,
SaveMode.ERROR_IF_EXISTS,
"temporary",
None,
None,
df._plan,
None,
use_scoped_temp_objects=True,
is_generated=False,
)
.queries[0]
.sql
== f' CREATE TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))'
)
with pytest.raises(
ValueError,
match="Internally generated tables must be called with mode ERROR_IF_EXISTS",
):
session._plan_builder.save_as_table(
[temp_table_name],
None,
SaveMode.APPEND,
"temporary",
None,
None,
df._plan,
None,
use_scoped_temp_objects=True,
is_generated=True,
)
finally:
Utils.drop_table(session, table_name)
19 changes: 19 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,25 @@ def test_df_col(session):
assert isinstance(c._expression, Star)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="Session.query_history is not supported",
run=False,
)
def test_cache_result_query(session):
df = session.create_dataframe([[1, 2]], schema=["a", "b"])
with session.query_history() as history:
df.cache_result()

assert len(history.queries) == 2
assert "CREATE SCOPED TEMPORARY TABLE" in history.queries[0].sql_text
assert (
"INSERT INTO" in history.queries[1].sql_text
and 'SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT)'
in history.queries[1].sql_text
)


def test_create_dataframe_with_basic_data_types(session):
data1 = [
1,
Expand Down

0 comments on commit dc16701

Please sign in to comment.