Skip to content

Commit

Permalink
SNOW-911327: Use CTAS for save as table (#1075)
Browse files Browse the repository at this point in the history
* SNOW-911327: Use CTAS for save as table

* SNOW-911327: add clustering keys in CTAS plan builder

* changelog updates
  • Loading branch information
sfc-gh-aalam authored Oct 4, 2023
1 parent 1e1860d commit eceeb36
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bug Fixes

- Fixed a bug where imports from permanent stage locations were ignored for temporary stored procedures, UDTFs, UDFs, and UDAFs.
- Revert back to using CTAS (create table as select) statement for `Dataframe.writer.save_as_table` which does not need insert permission for writing tables.

## 1.8.0 (2023-09-14)

Expand Down
10 changes: 9 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,14 +715,22 @@ def batch_insert_into_statement(table_name: str, column_names: List[str]) -> str
def create_table_as_select_statement(
table_name: str,
child: str,
column_definition: str,
replace: bool = False,
error: bool = True,
table_type: str = EMPTY_STRING,
clustering_key: Optional[Iterable[str]] = None,
) -> str:
cluster_by_clause = (
(CLUSTER_BY + LEFT_PARENTHESIS + COMMA.join(clustering_key) + RIGHT_PARENTHESIS)
if clustering_key
else EMPTY_STRING
)
return (
f"{CREATE}{OR + REPLACE if replace else EMPTY_STRING} {table_type.upper()} {TABLE}"
f"{IF + NOT + EXISTS if not replace and not error else EMPTY_STRING}"
f" {table_name}{AS}{project_statement([], child)}"
f" {table_name}{LEFT_PARENTHESIS}{column_definition}{RIGHT_PARENTHESIS}"
f"{cluster_by_clause} {AS}{project_statement([], child)}"
)


Expand Down
50 changes: 37 additions & 13 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,12 @@ def save_as_table(
child: SnowflakePlan,
) -> SnowflakePlan:
full_table_name = ".".join(table_name)
column_definition = attribute_to_schema_string(child.attributes)

def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
create_table = create_table_statement(
full_table_name,
attribute_to_schema_string(child.attributes),
column_definition,
replace=replace,
error=error,
table_type=table_type,
Expand Down Expand Up @@ -612,20 +613,43 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
else:
return get_create_and_insert_plan(child, replace=False, error=False)
elif mode == SaveMode.OVERWRITE:
return get_create_and_insert_plan(child, replace=True)
return self.build(
lambda x: create_table_as_select_statement(
full_table_name,
x,
column_definition,
replace=True,
table_type=table_type,
clustering_key=clustering_keys,
),
child,
None,
)
elif mode == SaveMode.IGNORE:
if self.session._table_exists(table_name):
return self.build(
lambda x: create_table_as_select_statement(
full_table_name, x, error=False, table_type=table_type
),
child,
None,
)
else:
return get_create_and_insert_plan(child, replace=False, error=False)
return self.build(
lambda x: create_table_as_select_statement(
full_table_name,
x,
column_definition,
error=False,
table_type=table_type,
clustering_key=clustering_keys,
),
child,
None,
)
elif mode == SaveMode.ERROR_IF_EXISTS:
return get_create_and_insert_plan(child, replace=False, error=True)
return self.build(
lambda x: create_table_as_select_statement(
full_table_name,
x,
column_definition,
table_type=table_type,
clustering_key=clustering_keys,
),
child,
None,
)

def limit(
self,
Expand Down
24 changes: 8 additions & 16 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,11 +2306,10 @@ def test_table_types_in_save_as_table(session, save_mode, table_type):
Utils.drop_table(session, table_name)


@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"])
@pytest.mark.parametrize(
"save_mode", ["append", "overwrite", "ignore", "errorifexists"]
)
def test_save_as_table_respects_schema(session, save_mode, table_type):
def test_save_as_table_respects_schema(session, save_mode):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)

schema1 = StructType(
Expand All @@ -2325,32 +2324,29 @@ def test_save_as_table_respects_schema(session, save_mode, table_type):
df2 = session.create_dataframe([(1), (2)], schema=schema2)

try:
df1.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df1.write.save_as_table(table_name, mode=save_mode)
saved_df = session.table(table_name)
Utils.is_schema_same(saved_df.schema, schema1)

if save_mode == "overwrite":
df2.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df2.write.save_as_table(table_name, mode=save_mode)
saved_df = session.table(table_name)
Utils.is_schema_same(saved_df.schema, schema2)
elif save_mode == "ignore":
df2.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df2.write.save_as_table(table_name, mode=save_mode)
saved_df = session.table(table_name)
Utils.is_schema_same(saved_df.schema, schema1)
else: # save_mode in ('append', 'errorifexists')
with pytest.raises(SnowparkSQLException):
df2.write.save_as_table(
table_name, mode=save_mode, table_type=table_type
)
df2.write.save_as_table(table_name, mode=save_mode)
finally:
Utils.drop_table(session, table_name)


@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"])
@pytest.mark.parametrize(
"save_mode", ["append", "overwrite", "ignore", "errorifexists"]
)
def test_save_as_table_nullable_test(session, save_mode, table_type):
def test_save_as_table_nullable_test(session, save_mode):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
schema = StructType(
[
Expand All @@ -2365,7 +2361,7 @@ def test_save_as_table_nullable_test(session, save_mode, table_type):
(IntegrityError, SnowparkSQLException),
match="NULL result in a non-nullable column",
):
df.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df.write.save_as_table(table_name, mode=save_mode)
finally:
Utils.drop_table(session, table_name)

Expand Down Expand Up @@ -2397,9 +2393,8 @@ def test_save_as_table_with_table_sproc_output(session, save_mode, table_type):
Utils.drop_procedure(session, f"{temp_sp_name}()")


@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"])
@pytest.mark.parametrize("save_mode", ["append", "overwrite"])
def test_write_table_with_clustering_keys(session, save_mode, table_type):
def test_write_table_with_clustering_keys(session, save_mode):
table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE)
table_name2 = Utils.random_name_for_temp_object(TempObjectType.TABLE)
table_name3 = Utils.random_name_for_temp_object(TempObjectType.TABLE)
Expand Down Expand Up @@ -2433,7 +2428,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type):
df1.write.save_as_table(
table_name1,
mode=save_mode,
table_type=table_type,
clustering_keys=["c1", "c2"],
)
ddl = session._run_query(f"select get_ddl('table', '{table_name1}')")[0][0]
Expand All @@ -2442,7 +2436,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type):
df2.write.save_as_table(
table_name2,
mode=save_mode,
table_type=table_type,
clustering_keys=[
col("c1").cast(DateType()),
col("c2").substring(0, 10),
Expand All @@ -2454,7 +2447,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type):
df3.write.save_as_table(
table_name3,
mode=save_mode,
table_type=table_type,
clustering_keys=[get_path(col("v"), lit("Data.id")).cast(IntegerType())],
)
ddl = session._run_query(f"select get_ddl('table', '{table_name3}')")[0][0]
Expand Down

0 comments on commit eceeb36

Please sign in to comment.