Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-919476: Do not ignore stage location for using imports when is_permanent is False #1053

Merged
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Release History

## 1.9 (TBD)

### Bug Fixes

- Fixed a bug where imports from permanent stage locations were ignored for temporary stored procedures, UDTFs, UDFs, and UDAFs.

## 1.8.0 (2023-09-14)

### New Features
Expand Down
18 changes: 11 additions & 7 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,6 @@ def check_register_args(
raise ValueError(
f"stage_location must be specified for permanent {get_error_message_abbr(object_type)}"
)
else:
if stage_location:
logger.warn(
"is_permanent is False therefore stage_location will be ignored"
)

if parallel < 1 or parallel > 99:
raise ValueError(
Expand Down Expand Up @@ -821,6 +816,12 @@ def resolve_imports_and_packages(
else session.get_session_stage()
)

import_stage = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define import_stage first then upload_stage = import_stage if is_permanent else session_stage.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a better name for upload_stage, because technically we need to upload files to import_stage too. How about something like sproc_stage vs import_stage?
In addition, could you please add a one line documentation for these args to clarify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your point but the stage is used for udfs as well. How about import_only_stage and upload_and_import_stage?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we upload files to import stage, but it is the other way around (we import the upload stage). And as Afroz mentioned, the import/upload difference is application to UDF families as well. I cannot come up with a better name than what Afroz suggested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

unwrap_stage_location_single_quote(stage_location)
if stage_location
else session.get_session_stage()
)

# resolve packages
resolved_packages = (
session._resolve_packages(packages, include_pandas=is_pandas_udf)
Expand Down Expand Up @@ -850,11 +851,14 @@ def resolve_imports_and_packages(
)
udf_level_imports[resolved_import_tuple[0]] = resolved_import_tuple[1:]
all_urls = session._resolve_imports(
upload_stage, udf_level_imports, statement_params=statement_params
import_stage,
upload_stage,
udf_level_imports,
statement_params=statement_params,
)
elif imports is None:
all_urls = session._resolve_imports(
upload_stage, statement_params=statement_params
import_stage, upload_stage, statement_params=statement_params
)
else:
all_urls = []
Expand Down
29 changes: 18 additions & 11 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,8 @@ def _resolve_import_path(

def _resolve_imports(
self,
stage_location: str,
import_stage: str,
upload_stage: str,
udf_level_import_paths: Optional[
Dict[str, Tuple[Optional[str], Optional[str]]]
] = None,
Expand All @@ -695,9 +696,11 @@ def _resolve_imports(
"""Resolve the imports and upload local files (if any) to the stage."""
resolved_stage_files = []
stage_file_list = self._list_files_in_stage(
stage_location, statement_params=statement_params
import_stage, statement_params=statement_params
)
normalized_stage_location = unwrap_stage_location_single_quote(stage_location)
# probably shouldn't do it. It is already done in resolve_imports_and_pacakges
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this comment if we already have tests :p

normalized_import_location = unwrap_stage_location_single_quote(import_stage)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we verify if this could lead to bugs (e.g. over-unwrapping)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are some tests in test_utils_suite.py/test_normalize_stage_location which test for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more tests. Overwrapping does not lead to bugs

normalized_upload_location = unwrap_stage_location_single_quote(upload_stage)

import_paths = udf_level_import_paths or self._import_paths
for path, (prefix, leading_path) in import_paths.items():
Expand All @@ -713,7 +716,12 @@ def _resolve_imports(
filename_with_prefix = f"{prefix}/{filename}"
if filename_with_prefix in stage_file_list:
_logger.debug(
f"{filename} exists on {normalized_stage_location}, skipped"
f"{filename} exists on {normalized_import_location}, skipped"
)
resolved_stage_files.append(
normalize_remote_file_or_dir(
f"{normalized_import_location}/{filename_with_prefix}"
)
)
else:
# local directory or .py file
Expand All @@ -723,7 +731,7 @@ def _resolve_imports(
) as input_stream:
self._conn.upload_stream(
input_stream=input_stream,
stage_location=normalized_stage_location,
stage_location=normalized_upload_location,
dest_filename=filename,
dest_prefix=prefix,
source_compression="DEFLATE",
Expand All @@ -736,17 +744,17 @@ def _resolve_imports(
else:
self._conn.upload_file(
path=path,
stage_location=normalized_stage_location,
stage_location=normalized_upload_location,
dest_prefix=prefix,
compress_data=False,
overwrite=True,
skip_upload_on_content_match=True,
)
resolved_stage_files.append(
normalize_remote_file_or_dir(
f"{normalized_stage_location}/{filename_with_prefix}"
resolved_stage_files.append(
normalize_remote_file_or_dir(
f"{normalized_upload_location}/{filename_with_prefix}"
)
)
)

return resolved_stage_files

Expand Down Expand Up @@ -1707,7 +1715,6 @@ def connection(self) -> "SnowflakeConnection":
and Snowflake server."""
return self._conn._conn


def _run_query(
self,
query: str,
Expand Down
29 changes: 13 additions & 16 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,28 +592,24 @@ def test_permanent_sp(session, db_parameters):


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_sp_negative(session, db_parameters, caplog):
def test_permanent_sp_negative(session, db_parameters):
stage_name = Utils.random_stage_name()
sp_name = Utils.random_name_for_temp_object(TempObjectType.PROCEDURE)
with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
new_session.add_packages("snowflake-snowpark-python")
try:
with caplog.at_level(logging.WARN):
sproc(
lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[
0
][0],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=sp_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
Utils.create_stage(session, stage_name, is_temporary=False)
sproc(
lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[
0
][0],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=sp_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)

with pytest.raises(
Expand All @@ -623,6 +619,7 @@ def test_permanent_sp_negative(session, db_parameters, caplog):
assert new_session.call(sp_name, 8, 9) == 17
finally:
new_session._run_query(f"drop function if exists {sp_name}(int, int)")
Utils.drop_stage(session, stage_name)


@pytest.mark.skipif(not is_pandas_available, reason="Requires pandas")
Expand Down
25 changes: 11 additions & 14 deletions tests/integ/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_register_udaf_from_file_with_type_hints(session, resources_path):


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_udaf_negative(session, db_parameters, caplog):
def test_permanent_udaf_negative(session, db_parameters):
stage_name = Utils.random_stage_name()
udaf_name = Utils.random_name_for_temp_object(TempObjectType.AGGREGATE_FUNCTION)
df1 = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")
Expand Down Expand Up @@ -442,19 +442,15 @@ def finish(self):
"a", "b"
)
try:
with caplog.at_level(logging.WARN):
sum_udaf = udaf(
PythonSumUDAFHandler,
return_type=IntegerType(),
input_types=[IntegerType()],
name=udaf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
Utils.create_stage(session, stage_name, is_temporary=False)
sum_udaf = udaf(
PythonSumUDAFHandler,
return_type=IntegerType(),
input_types=[IntegerType()],
name=udaf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)

with pytest.raises(
Expand All @@ -465,6 +461,7 @@ def finish(self):
Utils.check_answer(df2.agg(sum_udaf("a")), [Row(6)])
finally:
new_session._run_query(f"drop function if exists {udaf_name}(int)")
Utils.drop_stage(session, stage_name)


def test_udaf_negative(session):
Expand Down
25 changes: 11 additions & 14 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,25 +982,21 @@ def test_permanent_udf(session, db_parameters):


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_udf_negative(session, db_parameters, caplog):
def test_permanent_udf_negative(session, db_parameters):
stage_name = Utils.random_stage_name()
udf_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)
with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
try:
with caplog.at_level(logging.WARN):
udf(
lambda x, y: x + y,
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=udf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
Utils.create_stage(session, stage_name, is_temporary=False)
udf(
lambda x, y: x + y,
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=udf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)

with pytest.raises(
Expand All @@ -1013,6 +1009,7 @@ def test_permanent_udf_negative(session, db_parameters, caplog):
)
finally:
new_session._run_query(f"drop function if exists {udf_name}(int, int)")
Utils.drop_stage(session, stage_name)


def test_udf_negative(session):
Expand Down
25 changes: 11 additions & 14 deletions tests/integ/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def group_sum(pdf):


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_udtf_negative(session, db_parameters, caplog):
def test_permanent_udtf_negative(session, db_parameters):
stage_name = Utils.random_stage_name()
udtf_name = Utils.random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)

Expand All @@ -397,19 +397,15 @@ def process(
with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
try:
with caplog.at_level(logging.WARN):
echo_udtf = udtf(
UDTFEcho,
output_schema=StructType([StructField("A", IntegerType())]),
input_types=[IntegerType()],
name=udtf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
Utils.create_stage(session, stage_name, is_temporary=False)
echo_udtf = udtf(
UDTFEcho,
output_schema=StructType([StructField("A", IntegerType())]),
input_types=[IntegerType()],
name=udtf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)

with pytest.raises(
Expand All @@ -420,6 +416,7 @@ def process(
Utils.check_answer(new_session.table_function(echo_udtf(lit(1))), [Row(1)])
finally:
new_session._run_query(f"drop function if exists {udtf_name}(int)")
Utils.drop_stage(session, stage_name)


@pytest.mark.xfail(reason="SNOW-757054 flaky test", strict=False)
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,38 @@ def test_negative_execute_as():
)


@pytest.mark.parametrize("is_permanent", [True, False])
def test_sp_stage_location(is_permanent):
"""Make sure that EXECUTE AS option is rendered into SQL correctly."""
fake_session = mock.create_autospec(Session)
fake_session._conn = mock.create_autospec(ServerConnection)
fake_session._conn._telemetry_client = mock.create_autospec(TelemetryClient)
fake_session.sproc = StoredProcedureRegistration(fake_session)
fake_session._plan_builder = SnowflakePlanBuilder(fake_session)
fake_session._analyzer = Analyzer(fake_session)
fake_session._runtime_version_from_requirement = None
fake_session._resolve_imports = lambda x, y, statement_params: [x]
stage_location = "@permanent_stage_location/packages.zip"

def return1(_):
return 1

sproc(
return1,
name="UNIT_TEST",
packages=[],
return_type=IntegerType(),
session=fake_session,
stage_location=stage_location,
imports=["package"],
is_permanent=is_permanent,
)
assert any(
f"IMPORTS=('{stage_location}')" in c.args[0]
for c in fake_session._run_query.call_args_list
)


@mock.patch("snowflake.snowpark.stored_procedure.cleanup_failed_permanent_registration")
def test_do_register_sp_negative(cleanup_registration_patch):
fake_session = mock.create_autospec(Session)
Expand Down