From 1a22e981b563d1e12add63d3d6ef6cbe76d280d4 Mon Sep 17 00:00:00 2001 From: Shixuan Fan Date: Fri, 1 Mar 2024 17:18:01 -0800 Subject: [PATCH] SNOW-1064328 Use object name if db/schema is missing in stored proc (#1271) * SNOW-1064328 Use object name if db/schema is missing in stored proc Description In snowpark API, we use fully qualified name for temp objects created implicitly by the API, mainly to support locating the object even if user changes schema of the session. However, in stored procs, we have use cases like bundle where it is legit to have current_database/current_schema to return NULL. We still want to support object creation in this case. Thus, our change is to fall back to use the object name if there is no current database/schema and the client is running inside the stored procedure Testing Existing tests + bundle integration test * fix test * remove fully_qualified_schema in mock * fix deprecaated --- CHANGELOG.md | 4 +++ docs/source/session.rst | 1 + .../_internal/analyzer/snowflake_plan.py | 13 +++------- src/snowflake/snowpark/_internal/udf_utils.py | 4 +-- src/snowflake/snowpark/dataframe.py | 4 ++- src/snowflake/snowpark/dataframe_reader.py | 10 ++----- src/snowflake/snowpark/mock/_plan_builder.py | 1 - src/snowflake/snowpark/session.py | 26 ++++++++++++------- tests/integ/scala/test_permanent_udf_suite.py | 8 ++++-- tests/integ/scala/test_session_suite.py | 2 +- tests/integ/scala/test_udf_suite.py | 2 +- tests/integ/test_session.py | 6 ++--- tests/integ/test_stored_procedure.py | 4 +-- tests/integ/test_udf.py | 2 +- tests/unit/test_stored_procedure.py | 2 +- tests/unit/test_udaf.py | 2 +- tests/unit/test_udf.py | 2 +- tests/unit/test_udtf.py | 2 +- 18 files changed, 49 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0aadc8a86bf..d8e9937bcfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ - Added support for creating vectorized UDTFs with `process` method. +### Deprecations: + +- Deprecated `Session.get_fully_qualified_current_schema`. Consider using `Session.get_fully_qualified_name_if_possible` instead. + ## 1.13.0 (2024-02-26) ### New Features diff --git a/docs/source/session.rst b/docs/source/session.rst index e156c790117..6e07149a8d4 100644 --- a/docs/source/session.rst +++ b/docs/source/session.rst @@ -52,6 +52,7 @@ Snowpark Session Session.get_current_user Session.get_current_warehouse Session.get_fully_qualified_current_schema + Session.get_fully_qualified_name_if_possible Session.get_imports Session.get_packages Session.get_session_stage diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index bba846ce46b..0e2e70477ac 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -818,7 +818,6 @@ def read_file( path: str, format: str, options: Dict[str, str], - fully_qualified_schema: str, schema: List[Attribute], schema_to_cast: Optional[List[Tuple[str, str]]] = None, transformations: Optional[List[str]] = None, @@ -853,10 +852,8 @@ def read_file( post_queries: List[Query] = [] use_temp_file_format: bool = "FORMAT_NAME" not in options if use_temp_file_format: - format_name = ( - fully_qualified_schema - + "." - + random_name_for_temp_object(TempObjectType.FILE_FORMAT) + format_name = self.session.get_fully_qualified_name_if_possible( + random_name_for_temp_object(TempObjectType.FILE_FORMAT) ) queries.append( Query( @@ -928,10 +925,8 @@ def read_file( ] ) - temp_table_name = ( - fully_qualified_schema - + "." - + random_name_for_temp_object(TempObjectType.TABLE) + temp_table_name = self.session.get_fully_qualified_name_if_possible( + random_name_for_temp_object(TempObjectType.TABLE) ) queries = [ Query( diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index fac19cb84fd..053dfc166df 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -548,9 +548,7 @@ def process_registration_inputs( else: object_name = random_name_for_temp_object(object_type) if not anonymous: - object_name = ( - f"{session.get_fully_qualified_current_schema()}.{object_name}" - ) + object_name = session.get_fully_qualified_name_if_possible(object_name) validate_object_name(object_name) # get return and input types diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 3632d701e92..89c783c37c8 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -3646,7 +3646,9 @@ def cache_result( """ from snowflake.snowpark.mock._connection import MockServerConnection - temp_table_name = f'{self._session.get_current_database()}.{self._session.get_current_schema()}."{random_name_for_temp_object(TempObjectType.TABLE)}"' + temp_table_name = self._session.get_fully_qualified_name_if_possible( + f'"{random_name_for_temp_object(TempObjectType.TABLE)}"' + ) if isinstance(self._session._conn, MockServerConnection): self.write.save_as_table(temp_table_name, create_temp_table=True) diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 965b3ab14cc..d74c45abf51 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -434,7 +434,6 @@ def csv(self, path: str) -> DataFrame: path, self._file_type, self._cur_options, - self._session.get_fully_qualified_current_schema(), schema, schema_to_cast=schema_to_cast, transformations=transformations, @@ -453,7 +452,6 @@ def csv(self, path: str) -> DataFrame: path, self._file_type, self._cur_options, - self._session.get_fully_qualified_current_schema(), schema, schema_to_cast=schema_to_cast, transformations=transformations, @@ -576,10 +574,8 @@ def _infer_schema_for_file_format( ) -> Tuple[List, List, List, Exception]: format_type_options, _ = get_copy_into_table_options(self._cur_options) - temp_file_format_name = ( - self._session.get_fully_qualified_current_schema() - + "." - + random_name_for_temp_object(TempObjectType.FILE_FORMAT) + temp_file_format_name = self._session.get_fully_qualified_name_if_possible( + random_name_for_temp_object(TempObjectType.FILE_FORMAT) ) drop_tmp_file_format_if_exists_query: Optional[str] = None use_temp_file_format = "FORMAT_NAME" not in self._cur_options @@ -685,7 +681,6 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: path, format, self._cur_options, - self._session.get_fully_qualified_current_schema(), schema, schema_to_cast=schema_to_cast, transformations=read_file_transformations, @@ -704,7 +699,6 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: path, format, self._cur_options, - self._session.get_fully_qualified_current_schema(), schema, schema_to_cast=schema_to_cast, transformations=read_file_transformations, diff --git a/src/snowflake/snowpark/mock/_plan_builder.py b/src/snowflake/snowpark/mock/_plan_builder.py index 813abd8b5fb..52c6cb9c8a9 100644 --- a/src/snowflake/snowpark/mock/_plan_builder.py +++ b/src/snowflake/snowpark/mock/_plan_builder.py @@ -21,7 +21,6 @@ def read_file( path: str, format: str, options: Dict[str, str], - fully_qualified_schema: str, schema: List[Attribute], schema_to_cast: Optional[List[Tuple[str, str]]] = None, transformations: Optional[List[str]] = None, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 6bdfbdaf2b7..02c2afa3b08 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1313,7 +1313,6 @@ def _resolve_packages( package_dict = self._parse_packages(packages) package_table = "information_schema.packages" - # TODO: Use the database from fully qualified UDF name if not self.get_current_database(): package_table = f"snowflake.{package_table}" @@ -2018,17 +2017,15 @@ def get_session_stage(self) -> str: These artifacts include libraries and packages for UDFs that you define in this session via :func:`add_import`. """ - qualified_stage_name = ( - f"{self.get_fully_qualified_current_schema()}.{self._session_stage}" - ) + stage_name = self.get_fully_qualified_name_if_possible(self._session_stage) if not self._stage_created: self._run_query( f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ - stage if not exists {qualified_stage_name}", + stage if not exists {stage_name}", is_ddl_on_temp_object=True, ) self._stage_created = True - return f"{STAGE_PREFIX}{qualified_stage_name}" + return f"{STAGE_PREFIX}{stage_name}" def write_pandas( self, @@ -2313,7 +2310,7 @@ def create_dataframe( self._run_query( f"CREATE SCOPED TEMP TABLE {temp_table_name} ({schema_string})" ) - schema_query = f"SELECT * FROM {self.get_current_database()}.{self.get_current_schema()}.{temp_table_name}" + schema_query = f"SELECT * FROM {self.get_fully_qualified_name_if_possible(temp_table_name)}" except ProgrammingError as e: logging.debug( f"Cannot create temp table for specified non-nullable schema, fall back to using schema " @@ -2596,16 +2593,27 @@ def get_current_schema(self) -> Optional[str]: """ return self._conn._get_current_parameter("schema") + @deprecated(version="1.14.0") def get_fully_qualified_current_schema(self) -> str: """Returns the fully qualified name of the current schema for the session.""" + return self.get_fully_qualified_name_if_possible("")[:-1] + + def get_fully_qualified_name_if_possible(self, name: str) -> str: + """ + Returns the fully qualified object name if current database/schema exists, otherwise returns the object name + """ database = self.get_current_database() schema = self.get_current_schema() - if database is None or schema is None: + if database and schema: + return f"{database}.{schema}.{name}" + + # In stored procedure, there are scenarios like bundle where we allow empty current schema + if not is_in_stored_procedure(): missing_item = "DATABASE" if not database else "SCHEMA" raise SnowparkClientExceptionMessages.SERVER_CANNOT_FIND_CURRENT_DB_OR_SCHEMA( missing_item, missing_item, missing_item ) - return database + "." + schema + return name def get_current_warehouse(self) -> Optional[str]: """ diff --git a/tests/integ/scala/test_permanent_udf_suite.py b/tests/integ/scala/test_permanent_udf_suite.py index 1de6d581fb1..1e5a25b1e0e 100644 --- a/tests/integ/scala/test_permanent_udf_suite.py +++ b/tests/integ/scala/test_permanent_udf_suite.py @@ -107,8 +107,12 @@ def test_support_fully_qualified_udf_name(session, new_session): def add_one(x: int) -> int: return x + 1 - temp_func_name = f"{session.get_fully_qualified_current_schema()}.{Utils.random_name_for_temp_object(TempObjectType.FUNCTION)}" - perm_func_name = f"{session.get_fully_qualified_current_schema()}.{Utils.random_name_for_temp_object(TempObjectType.FUNCTION)}" + temp_func_name = session.get_fully_qualified_name_if_possible( + Utils.random_name_for_temp_object(TempObjectType.FUNCTION) + ) + perm_func_name = session.get_fully_qualified_name_if_possible( + Utils.random_name_for_temp_object(TempObjectType.FUNCTION) + ) stage_name = Utils.random_stage_name() try: Utils.create_stage(session, stage_name, is_temporary=False) diff --git a/tests/integ/scala/test_session_suite.py b/tests/integ/scala/test_session_suite.py index 541921cbcfb..62c5702f252 100644 --- a/tests/integ/scala/test_session_suite.py +++ b/tests/integ/scala/test_session_suite.py @@ -128,7 +128,7 @@ def test_negative_test_for_missing_required_parameter_schema( new_session.sql_simplifier_enabled = sql_simplifier_enabled with new_session: with pytest.raises(SnowparkMissingDbOrSchemaException) as ex_info: - new_session.get_fully_qualified_current_schema() + new_session.get_fully_qualified_name_if_possible("table") assert "The SCHEMA is not set for the current session." in str(ex_info) diff --git a/tests/integ/scala/test_udf_suite.py b/tests/integ/scala/test_udf_suite.py index 5a6492d9ad7..53dbfc60e6b 100644 --- a/tests/integ/scala/test_udf_suite.py +++ b/tests/integ/scala/test_udf_suite.py @@ -299,7 +299,7 @@ def test_call_udf_api(session): df.with_column( "c", call_udf( - f"{session.get_fully_qualified_current_schema()}.{function_name}", + session.get_fully_qualified_name_if_possible(function_name), col("a"), ), ).collect(), diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index f38e0c688ea..8ad3bc1aee6 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -222,7 +222,7 @@ def test_list_files_in_stage(session, resources_path, local_testing_mode): assert len(files) == 1 assert os.path.basename(test_files.test_file_avro) in files - full_name = f"{session.get_fully_qualified_current_schema()}.{stage_name}" + full_name = session.get_fully_qualified_name_if_possible(stage_name) files2 = session._list_files_in_stage(full_name) assert len(files2) == 1 assert os.path.basename(test_files.test_file_avro) in files2 @@ -241,8 +241,8 @@ def test_list_files_in_stage(session, resources_path, local_testing_mode): assert len(files4) == 1 assert os.path.basename(test_files.test_file_avro) in files4 - full_name_with_prefix = ( - f"{session.get_fully_qualified_current_schema()}.{quoted_name}" + full_name_with_prefix = session.get_fully_qualified_name_if_possible( + quoted_name ) files5 = session._list_files_in_stage(full_name_with_prefix) assert len(files5) == 1 diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index 552fdc1968e..3ea58d778e9 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -250,9 +250,7 @@ def test_call_named_stored_procedure(session, temp_schema, db_parameters): ) assert session.call(sproc_name, 13, 19) == 13 * 19 assert ( - session.call( - f"{session.get_fully_qualified_current_schema()}.{sproc_name}", 13, 19 - ) + session.call(session.get_fully_qualified_name_if_possible(sproc_name), 13, 19) == 13 * 19 ) diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 60b75ffa34c..0a88f65fefb 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -192,7 +192,7 @@ def test_call_named_udf(session, temp_schema, db_parameters): Utils.check_answer( df.select( call_udf( - f"{session.get_fully_qualified_current_schema()}.{mult_udf_name}", + session.get_fully_qualified_name_if_possible(mult_udf_name), 6, 7, ) diff --git a/tests/unit/test_stored_procedure.py b/tests/unit/test_stored_procedure.py index 2fa83db68b1..a024e22a815 100644 --- a/tests/unit/test_stored_procedure.py +++ b/tests/unit/test_stored_procedure.py @@ -72,7 +72,7 @@ def test_negative_execute_as(): def test_do_register_sp_negative(cleanup_registration_patch): fake_session = mock.create_autospec(Session) fake_session._runtime_version_from_requirement = None - fake_session.get_fully_qualified_current_schema = mock.Mock( + fake_session.get_fully_qualified_name_if_possible = mock.Mock( return_value="database.schema" ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) diff --git a/tests/unit/test_udaf.py b/tests/unit/test_udaf.py index b812380ea1b..4cee34b20c2 100644 --- a/tests/unit/test_udaf.py +++ b/tests/unit/test_udaf.py @@ -39,7 +39,7 @@ class FakeClass: @mock.patch("snowflake.snowpark.udaf.cleanup_failed_permanent_registration") def test_do_register_udaf_negative(cleanup_registration_patch): fake_session = mock.create_autospec(Session) - fake_session.get_fully_qualified_current_schema = mock.Mock( + fake_session.get_fully_qualified_name_if_possible = mock.Mock( return_value="database.schema" ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) diff --git a/tests/unit/test_udf.py b/tests/unit/test_udf.py index 563947c6d2e..431b8c00b5b 100644 --- a/tests/unit/test_udf.py +++ b/tests/unit/test_udf.py @@ -18,7 +18,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): fake_session = mock.create_autospec(Session) fake_session._runtime_version_from_requirement = None - fake_session.get_fully_qualified_current_schema = mock.Mock( + fake_session.get_fully_qualified_name_if_possible = mock.Mock( return_value="database.schema" ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) diff --git a/tests/unit/test_udtf.py b/tests/unit/test_udtf.py index a4964054744..29ded9fbac0 100644 --- a/tests/unit/test_udtf.py +++ b/tests/unit/test_udtf.py @@ -26,7 +26,7 @@ @mock.patch("snowflake.snowpark.udtf.cleanup_failed_permanent_registration") def test_do_register_sp_negative(cleanup_registration_patch): fake_session = mock.create_autospec(Session) - fake_session.get_fully_qualified_current_schema = mock.Mock( + fake_session.get_fully_qualified_name_if_possible = mock.Mock( return_value="database.schema" ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError())