Skip to content

Commit

Permalink
SNOW-1064328 Use object name if db/schema is missing in stored proc (#…
Browse files Browse the repository at this point in the history
…1271)

* SNOW-1064328 Use object name if db/schema is missing in stored proc

Description

In snowpark API, we use fully qualified name for temp objects
created implicitly by the API, mainly to support locating the
object even if user changes schema of the session.

However, in stored procs, we have use cases like bundle where
it is legit to have current_database/current_schema to return
NULL. We still want to support object creation in this case.

Thus, our change is to fall back to use the object name if there
is no current database/schema and the client is running inside
the stored procedure

Testing

Existing tests + bundle integration test

* fix test

* remove fully_qualified_schema in mock

* fix deprecaated
  • Loading branch information
sfc-gh-sfan authored Mar 2, 2024
1 parent 6fb2515 commit 1a22e98
Show file tree
Hide file tree
Showing 18 changed files with 49 additions and 46 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

- Added support for creating vectorized UDTFs with `process` method.

### Deprecations:

- Deprecated `Session.get_fully_qualified_current_schema`. Consider using `Session.get_fully_qualified_name_if_possible` instead.

## 1.13.0 (2024-02-26)

### New Features
Expand Down
1 change: 1 addition & 0 deletions docs/source/session.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Snowpark Session
Session.get_current_user
Session.get_current_warehouse
Session.get_fully_qualified_current_schema
Session.get_fully_qualified_name_if_possible
Session.get_imports
Session.get_packages
Session.get_session_stage
Expand Down
13 changes: 4 additions & 9 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,6 @@ def read_file(
path: str,
format: str,
options: Dict[str, str],
fully_qualified_schema: str,
schema: List[Attribute],
schema_to_cast: Optional[List[Tuple[str, str]]] = None,
transformations: Optional[List[str]] = None,
Expand Down Expand Up @@ -853,10 +852,8 @@ def read_file(
post_queries: List[Query] = []
use_temp_file_format: bool = "FORMAT_NAME" not in options
if use_temp_file_format:
format_name = (
fully_qualified_schema
+ "."
+ random_name_for_temp_object(TempObjectType.FILE_FORMAT)
format_name = self.session.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.FILE_FORMAT)
)
queries.append(
Query(
Expand Down Expand Up @@ -928,10 +925,8 @@ def read_file(
]
)

temp_table_name = (
fully_qualified_schema
+ "."
+ random_name_for_temp_object(TempObjectType.TABLE)
temp_table_name = self.session.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.TABLE)
)
queries = [
Query(
Expand Down
4 changes: 1 addition & 3 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,7 @@ def process_registration_inputs(
else:
object_name = random_name_for_temp_object(object_type)
if not anonymous:
object_name = (
f"{session.get_fully_qualified_current_schema()}.{object_name}"
)
object_name = session.get_fully_qualified_name_if_possible(object_name)
validate_object_name(object_name)

# get return and input types
Expand Down
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3646,7 +3646,9 @@ def cache_result(
"""
from snowflake.snowpark.mock._connection import MockServerConnection

temp_table_name = f'{self._session.get_current_database()}.{self._session.get_current_schema()}."{random_name_for_temp_object(TempObjectType.TABLE)}"'
temp_table_name = self._session.get_fully_qualified_name_if_possible(
f'"{random_name_for_temp_object(TempObjectType.TABLE)}"'
)

if isinstance(self._session._conn, MockServerConnection):
self.write.save_as_table(temp_table_name, create_temp_table=True)
Expand Down
10 changes: 2 additions & 8 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,6 @@ def csv(self, path: str) -> DataFrame:
path,
self._file_type,
self._cur_options,
self._session.get_fully_qualified_current_schema(),
schema,
schema_to_cast=schema_to_cast,
transformations=transformations,
Expand All @@ -453,7 +452,6 @@ def csv(self, path: str) -> DataFrame:
path,
self._file_type,
self._cur_options,
self._session.get_fully_qualified_current_schema(),
schema,
schema_to_cast=schema_to_cast,
transformations=transformations,
Expand Down Expand Up @@ -576,10 +574,8 @@ def _infer_schema_for_file_format(
) -> Tuple[List, List, List, Exception]:
format_type_options, _ = get_copy_into_table_options(self._cur_options)

temp_file_format_name = (
self._session.get_fully_qualified_current_schema()
+ "."
+ random_name_for_temp_object(TempObjectType.FILE_FORMAT)
temp_file_format_name = self._session.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.FILE_FORMAT)
)
drop_tmp_file_format_if_exists_query: Optional[str] = None
use_temp_file_format = "FORMAT_NAME" not in self._cur_options
Expand Down Expand Up @@ -685,7 +681,6 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
path,
format,
self._cur_options,
self._session.get_fully_qualified_current_schema(),
schema,
schema_to_cast=schema_to_cast,
transformations=read_file_transformations,
Expand All @@ -704,7 +699,6 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
path,
format,
self._cur_options,
self._session.get_fully_qualified_current_schema(),
schema,
schema_to_cast=schema_to_cast,
transformations=read_file_transformations,
Expand Down
1 change: 0 additions & 1 deletion src/snowflake/snowpark/mock/_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def read_file(
path: str,
format: str,
options: Dict[str, str],
fully_qualified_schema: str,
schema: List[Attribute],
schema_to_cast: Optional[List[Tuple[str, str]]] = None,
transformations: Optional[List[str]] = None,
Expand Down
26 changes: 17 additions & 9 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,7 +1313,6 @@ def _resolve_packages(
package_dict = self._parse_packages(packages)

package_table = "information_schema.packages"
# TODO: Use the database from fully qualified UDF name
if not self.get_current_database():
package_table = f"snowflake.{package_table}"

Expand Down Expand Up @@ -2018,17 +2017,15 @@ def get_session_stage(self) -> str:
These artifacts include libraries and packages for UDFs that you define
in this session via :func:`add_import`.
"""
qualified_stage_name = (
f"{self.get_fully_qualified_current_schema()}.{self._session_stage}"
)
stage_name = self.get_fully_qualified_name_if_possible(self._session_stage)
if not self._stage_created:
self._run_query(
f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \
stage if not exists {qualified_stage_name}",
stage if not exists {stage_name}",
is_ddl_on_temp_object=True,
)
self._stage_created = True
return f"{STAGE_PREFIX}{qualified_stage_name}"
return f"{STAGE_PREFIX}{stage_name}"

def write_pandas(
self,
Expand Down Expand Up @@ -2313,7 +2310,7 @@ def create_dataframe(
self._run_query(
f"CREATE SCOPED TEMP TABLE {temp_table_name} ({schema_string})"
)
schema_query = f"SELECT * FROM {self.get_current_database()}.{self.get_current_schema()}.{temp_table_name}"
schema_query = f"SELECT * FROM {self.get_fully_qualified_name_if_possible(temp_table_name)}"
except ProgrammingError as e:
logging.debug(
f"Cannot create temp table for specified non-nullable schema, fall back to using schema "
Expand Down Expand Up @@ -2596,16 +2593,27 @@ def get_current_schema(self) -> Optional[str]:
"""
return self._conn._get_current_parameter("schema")

@deprecated(version="1.14.0")
def get_fully_qualified_current_schema(self) -> str:
"""Returns the fully qualified name of the current schema for the session."""
return self.get_fully_qualified_name_if_possible("")[:-1]

def get_fully_qualified_name_if_possible(self, name: str) -> str:
"""
Returns the fully qualified object name if current database/schema exists, otherwise returns the object name
"""
database = self.get_current_database()
schema = self.get_current_schema()
if database is None or schema is None:
if database and schema:
return f"{database}.{schema}.{name}"

# In stored procedure, there are scenarios like bundle where we allow empty current schema
if not is_in_stored_procedure():
missing_item = "DATABASE" if not database else "SCHEMA"
raise SnowparkClientExceptionMessages.SERVER_CANNOT_FIND_CURRENT_DB_OR_SCHEMA(
missing_item, missing_item, missing_item
)
return database + "." + schema
return name

def get_current_warehouse(self) -> Optional[str]:
"""
Expand Down
8 changes: 6 additions & 2 deletions tests/integ/scala/test_permanent_udf_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ def test_support_fully_qualified_udf_name(session, new_session):
def add_one(x: int) -> int:
return x + 1

temp_func_name = f"{session.get_fully_qualified_current_schema()}.{Utils.random_name_for_temp_object(TempObjectType.FUNCTION)}"
perm_func_name = f"{session.get_fully_qualified_current_schema()}.{Utils.random_name_for_temp_object(TempObjectType.FUNCTION)}"
temp_func_name = session.get_fully_qualified_name_if_possible(
Utils.random_name_for_temp_object(TempObjectType.FUNCTION)
)
perm_func_name = session.get_fully_qualified_name_if_possible(
Utils.random_name_for_temp_object(TempObjectType.FUNCTION)
)
stage_name = Utils.random_stage_name()
try:
Utils.create_stage(session, stage_name, is_temporary=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/scala/test_session_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_negative_test_for_missing_required_parameter_schema(
new_session.sql_simplifier_enabled = sql_simplifier_enabled
with new_session:
with pytest.raises(SnowparkMissingDbOrSchemaException) as ex_info:
new_session.get_fully_qualified_current_schema()
new_session.get_fully_qualified_name_if_possible("table")
assert "The SCHEMA is not set for the current session." in str(ex_info)


Expand Down
2 changes: 1 addition & 1 deletion tests/integ/scala/test_udf_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_call_udf_api(session):
df.with_column(
"c",
call_udf(
f"{session.get_fully_qualified_current_schema()}.{function_name}",
session.get_fully_qualified_name_if_possible(function_name),
col("a"),
),
).collect(),
Expand Down
6 changes: 3 additions & 3 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_list_files_in_stage(session, resources_path, local_testing_mode):
assert len(files) == 1
assert os.path.basename(test_files.test_file_avro) in files

full_name = f"{session.get_fully_qualified_current_schema()}.{stage_name}"
full_name = session.get_fully_qualified_name_if_possible(stage_name)
files2 = session._list_files_in_stage(full_name)
assert len(files2) == 1
assert os.path.basename(test_files.test_file_avro) in files2
Expand All @@ -241,8 +241,8 @@ def test_list_files_in_stage(session, resources_path, local_testing_mode):
assert len(files4) == 1
assert os.path.basename(test_files.test_file_avro) in files4

full_name_with_prefix = (
f"{session.get_fully_qualified_current_schema()}.{quoted_name}"
full_name_with_prefix = session.get_fully_qualified_name_if_possible(
quoted_name
)
files5 = session._list_files_in_stage(full_name_with_prefix)
assert len(files5) == 1
Expand Down
4 changes: 1 addition & 3 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ def test_call_named_stored_procedure(session, temp_schema, db_parameters):
)
assert session.call(sproc_name, 13, 19) == 13 * 19
assert (
session.call(
f"{session.get_fully_qualified_current_schema()}.{sproc_name}", 13, 19
)
session.call(session.get_fully_qualified_name_if_possible(sproc_name), 13, 19)
== 13 * 19
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_call_named_udf(session, temp_schema, db_parameters):
Utils.check_answer(
df.select(
call_udf(
f"{session.get_fully_qualified_current_schema()}.{mult_udf_name}",
session.get_fully_qualified_name_if_possible(mult_udf_name),
6,
7,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_negative_execute_as():
def test_do_register_sp_negative(cleanup_registration_patch):
fake_session = mock.create_autospec(Session)
fake_session._runtime_version_from_requirement = None
fake_session.get_fully_qualified_current_schema = mock.Mock(
fake_session.get_fully_qualified_name_if_possible = mock.Mock(
return_value="database.schema"
)
fake_session._run_query = mock.Mock(side_effect=ProgrammingError())
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class FakeClass:
@mock.patch("snowflake.snowpark.udaf.cleanup_failed_permanent_registration")
def test_do_register_udaf_negative(cleanup_registration_patch):
fake_session = mock.create_autospec(Session)
fake_session.get_fully_qualified_current_schema = mock.Mock(
fake_session.get_fully_qualified_name_if_possible = mock.Mock(
return_value="database.schema"
)
fake_session._run_query = mock.Mock(side_effect=ProgrammingError())
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_do_register_sp_negative(cleanup_registration_patch):
fake_session = mock.create_autospec(Session)
fake_session._runtime_version_from_requirement = None
fake_session.get_fully_qualified_current_schema = mock.Mock(
fake_session.get_fully_qualified_name_if_possible = mock.Mock(
return_value="database.schema"
)
fake_session._run_query = mock.Mock(side_effect=ProgrammingError())
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@mock.patch("snowflake.snowpark.udtf.cleanup_failed_permanent_registration")
def test_do_register_sp_negative(cleanup_registration_patch):
fake_session = mock.create_autospec(Session)
fake_session.get_fully_qualified_current_schema = mock.Mock(
fake_session.get_fully_qualified_name_if_possible = mock.Mock(
return_value="database.schema"
)
fake_session._run_query = mock.Mock(side_effect=ProgrammingError())
Expand Down

0 comments on commit 1a22e98

Please sign in to comment.