Skip to content

Commit

Permalink
[SNOW-1545149] Ensure source plan is attached for Table and View crea…
Browse files Browse the repository at this point in the history
…te SnowflakePlan (#1938)
  • Loading branch information
sfc-gh-yzou authored Jul 19, 2024
1 parent ef7ba37 commit 2fbd524
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
5 changes: 5 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ def do_resolve_with_resolved_children(
],
logical_plan.comment,
resolved_children[logical_plan.children[0]],
logical_plan,
)

if isinstance(logical_plan, Limit):
Expand Down Expand Up @@ -1109,6 +1110,7 @@ def do_resolve_with_resolved_children(
resolved_children[logical_plan.child],
is_temp,
logical_plan.comment,
logical_plan,
)

if isinstance(logical_plan, CreateDynamicTableCommand):
Expand All @@ -1118,6 +1120,7 @@ def do_resolve_with_resolved_children(
logical_plan.lag,
logical_plan.comment,
resolved_children[logical_plan.child],
logical_plan,
)

if isinstance(logical_plan, CopyIntoTableNode):
Expand All @@ -1134,6 +1137,7 @@ def do_resolve_with_resolved_children(
path=logical_plan.file_path,
table_name=logical_plan.table_name,
files=logical_plan.files,
source_plan=logical_plan,
pattern=logical_plan.pattern,
file_format=logical_plan.file_format,
format_type_options=format_type_options,
Expand All @@ -1154,6 +1158,7 @@ def do_resolve_with_resolved_children(
return self.plan_builder.copy_into_location(
query=resolved_children[logical_plan.child],
stage_location=logical_plan.stage_location,
source_plan=logical_plan,
partition_by=self.analyze(
logical_plan.partition_by, df_aliased_col_name_to_real_col_name
)
Expand Down
57 changes: 45 additions & 12 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,15 @@
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
CopyIntoLocationNode,
CopyIntoTableNode,
LogicalPlan,
SaveMode,
SnowflakeCreateTable,
)
from snowflake.snowpark._internal.analyzer.unary_plan_node import (
CreateDynamicTableCommand,
CreateViewCommand,
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.utils import (
Expand Down Expand Up @@ -291,6 +298,21 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan":
if self.source_plan is None or self.placeholder_query is None:
return self

# When the source plan node is an instance of nodes in pre_handled_logical_node,
# the cte optimization has been pre-handled during the plan build step, skip the
# optimization step for now.
# TODO: Once SNOW-1541094 is done, we will be able to unify all the optimization steps, and
# there is no need for such check anymore.
pre_handled_logical_node = (
CreateDynamicTableCommand,
CreateViewCommand,
SnowflakeCreateTable,
CopyIntoTableNode,
CopyIntoLocationNode,
)
if isinstance(self.source_plan, pre_handled_logical_node):
return self

# only select statement can be converted to CTEs
if not is_sql_select_statement(self.queries[-1].sql):
return self
Expand Down Expand Up @@ -741,6 +763,7 @@ def save_as_table(
clustering_keys: Iterable[str],
comment: Optional[str],
child: SnowflakePlan,
logical_plan: Optional[LogicalPlan],
) -> SnowflakePlan:
full_table_name = ".".join(table_name)

Expand Down Expand Up @@ -789,7 +812,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
create_table,
child.post_actions,
{},
None,
logical_plan,
api_calls=child.api_calls,
session=self.session,
)
Expand All @@ -803,7 +826,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
column_names=column_names,
),
child,
None,
logical_plan,
)
else:
return get_create_and_insert_plan(child, replace=False, error=False)
Expand All @@ -814,7 +837,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
full_table_name, x, [x.name for x in child.attributes], True
),
child,
None,
logical_plan,
)
else:
return self.build(
Expand All @@ -828,7 +851,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
None,
logical_plan,
)
elif mode == SaveMode.OVERWRITE:
return self.build(
Expand All @@ -842,7 +865,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
None,
logical_plan,
)
elif mode == SaveMode.IGNORE:
return self.build(
Expand All @@ -856,7 +879,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
None,
logical_plan,
)
elif mode == SaveMode.ERROR_IF_EXISTS:
return self.build(
Expand All @@ -869,7 +892,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
comment=comment,
),
child,
None,
logical_plan,
)

def limit(
Expand Down Expand Up @@ -930,7 +953,12 @@ def rename(
)

def create_or_replace_view(
self, name: str, child: SnowflakePlan, is_temp: bool, comment: Optional[str]
self,
name: str,
child: SnowflakePlan,
is_temp: bool,
comment: Optional[str],
source_plan: Optional[LogicalPlan],
) -> SnowflakePlan:
if len(child.queries) != 1:
raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEW_FROM_DDL_DML_OPERATIONS()
Expand All @@ -942,7 +970,7 @@ def create_or_replace_view(
return self.build(
lambda x: create_or_replace_view_statement(name, x, is_temp, comment),
child,
None,
source_plan,
)

def create_or_replace_dynamic_table(
Expand All @@ -952,6 +980,7 @@ def create_or_replace_dynamic_table(
lag: str,
comment: Optional[str],
child: SnowflakePlan,
source_plan: Optional[LogicalPlan],
) -> SnowflakePlan:
if len(child.queries) != 1:
raise SnowparkClientExceptionMessages.PLAN_CREATE_DYNAMIC_TABLE_FROM_DDL_DML_OPERATIONS()
Expand All @@ -965,7 +994,7 @@ def create_or_replace_dynamic_table(
name, warehouse, lag, comment, x
),
child,
None,
source_plan,
)

def create_temp_table(
Expand Down Expand Up @@ -1190,6 +1219,7 @@ def copy_into_table(
file_format: str,
table_name: Iterable[str],
path: str,
source_plan: Optional[LogicalPlan],
files: Optional[str] = None,
pattern: Optional[str] = None,
validation_mode: Optional[str] = None,
Expand Down Expand Up @@ -1245,12 +1275,15 @@ def copy_into_table(
raise SnowparkClientExceptionMessages.DF_COPY_INTO_CANNOT_CREATE_TABLE(
full_table_name
)
return SnowflakePlan(queries, copy_command, [], {}, None, session=self.session)
return SnowflakePlan(
queries, copy_command, [], {}, source_plan, session=self.session
)

def copy_into_location(
self,
query: SnowflakePlan,
stage_location: str,
source_plan: Optional[LogicalPlan],
partition_by: Optional[str] = None,
file_format_name: Optional[str] = None,
file_format_type: Optional[str] = None,
Expand All @@ -1271,7 +1304,7 @@ def copy_into_location(
**copy_options,
),
query,
None,
source_plan,
query.schema_query,
)

Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/mock/_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ def do_resolve_with_resolved_children(
return self.plan_builder.copy_into_location(
query=resolved_children[logical_plan.child],
stage_location=logical_plan.stage_location,
source_plan=logical_plan,
partition_by=self.analyze(logical_plan.partition_by, expr_to_alias)
if logical_plan.partition_by
else None,
Expand Down

0 comments on commit 2fbd524

Please sign in to comment.