diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ce669f04d6..5fa7a05a7d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features +- Added support for registering udfs and stored procedure to local testing. - Added support for the following local testing APIs: - snowflake.snowpark.Session: - file.put @@ -25,9 +26,10 @@ - current_database - current_session - date_trunc - - udf - object_construct - object_construct_keep_null + - pow + - sqrt - Added the function `DataFrame.write.csv` to unload data from a ``DataFrame`` into one or more CSV files in a stage. - Added distributed tracing using open telemetry apis for action functions in `DataFrame` and `DataFrameWriter`: - snowflake.snowpark.DataFrame: diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index 30527b43db6..32144f61a41 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -273,6 +273,20 @@ def mock_listagg(column: ColumnEmulator, delimiter: str, is_distinct: bool): ) +@patch("sqrt") +def mock_sqrt(column: ColumnEmulator): + result = column.apply(math.sqrt) + result.sf_type = ColumnType(FloatType(), column.sf_type.nullable) + return result + + +@patch("pow") +def mock_pow(left: ColumnEmulator, right: ColumnEmulator): + result = left.combine(right, lambda l, r: l**r) + result.sf_type = ColumnType(FloatType(), left.sf_type.nullable) + return result + + @patch("to_date") def mock_to_date( column: ColumnEmulator, diff --git a/src/snowflake/snowpark/mock/_stored_procedure.py b/src/snowflake/snowpark/mock/_stored_procedure.py new file mode 100644 index 00000000000..03d89899622 --- /dev/null +++ b/src/snowflake/snowpark/mock/_stored_procedure.py @@ -0,0 +1,238 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import json +import sys +import typing +from types import ModuleType +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import snowflake.snowpark +from snowflake.snowpark._internal.udf_utils import ( + check_python_runtime_version, + process_registration_inputs, +) +from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark.column import Column +from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.mock import CUSTOM_JSON_ENCODER +from snowflake.snowpark.mock._plan import calculate_expression +from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator +from snowflake.snowpark.stored_procedure import ( + StoredProcedure, + StoredProcedureRegistration, +) +from snowflake.snowpark.types import ( + ArrayType, + DataType, + MapType, + StructType, + _FractionalType, + _IntegralType, +) + +from ._telemetry import LocalTestOOBTelemetryService + +if sys.version_info <= (3, 9): + from typing import Iterable +else: + from collections.abc import Iterable + + +def sproc_types_are_compatible(x, y): + if ( + isinstance(x, type(y)) + or isinstance(x, _IntegralType) + and isinstance(y, _IntegralType) + or isinstance(x, _FractionalType) + and isinstance(y, _FractionalType) + ): + return True + return False + + +class MockStoredProcedure(StoredProcedure): + def __call__( + self, + *args: Any, + session: Optional["snowflake.snowpark.session.Session"] = None, + statement_params: Optional[Dict[str, str]] = None, + ) -> Any: + args, session = self._validate_call(args, session) + + # Unpack columns if passed + parsed_args = [] + for arg, expected_type in zip(args, self._input_types): + if isinstance(arg, Column): + expr = arg._expression + + # If expression does not define its datatype we cannot verify it's compatibale. + # This is potentially unsafe. + if expr.datatype and not sproc_types_are_compatible( + expr.datatype, expected_type + ): + raise ValueError( + f"Unexpected type {expr.datatype} for sproc argument of type {expected_type}" + ) + + # Expression may be a nested expression. Expression should not need any input data + # and should only return one value so that it can be passed as a literal value. + # We pass in a single None value so that the expression evaluator has some data to + # pass to the expressions. + resolved_expr = calculate_expression( + expr, + ColumnEmulator(data=[None]), + session._analyzer, + {}, + ) + + # If the length of the resolved expression is not a single value we cannot pass it as a literal. + if len(resolved_expr) != 1: + raise ValueError( + "[Local Testing] Unexpected argument type {expr.__class__.__name__} for call to sproc" + ) + parsed_args.append(resolved_expr[0]) + else: + parsed_args.append(arg) + + result = self.func(session, *parsed_args) + + # Semi-structured types are serialized in json + if isinstance( + self._return_type, + ( + ArrayType, + MapType, + StructType, + ), + ) and not isinstance(result, DataFrame): + result = json.dumps(result, indent=2, cls=CUSTOM_JSON_ENCODER) + + return result + + +class MockStoredProcedureRegistration(StoredProcedureRegistration): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._registry: Dict[str, Callable] = dict() + + def register_from_file( + self, + file_path: str, + func_name: str, + return_type: Optional[DataType] = None, + input_types: Optional[List[DataType]] = None, + name: Optional[Union[str, Iterable[str]]] = None, + is_permanent: bool = False, + stage_location: Optional[str] = None, + imports: Optional[List[Union[str, Tuple[str, str]]]] = None, + packages: Optional[List[Union[str, ModuleType]]] = None, + replace: bool = False, + if_not_exists: bool = False, + parallel: int = 4, + execute_as: typing.Literal["caller", "owner"] = "owner", + strict: bool = False, + external_access_integrations: Optional[List[str]] = None, + secrets: Optional[Dict[str, str]] = None, + *, + statement_params: Optional[Dict[str, str]] = None, + source_code_display: bool = True, + skip_upload_on_content_match: bool = False, + ) -> StoredProcedure: + LocalTestOOBTelemetryService.get_instance().log_not_supported_error( + external_feature_name="register sproc from file", + internal_feature_name="MockStoredProcedureRegistration.register_from_file", + parameters_info={}, + raise_error=NotImplementedError, + ) + + def _do_register_sp( + self, + func: Union[Callable, Tuple[str, str]], + return_type: DataType, + input_types: List[DataType], + sp_name: str, + stage_location: Optional[str], + imports: Optional[List[Union[str, Tuple[str, str]]]], + packages: Optional[List[Union[str, ModuleType]]], + replace: bool, + if_not_exists: bool, + parallel: int, + strict: bool, + *, + source_code_display: bool = False, + statement_params: Optional[Dict[str, str]] = None, + execute_as: typing.Literal["caller", "owner"] = "owner", + anonymous: bool = False, + api_call_source: str, + skip_upload_on_content_match: bool = False, + is_permanent: bool = False, + external_access_integrations: Optional[List[str]] = None, + secrets: Optional[Dict[str, str]] = None, + force_inline_code: bool = False, + ) -> StoredProcedure: + ( + udf_name, + is_pandas_udf, + is_dataframe_input, + return_type, + input_types, + ) = process_registration_inputs( + self._session, + TempObjectType.PROCEDURE, + func, + return_type, + input_types, + sp_name, + anonymous, + ) + + if is_pandas_udf: + raise TypeError("pandas stored procedure is not supported") + + if packages or imports: + LocalTestOOBTelemetryService.get_instance().log_not_supported_error( + external_feature_name="uploading imports and packages for sprocs", + internal_feature_name="MockStoredProcedureRegistration._do_register_sp", + parameters_info={}, + raise_error=NotImplementedError, + ) + + check_python_runtime_version(self._session._runtime_version_from_requirement) + + if udf_name in self._registry and not replace: + raise SnowparkSQLException( + f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.", + error_code="1304", + ) + + sproc = MockStoredProcedure( + func, + return_type, + input_types, + udf_name, + execute_as=execute_as, + ) + + self._registry[udf_name] = sproc + + return sproc + + def call( + self, + sproc_name: str, + *args: Any, + session: Optional["snowflake.snowpark.session.Session"] = None, + statement_params: Optional[Dict[str, str]] = None, + ): + + if sproc_name not in self._registry: + raise SnowparkSQLException( + f"[Local Testing] sproc {sproc_name} does not exist." + ) + + return self._registry[sproc_name]( + *args, session=session, statement_params=statement_params + ) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 7667adbebc8..af7a2d92877 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -136,6 +136,7 @@ _extract_schema_and_data_from_pandas_df, ) from snowflake.snowpark.mock._plan_builder import MockSnowflakePlanBuilder +from snowflake.snowpark.mock._stored_procedure import MockStoredProcedureRegistration from snowflake.snowpark.mock._udf import MockUDFRegistration from snowflake.snowpark.query_history import QueryHistory from snowflake.snowpark.row import Row @@ -441,12 +442,14 @@ def __init__( if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) + self._sp_registration = MockStoredProcedureRegistration(self) else: self._udf_registration = UDFRegistration(self) + self._sp_registration = StoredProcedureRegistration(self) self._udtf_registration = UDTFRegistration(self) self._udaf_registration = UDAFRegistration(self) - self._sp_registration = StoredProcedureRegistration(self) + self._plan_builder = ( SnowflakePlanBuilder(self) if isinstance(self._conn, ServerConnection) @@ -2805,11 +2808,6 @@ def sproc(self) -> StoredProcedureRegistration: Returns a :class:`stored_procedure.StoredProcedureRegistration` object that you can use to register stored procedures. See details of how to use this object in :class:`stored_procedure.StoredProcedureRegistration`. """ - if isinstance(self, MockServerConnection): - self._conn.log_not_supported_error( - external_feature_name="Session.sproc", - raise_error=NotImplementedError, - ) return self._sp_registration def _infer_is_return_table( @@ -2917,6 +2915,11 @@ def _call( is_return_table: When set to a non-null value, it signifies whether the return type of sproc_name is a table return type. This skips infer check and returns a dataframe with appropriate sql call. """ + if isinstance(self._sp_registration, MockStoredProcedureRegistration): + return self._sp_registration.call( + sproc_name, *args, session=self, statement_params=statement_params + ) + validate_object_name(sproc_name) query = generate_call_python_sp_sql(self, sproc_name, *args) diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index e066287d703..f3de40ea953 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -76,12 +76,11 @@ def __init__( self._anonymous_sp_sql = anonymous_sp_sql self._is_return_table = isinstance(return_type, StructType) - def __call__( + def _validate_call( self, - *args: Any, + args: List[Any], session: Optional["snowflake.snowpark.session.Session"] = None, - statement_params: Optional[Dict[str, str]] = None, - ) -> Any: + ): if args and isinstance(args[0], snowflake.snowpark.session.Session): if session: raise ValueError( @@ -98,6 +97,16 @@ def __call__( f"Incorrect number of arguments passed to the stored procedure. Expected: {len(self._input_types)}, Found: {len(args)}" ) + return args, session + + def __call__( + self, + *args: Any, + session: Optional["snowflake.snowpark.session.Session"] = None, + statement_params: Optional[Dict[str, str]] = None, + ) -> Any: + args, session = self._validate_call(args, session) + session._conn._telemetry_client.send_function_usage_telemetry( "StoredProcedure.__call__", TelemetryField.FUNC_CAT_USAGE.value ) diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index b7d95bb77d1..2ebddda9e60 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -482,6 +482,7 @@ def test_random(session): df.select(random()).collect() +@pytest.mark.localtest def test_sqrt(session): Utils.check_answer( TestData.test_data1(session).select(sqrt(col("NUM"))), @@ -552,6 +553,7 @@ def test_log(session): ) +@pytest.mark.localtest def test_pow(session): Utils.check_answer( TestData.double2(session).select(pow(col("A"), col("B"))), diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index 1903794d0b2..c515ea02420 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -410,7 +410,7 @@ def test_strtok_to_array(session): assert res[0] == "a" and res[1] == "b" and res[2] == "c" -@pytest.mark.local +@pytest.mark.localtest @pytest.mark.parametrize("use_col", [True, False]) @pytest.mark.parametrize( "values,expected", @@ -431,7 +431,7 @@ def test_greatest(session, use_col, values, expected): assert res[0][0] == expected -@pytest.mark.local +@pytest.mark.localtest @pytest.mark.parametrize("use_col", [True, False]) @pytest.mark.parametrize( "values,expected", diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index e2a444e1557..481cc7de562 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -30,11 +30,14 @@ SnowparkSQLException, ) from snowflake.snowpark.functions import ( + cast, col, current_date, date_from_parts, + iff, lit, max as max_, + pow, sproc, sqrt, ) @@ -161,18 +164,34 @@ def return1(session_): ) -def test_basic_stored_procedure(session): +@pytest.mark.localtest +def test_basic_stored_procedure(session, local_testing_mode): def return1(session_): - return session_.sql("select '1'").collect()[0][0] + return session_.create_dataframe([["1"]]).collect()[0][0] def plus1(session_, x): - return session_.sql(f"select {x} + 1").collect()[0][0] + return ( + session_.create_dataframe([[x]]) + .to_df(["a"]) + .select(col("a") + lit(1)) + .collect()[0][0] + ) def add(session_, x, y): - return session_.sql(f"select {x} + {y}").collect()[0][0] + return ( + session_.create_dataframe([[x, y]]) + .to_df(["a", "b"]) + .select(col("a") + col("b")) + .collect()[0][0] + ) def int2str(session_, x): - return session_.sql(f"select cast({x} as string)").collect()[0][0] + return ( + session_.create_dataframe([[x]]) + .to_df(["a"]) + .select(cast(col("a"), "string")) + .collect()[0][0] + ) return1_sp = sproc(return1, return_type=StringType()) plus1_sp = sproc(plus1, return_type=IntegerType(), input_types=[IntegerType()]) @@ -180,59 +199,79 @@ def int2str(session_, x): add, return_type=IntegerType(), input_types=[IntegerType(), IntegerType()] ) int2str_sp = sproc(int2str, return_type=StringType(), input_types=[IntegerType()]) - pow_sp = sproc( - lambda session_, x, y: session_.sql(f"select pow({x}, {y})").collect()[0][0], - return_type=DoubleType(), - input_types=[IntegerType(), IntegerType()], - ) assert return1_sp() == "1" assert plus1_sp(1) == 2 assert add_sp(4, 6) == 10 assert int2str_sp(123) == "123" - assert pow_sp(2, 10) == 1024 assert return1_sp(session=session) == "1" assert plus1_sp(1, session=session) == 2 assert add_sp(4, 6, session=session) == 10 assert int2str_sp(123, session=session) == "123" + + def sp_pow(session_, x, y): + return ( + session_.create_dataframe([[x, y]]) + .to_df(["a", "b"]) + .select(pow(col("a"), col("b"))) + .collect()[0][0] + ) + + pow_sp = sproc( + sp_pow, + return_type=DoubleType(), + input_types=[IntegerType(), IntegerType()], + ) + assert pow_sp(2, 10) == 1024 assert pow_sp(2, 10, session=session) == 1024 -def test_stored_procedure_with_column_datatype(session): +@pytest.mark.localtest +def test_stored_procedure_with_basic_column_datatype(session, local_testing_mode): + expected_err = Exception if local_testing_mode else SnowparkSQLException + def plus1(session_, x): return x + 1 + plus1_sp = sproc(plus1, return_type=IntegerType(), input_types=[IntegerType()]) + assert plus1_sp(lit(6)) == 7 + + with pytest.raises(expected_err) as ex_info: + plus1_sp(col("a")) + assert local_testing_mode or "invalid identifier" in str(ex_info) + + with pytest.raises(expected_err) as ex_info: + plus1_sp(current_date()) + assert local_testing_mode or "Invalid argument types for function" in str(ex_info) + + with pytest.raises(expected_err) as ex_info: + plus1_sp(lit("")) + assert local_testing_mode or "not recognized" in str(ex_info) + + +@pytest.mark.localtest +def test_stored_procedure_with_column_datatype(session, local_testing_mode): def add(session_, x, y): return x + y - def add_date(session_, date, add_days): - return date + datetime.timedelta(days=add_days) - - plus1_sp = sproc(plus1, return_type=IntegerType(), input_types=[IntegerType()]) add_sp = sproc( add, return_type=IntegerType(), input_types=[IntegerType(), IntegerType()] ) - add_date_sp = sproc( - add_date, return_type=DateType(), input_types=[DateType(), IntegerType()] - ) - dt = datetime.date(1992, 12, 14) + datetime.timedelta(days=3) - assert plus1_sp(lit(6)) == 7 assert add_sp(4, sqrt(lit(36))) == 10 - # the date can be different between server and client due to timezone difference - assert -1 <= (add_date_sp(date_from_parts(1992, 12, 14), 3) - dt).days <= 1 - with pytest.raises(SnowparkSQLException) as ex_info: - plus1_sp(col("a")) - assert "invalid identifier" in str(ex_info) + if not local_testing_mode: + dt = datetime.date(1992, 12, 14) + datetime.timedelta(days=3) - with pytest.raises(SnowparkSQLException) as ex_info: - plus1_sp(current_date()) - assert "Invalid argument types for function" in str(ex_info) + def add_date(session_, date, add_days): + return date + datetime.timedelta(days=add_days) - with pytest.raises(SnowparkSQLException) as ex_info: - plus1_sp(lit("")) - assert "not recognized" in str(ex_info) + add_date_sp = sproc( + add_date, return_type=DateType(), input_types=[DateType(), IntegerType()] + ) + + # the date can be different between server and client due to timezone difference + assert -1 <= (add_date_sp(date_from_parts(1992, 12, 14), 3) - dt).days <= 1 @pytest.mark.skipif( @@ -293,6 +332,7 @@ def test_call_named_stored_procedure(session, temp_schema, db_parameters): # restore active session +@pytest.mark.localtest @pytest.mark.parametrize("anonymous", [True, False]) def test_call_table_sproc_triggers_action(session, anonymous): """Here we create a table sproc which creates a table. we call the table sproc using @@ -303,7 +343,7 @@ def test_call_table_sproc_triggers_action(session, anonymous): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) def create_temp_table_sp(session_: Session, name: str): - df = session_.sql("select 1 as A") + df = session_.create_dataframe([1]).to_df("A") df.write.save_as_table(name, mode="overwrite") return df @@ -322,6 +362,7 @@ def create_temp_table_sp(session_: Session, name: str): Utils.drop_table(session, table_name) +@pytest.mark.localtest def test_recursive_function(session): # Test recursive function def factorial(session_, n): @@ -333,15 +374,19 @@ def factorial(session_, n): assert factorial_sp(3) == factorial(session, 3) +@pytest.mark.localtest def test_nested_function(session): def outer_func(session_): def inner_func(): return "snow" - return session_.sql(f"select '{inner_func()}-{inner_func()}'").collect()[0][0] + return session_.create_dataframe([f"{inner_func()}-{inner_func()}"]).collect()[ + 0 + ][0] def square(session_, x): - return session_.sql(f"select square({x})").collect()[0][0] + df = session_.create_dataframe([x]).to_df("a") + return df.select(pow("a", lit(2))).collect()[0][0] def cube(session_, x): return square(session_, x) * x @@ -359,6 +404,7 @@ def cube(session_, x): assert square_sp(2) == 4 +@pytest.mark.localtest def test_decorator_function(session): def decorator_do_twice(func): def wrapper(*args, **kwargs): @@ -370,7 +416,8 @@ def wrapper(*args, **kwargs): @decorator_do_twice def square(session_, x): - return session_.sql(f"select square({x})").collect()[0][0] + df = session_.create_dataframe([x]).to_df("a") + return df.select(pow("a", lit(2))).collect()[0][0] square_twice_sp = sproc( square, @@ -380,14 +427,16 @@ def square(session_, x): assert square_twice_sp(2) == 16 +@pytest.mark.localtest def test_annotation_syntax(session): @sproc(return_type=IntegerType(), input_types=[IntegerType(), IntegerType()]) def add_sp(session_, x, y): - return session_.sql(f"SELECT {x} + {y}").collect()[0][0] + df = session_.create_dataframe([(x, y)]).to_df("a", "b") + return df.select(col("a") + col("b")).collect()[0][0] @sproc(return_type=StringType()) def snow(session_): - return session_.sql("SELECT 'snow'").collect()[0][0] + return session_.create_dataframe(["snow"]).collect()[0][0] assert add_sp(1, 2) == 3 assert snow() == "snow" @@ -436,16 +485,23 @@ def test_register_sp_from_file(session, resources_path, tmpdir): ) +@pytest.mark.localtest def test_session_register_sp(session): add_sp = session.sproc.register( - lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[0][0], + lambda session_, x, y: session_.create_dataframe([(x, y)]) + .to_df("a", "b") + .select(col("a") + col("b")) + .collect()[0][0], return_type=IntegerType(), input_types=[IntegerType(), IntegerType()], ) assert add_sp(1, 2) == 3 add_sp = session.sproc.register( - lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[0][0], + lambda session_, x, y: session_.create_dataframe([(x, y)]) + .to_df("a", "b") + .select(col("a") + col("b")) + .collect()[0][0], return_type=IntegerType(), input_types=[IntegerType(), IntegerType()], statement_params={"SF_PARTNER": "FAKE_PARTNER"}, @@ -580,18 +636,34 @@ def plus4_then_mod5(session_, x): assert "No module named" in ex_info.value.message +@pytest.mark.localtest def test_type_hints(session): @sproc() def add_sp(session_: Session, x: int, y: int) -> int: - return session_.sql(f"SELECT {x} + {y}").collect()[0][0] + df = session_.create_dataframe( + [ + (x, y), + ] + ).to_df(["a", "b"]) + return df.select(col("a") + col("b")).collect()[0][0] @sproc def snow_sp(session_: Session, x: int) -> Optional[str]: - return session_.sql(f"SELECT IFF({x} % 2 = 0, 'snow', NULL)").collect()[0][0] + df = session_.create_dataframe( + [ + (x), + ] + ).to_df(["a"]) + return df.select(iff(col("a") % 2 == 0, "snow", None)).collect()[0][0] @sproc def double_str_list_sp(session_: Session, x: str) -> List[str]: - val = session_.sql(f"SELECT '{x}'").collect()[0][0] + df = session_.create_dataframe( + [ + (x), + ] + ).to_df(["a"]) + val = df.collect()[0][0] return [val, val] dt = datetime.datetime.strptime("2017-02-24 12:00:05.456", "%Y-%m-%d %H:%M:%S.%f") @@ -617,9 +689,15 @@ def get_sp(_: Session, d: Dict[str, str], i: str) -> str: assert get_sp({"0": "snow", "1": "flake"}, "0") == "snow" +@pytest.mark.localtest def test_type_hint_no_change_after_registration(session): def add(session_: Session, x: int, y: int) -> int: - return session_.sql(f"SELECT {x} + {y}").collect()[0][0] + return ( + session_.create_dataframe([(x, y)]) + .to_df("a", "b") + .select(col("a") + col("b")) + .collect()[0][0], + ) annotations = add.__annotations__ session.sproc.register(add)