diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e9a671d675..a67fbd35adc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ - Added back the dependency of `typing-extensions`. +### Bug Fixes + +- Fixed a bug where imports from permanent stage locations were ignored for temporary stored procedures, UDTFs, UDFs, and UDAFs. + ## 1.8.0 (2023-09-14) ### New Features diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index f18f29f8605..237d5a33843 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -349,11 +349,6 @@ def check_register_args( raise ValueError( f"stage_location must be specified for permanent {get_error_message_abbr(object_type)}" ) - else: - if stage_location: - logger.warn( - "is_permanent is False therefore stage_location will be ignored" - ) if parallel < 1 or parallel > 99: raise ValueError( @@ -815,12 +810,16 @@ def resolve_imports_and_packages( skip_upload_on_content_match: bool = False, is_permanent: bool = False, ) -> Tuple[str, str, str, str, str, bool]: - upload_stage = ( + import_only_stage = ( unwrap_stage_location_single_quote(stage_location) - if stage_location and is_permanent + if stage_location else session.get_session_stage() ) + upload_and_import_stage = ( + import_only_stage if is_permanent else session.get_session_stage() + ) + # resolve packages resolved_packages = ( session._resolve_packages(packages, include_pandas=is_pandas_udf) @@ -850,11 +849,16 @@ def resolve_imports_and_packages( ) udf_level_imports[resolved_import_tuple[0]] = resolved_import_tuple[1:] all_urls = session._resolve_imports( - upload_stage, udf_level_imports, statement_params=statement_params + import_only_stage, + upload_and_import_stage, + udf_level_imports, + statement_params=statement_params, ) elif imports is None: all_urls = session._resolve_imports( - upload_stage, statement_params=statement_params + import_only_stage, + upload_and_import_stage, + statement_params=statement_params, ) else: all_urls = [] @@ -883,7 +887,7 @@ def resolve_imports_and_packages( if len(code) > _MAX_INLINE_CLOSURE_SIZE_BYTES: dest_prefix = get_udf_upload_prefix(udf_name) upload_file_stage_location = normalize_remote_file_or_dir( - f"{upload_stage}/{dest_prefix}/{udf_file_name}" + f"{upload_and_import_stage}/{dest_prefix}/{udf_file_name}" ) udf_file_name_base = os.path.splitext(udf_file_name)[0] with io.BytesIO() as input_stream: @@ -893,7 +897,7 @@ def resolve_imports_and_packages( zf.writestr(f"{udf_file_name_base}.py", code) session._conn.upload_stream( input_stream=input_stream, - stage_location=upload_stage, + stage_location=upload_and_import_stage, dest_filename=udf_file_name, dest_prefix=dest_prefix, parallel=parallel, @@ -924,11 +928,11 @@ def resolve_imports_and_packages( all_urls.append(func[0]) else: upload_file_stage_location = normalize_remote_file_or_dir( - f"{upload_stage}/{dest_prefix}/{udf_file_name}" + f"{upload_and_import_stage}/{dest_prefix}/{udf_file_name}" ) session._conn.upload_file( path=func[0], - stage_location=upload_stage, + stage_location=upload_and_import_stage, dest_prefix=dest_prefix, parallel=parallel, compress_data=False, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 6313fd6d195..a07ab02e46b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -685,7 +685,8 @@ def _resolve_import_path( def _resolve_imports( self, - stage_location: str, + import_only_stage: str, + upload_and_import_stage: str, udf_level_import_paths: Optional[ Dict[str, Tuple[Optional[str], Optional[str]]] ] = None, @@ -695,9 +696,15 @@ def _resolve_imports( """Resolve the imports and upload local files (if any) to the stage.""" resolved_stage_files = [] stage_file_list = self._list_files_in_stage( - stage_location, statement_params=statement_params + import_only_stage, statement_params=statement_params + ) + + normalized_import_only_location = unwrap_stage_location_single_quote( + import_only_stage + ) + normalized_upload_and_import_location = unwrap_stage_location_single_quote( + upload_and_import_stage ) - normalized_stage_location = unwrap_stage_location_single_quote(stage_location) import_paths = udf_level_import_paths or self._import_paths for path, (prefix, leading_path) in import_paths.items(): @@ -713,7 +720,12 @@ def _resolve_imports( filename_with_prefix = f"{prefix}/{filename}" if filename_with_prefix in stage_file_list: _logger.debug( - f"{filename} exists on {normalized_stage_location}, skipped" + f"{filename} exists on {normalized_import_only_location}, skipped" + ) + resolved_stage_files.append( + normalize_remote_file_or_dir( + f"{normalized_import_only_location}/{filename_with_prefix}" + ) ) else: # local directory or .py file @@ -723,7 +735,7 @@ def _resolve_imports( ) as input_stream: self._conn.upload_stream( input_stream=input_stream, - stage_location=normalized_stage_location, + stage_location=normalized_upload_and_import_location, dest_filename=filename, dest_prefix=prefix, source_compression="DEFLATE", @@ -736,17 +748,17 @@ def _resolve_imports( else: self._conn.upload_file( path=path, - stage_location=normalized_stage_location, + stage_location=normalized_upload_and_import_location, dest_prefix=prefix, compress_data=False, overwrite=True, skip_upload_on_content_match=True, ) - resolved_stage_files.append( - normalize_remote_file_or_dir( - f"{normalized_stage_location}/{filename_with_prefix}" + resolved_stage_files.append( + normalize_remote_file_or_dir( + f"{normalized_upload_and_import_location}/{filename_with_prefix}" + ) ) - ) return resolved_stage_files @@ -1177,7 +1189,9 @@ def get_req_identifiers_list( if name in result_dict: if version is not None: added_package_has_version = "==" in result_dict[name] - if added_package_has_version and result_dict[name] != str(package): + if added_package_has_version and result_dict[name] != str( + package + ): raise ValueError( f"Cannot add dependency package '{name}=={version}' " f"because {result_dict[name]} is already added." diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index c80e1ae871a..c74774f3c86 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -592,28 +592,24 @@ def test_permanent_sp(session, db_parameters): @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_permanent_sp_negative(session, db_parameters, caplog): +def test_permanent_sp_negative(session, db_parameters): stage_name = Utils.random_stage_name() sp_name = Utils.random_name_for_temp_object(TempObjectType.PROCEDURE) with Session.builder.configs(db_parameters).create() as new_session: new_session.sql_simplifier_enabled = session.sql_simplifier_enabled new_session.add_packages("snowflake-snowpark-python") try: - with caplog.at_level(logging.WARN): - sproc( - lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[ - 0 - ][0], - return_type=IntegerType(), - input_types=[IntegerType(), IntegerType()], - name=sp_name, - is_permanent=False, - stage_location=stage_name, - session=new_session, - ) - assert ( - "is_permanent is False therefore stage_location will be ignored" - in caplog.text + Utils.create_stage(session, stage_name, is_temporary=False) + sproc( + lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[0][ + 0 + ], + return_type=IntegerType(), + input_types=[IntegerType(), IntegerType()], + name=sp_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, ) with pytest.raises( @@ -623,6 +619,7 @@ def test_permanent_sp_negative(session, db_parameters, caplog): assert new_session.call(sp_name, 8, 9) == 17 finally: new_session._run_query(f"drop function if exists {sp_name}(int, int)") + Utils.drop_stage(session, stage_name) @pytest.mark.skipif(not is_pandas_available, reason="Requires pandas") @@ -934,6 +931,78 @@ def hello_sp(session: Session, name: str, age: int) -> DataFrame: Utils.drop_procedure(session, f"{temp_sp_name}(string, bigint)") +def test_temp_sp_with_import_and_upload_stage(session, resources_path): + """We want temporary stored procs to be able to do the following: + - Do not upload packages to permanent stage locations + - Can import packages from permanent stage locations + - Can upload packages to temp stages for custom usage + - Import from permanent stage location and upload to temp stage + import from temp stage should + work + """ + stage_name = Utils.random_stage_name() + Utils.create_stage(session, stage_name, is_temporary=False) + test_files = TestFiles(resources_path) + # upload test_sp_dir.test_sp_file (mod5) to permanent stage and use mod3 + # file for temporary stage import correctness + session._conn.upload_file( + path=test_files.test_sp_py_file, + stage_location=unwrap_stage_location_single_quote(stage_name), + compress_data=False, + overwrite=True, + skip_upload_on_content_match=True, + ) + try: + # Can import packages from permanent stage locations + def mod5_(session_, x): + from test_sp_file import mod5 + + return mod5(session_, x) + + mod5_sproc = sproc( + mod5_, + return_type=IntegerType(), + input_types=[IntegerType()], + imports=[f"@{stage_name}/test_sp_file.py"], + is_permanent=False, + ) + assert mod5_sproc(5) == 0 + + # Can upload packages to temp stages for custom usage + def mod3_(session_, x): + from test_sp_mod3_file import mod3 + + return mod3(session_, x) + + mod3_sproc = sproc( + mod3_, + return_type=IntegerType(), + input_types=[IntegerType()], + imports=[test_files.test_sp_mod3_py_file], + ) + + assert mod3_sproc(3) == 0 + + # Import from permanent stage location and upload to temp stage + import + # from temp stage should work + def mod3_of_mod5_(session_, x): + from test_sp_file import mod5 + from test_sp_mod3_file import mod3 + + return mod3(session_, mod5(session_, x)) + + mod3_of_mod5_sproc = sproc( + mod3_of_mod5_, + return_type=IntegerType(), + input_types=[IntegerType()], + imports=[f"@{stage_name}/test_sp_file.py", test_files.test_sp_mod3_py_file], + ) + + assert mod3_of_mod5_sproc(4) == 1 + finally: + Utils.drop_stage(session, stage_name) + pass + + def test_add_import_negative(session, resources_path): test_files = TestFiles(resources_path) diff --git a/tests/integ/test_udaf.py b/tests/integ/test_udaf.py index e6bb82d0bb4..a063adb6e98 100644 --- a/tests/integ/test_udaf.py +++ b/tests/integ/test_udaf.py @@ -4,7 +4,6 @@ import datetime import decimal -import logging from typing import Any, Dict, List import pytest @@ -414,7 +413,7 @@ def test_register_udaf_from_file_with_type_hints(session, resources_path): @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_permanent_udaf_negative(session, db_parameters, caplog): +def test_permanent_udaf_negative(session, db_parameters): stage_name = Utils.random_stage_name() udaf_name = Utils.random_name_for_temp_object(TempObjectType.AGGREGATE_FUNCTION) df1 = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b") @@ -442,19 +441,15 @@ def finish(self): "a", "b" ) try: - with caplog.at_level(logging.WARN): - sum_udaf = udaf( - PythonSumUDAFHandler, - return_type=IntegerType(), - input_types=[IntegerType()], - name=udaf_name, - is_permanent=False, - stage_location=stage_name, - session=new_session, - ) - assert ( - "is_permanent is False therefore stage_location will be ignored" - in caplog.text + Utils.create_stage(session, stage_name, is_temporary=False) + sum_udaf = udaf( + PythonSumUDAFHandler, + return_type=IntegerType(), + input_types=[IntegerType()], + name=udaf_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, ) with pytest.raises( @@ -465,6 +460,7 @@ def finish(self): Utils.check_answer(df2.agg(sum_udaf("a")), [Row(6)]) finally: new_session._run_query(f"drop function if exists {udaf_name}(int)") + Utils.drop_stage(session, stage_name) def test_udaf_negative(session): diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 8b2f1b10bc5..441c17a7bb7 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -694,9 +694,9 @@ def test_add_import_duplicate(session, resources_path, caplog): # skip upload the file because the calculated checksum is same session_stage = session.get_session_stage() - session._resolve_imports(session_stage) + session._resolve_imports(session_stage, session_stage) session.add_import(abs_path) - session._resolve_imports(session_stage) + session._resolve_imports(session_stage, session_stage) assert ( f"{os.path.basename(abs_path)}.zip exists on {session_stage}, skipped" in caplog.text @@ -982,25 +982,21 @@ def test_permanent_udf(session, db_parameters): @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_permanent_udf_negative(session, db_parameters, caplog): +def test_permanent_udf_negative(session, db_parameters): stage_name = Utils.random_stage_name() udf_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) with Session.builder.configs(db_parameters).create() as new_session: new_session.sql_simplifier_enabled = session.sql_simplifier_enabled try: - with caplog.at_level(logging.WARN): - udf( - lambda x, y: x + y, - return_type=IntegerType(), - input_types=[IntegerType(), IntegerType()], - name=udf_name, - is_permanent=False, - stage_location=stage_name, - session=new_session, - ) - assert ( - "is_permanent is False therefore stage_location will be ignored" - in caplog.text + Utils.create_stage(session, stage_name, is_temporary=False) + udf( + lambda x, y: x + y, + return_type=IntegerType(), + input_types=[IntegerType(), IntegerType()], + name=udf_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, ) with pytest.raises( @@ -1013,6 +1009,7 @@ def test_permanent_udf_negative(session, db_parameters, caplog): ) finally: new_session._run_query(f"drop function if exists {udf_name}(int, int)") + Utils.drop_stage(session, stage_name) def test_udf_negative(session): diff --git a/tests/integ/test_udtf.py b/tests/integ/test_udtf.py index db301f08e0b..0206abf5d8a 100644 --- a/tests/integ/test_udtf.py +++ b/tests/integ/test_udtf.py @@ -3,7 +3,6 @@ # import decimal -import logging import sys from typing import Tuple @@ -383,7 +382,7 @@ def group_sum(pdf): @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_permanent_udtf_negative(session, db_parameters, caplog): +def test_permanent_udtf_negative(session, db_parameters): stage_name = Utils.random_stage_name() udtf_name = Utils.random_name_for_temp_object(TempObjectType.TABLE_FUNCTION) @@ -397,19 +396,15 @@ def process( with Session.builder.configs(db_parameters).create() as new_session: new_session.sql_simplifier_enabled = session.sql_simplifier_enabled try: - with caplog.at_level(logging.WARN): - echo_udtf = udtf( - UDTFEcho, - output_schema=StructType([StructField("A", IntegerType())]), - input_types=[IntegerType()], - name=udtf_name, - is_permanent=False, - stage_location=stage_name, - session=new_session, - ) - assert ( - "is_permanent is False therefore stage_location will be ignored" - in caplog.text + Utils.create_stage(session, stage_name, is_temporary=False) + echo_udtf = udtf( + UDTFEcho, + output_schema=StructType([StructField("A", IntegerType())]), + input_types=[IntegerType()], + name=udtf_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, ) with pytest.raises( @@ -420,6 +415,7 @@ def process( Utils.check_answer(new_session.table_function(echo_udtf(lit(1))), [Row(1)]) finally: new_session._run_query(f"drop function if exists {udtf_name}(int)") + Utils.drop_stage(session, stage_name) @pytest.mark.xfail(reason="SNOW-757054 flaky test", strict=False) diff --git a/tests/resources/test_sp_dir/test_sp_mod3_file.py b/tests/resources/test_sp_dir/test_sp_mod3_file.py new file mode 100644 index 00000000000..18fdceef81e --- /dev/null +++ b/tests/resources/test_sp_dir/test_sp_mod3_file.py @@ -0,0 +1,5 @@ +import snowflake.snowpark + + +def mod3(session: snowflake.snowpark.Session, x: int) -> int: + return session.sql(f"SELECT {x} % 3").collect()[0][0] diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index ddbadcb9c26..ab2e9d584b8 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -99,15 +99,26 @@ def test_calculate_checksum(): def test_normalize_stage_location(): name1 = "stage" - assert unwrap_stage_location_single_quote(name1 + " ") == f"@{name1}" + unwrap_name1 = unwrap_stage_location_single_quote(name1 + " ") + assert unwrap_name1 == f"@{name1}" assert unwrap_stage_location_single_quote("@" + name1 + " ") == f"@{name1}" + assert unwrap_stage_location_single_quote(unwrap_name1) == unwrap_name1 + name2 = '"DATABASE"."SCHEMA"."STAGE"' - assert unwrap_stage_location_single_quote(name2 + " ") == f"@{name2}" + unwrap_name2 = unwrap_stage_location_single_quote(name2 + " ") + assert unwrap_name2 == f"@{name2}" assert unwrap_stage_location_single_quote("@" + name2 + " ") == f"@{name2}" + assert unwrap_stage_location_single_quote(unwrap_name2) == unwrap_name2 + name3 = "s t a g 'e" - assert unwrap_stage_location_single_quote(name3) == "@s t a g 'e" + unwrap_name3 = unwrap_stage_location_single_quote(name3) + assert unwrap_name3 == "@s t a g 'e" + assert unwrap_stage_location_single_quote(unwrap_name3) == unwrap_name3 + name4 = "' s t a g 'e'" - assert unwrap_stage_location_single_quote(name4) == "@ s t a g 'e" + unwrap_name4 = unwrap_stage_location_single_quote(name4) + assert unwrap_name4 == "@ s t a g 'e" + assert unwrap_stage_location_single_quote(unwrap_name4) == unwrap_name4 @pytest.mark.parametrize("is_local", [True, False]) @@ -247,6 +258,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/test_environment.yml", "resources/test_sp_dir/", "resources/test_sp_dir/test_sp_file.py", + "resources/test_sp_dir/test_sp_mod3_file.py", "resources/test_sp_dir/test_table_sp_file.py", "resources/test_udf_dir/", "resources/test_udf_dir/test_pandas_udf_file.py", diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index be78b8745c1..c2168de21f7 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -8,7 +8,6 @@ import pytest import snowflake.snowpark.session -from snowflake.snowpark.session import Session from snowflake.snowpark import ( DataFrame, DataFrameNaFunctions, @@ -21,6 +20,7 @@ from snowflake.snowpark._internal.server_connection import ServerConnection from snowflake.snowpark.dataframe import _get_unaliased from snowflake.snowpark.exceptions import SnowparkCreateDynamicTableException +from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType, StringType @@ -300,7 +300,5 @@ def test_session(): fake_session._analyzer = mock.Mock() df = DataFrame(fake_session) - assert(df.session == fake_session) - assert(df.session._session_id == fake_session._session_id) - - + assert df.session == fake_session + assert df.session._session_id == fake_session._session_id diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 032c926f1c5..39ab959d608 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -3,7 +3,7 @@ # import json import os -from typing import Optional, Dict, Union +from typing import Optional from unittest import mock from unittest.mock import MagicMock @@ -375,7 +375,7 @@ def test_session_id(): fake_server_connection.get_session_id = mock.Mock(return_value=123456) session = Session(fake_server_connection) - assert(session.session_id == 123456) + assert session.session_id == 123456 def test_connection(): @@ -387,6 +387,5 @@ def test_connection(): server_connection = ServerConnection(fake_options, fake_snowflake_connection) session = Session(server_connection) - assert(session.connection == session._conn._conn) - assert(session.connection == fake_snowflake_connection) - + assert session.connection == session._conn._conn + assert session.connection == fake_snowflake_connection diff --git a/tests/utils.py b/tests/utils.py index adc5016f3ef..948f3e7c11a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -941,6 +941,10 @@ def test_sp_directory(self): def test_sp_py_file(self): return os.path.join(self.test_sp_directory, "test_sp_file.py") + @property + def test_sp_mod3_py_file(self): + return os.path.join(self.test_sp_directory, "test_sp_mod3_file.py") + @property def test_table_sp_py_file(self): return os.path.join(self.test_sp_directory, "test_table_sp_file.py")