Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1183322: [Local Testing] Add support for registering sprocs #1338

Merged
merged 10 commits into from
Apr 15, 2024
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features

- Added support for registering udfs and stored procedure to local testing.
- Added support for the following local testing APIs:
- snowflake.snowpark.Session:
- file.put
Expand All @@ -25,9 +26,10 @@
- current_database
- current_session
- date_trunc
- udf
- object_construct
- object_construct_keep_null
- pow
- sqrt
- Added the function `DataFrame.write.csv` to unload data from a ``DataFrame`` into one or more CSV files in a stage.
- Added distributed tracing using open telemetry apis for action functions in `DataFrame` and `DataFrameWriter`:
- snowflake.snowpark.DataFrame:
Expand Down
14 changes: 14 additions & 0 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@ def mock_listagg(column: ColumnEmulator, delimiter: str, is_distinct: bool):
)


@patch("sqrt")
def mock_sqrt(column: ColumnEmulator):
result = column.apply(math.sqrt)
result.sf_type = ColumnType(FloatType(), column.sf_type.nullable)
return result


@patch("pow")
def mock_pow(left: ColumnEmulator, right: ColumnEmulator):
result = left.combine(right, lambda l, r: l**r)
result.sf_type = ColumnType(FloatType(), left.sf_type.nullable)
return result


@patch("to_date")
def mock_to_date(
column: ColumnEmulator,
Expand Down
238 changes: 238 additions & 0 deletions src/snowflake/snowpark/mock/_stored_procedure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import json
import sys
import typing
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import snowflake.snowpark
from snowflake.snowpark._internal.udf_utils import (
check_python_runtime_version,
process_registration_inputs,
)
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.column import Column
from snowflake.snowpark.dataframe import DataFrame
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock import CUSTOM_JSON_ENCODER
from snowflake.snowpark.mock._plan import calculate_expression
from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator
from snowflake.snowpark.stored_procedure import (
StoredProcedure,
StoredProcedureRegistration,
)
from snowflake.snowpark.types import (
ArrayType,
DataType,
MapType,
StructType,
_FractionalType,
_IntegralType,
)

from ._telemetry import LocalTestOOBTelemetryService

if sys.version_info <= (3, 9):
from typing import Iterable
else:
from collections.abc import Iterable


def sproc_types_are_compatible(x, y):
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we work on coercion, this function should be replaced by like DataType.can_coerce_to(DataType). @sfc-gh-aling

isinstance(x, type(y))
or isinstance(x, _IntegralType)
and isinstance(y, _IntegralType)
or isinstance(x, _FractionalType)
and isinstance(y, _FractionalType)
):
return True
return False


class MockStoredProcedure(StoredProcedure):
def __call__(
self,
*args: Any,
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
) -> Any:
args, session = self._validate_call(args, session)

# Unpack columns if passed
parsed_args = []
for arg, expected_type in zip(args, self._input_types):
if isinstance(arg, Column):
expr = arg._expression

# If expression does not define its datatype we cannot verify it's compatibale.
# This is potentially unsafe.
if expr.datatype and not sproc_types_are_compatible(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: in what case is expr.datatype empty here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr.datatype, expected_type
):
raise ValueError(
f"Unexpected type {expr.datatype} for sproc argument of type {expected_type}"
)

# Expression may be a nested expression. Expression should not need any input data
# and should only return one value so that it can be passed as a literal value.
# We pass in a single None value so that the expression evaluator has some data to
# pass to the expressions.
resolved_expr = calculate_expression(
expr,
ColumnEmulator(data=[None]),
session._analyzer,
{},
)

# If the length of the resolved expression is not a single value we cannot pass it as a literal.
if len(resolved_expr) != 1:
raise ValueError(
"[Local Testing] Unexpected argument type {expr.__class__.__name__} for call to sproc"
)
parsed_args.append(resolved_expr[0])
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved
else:
parsed_args.append(arg)

result = self.func(session, *parsed_args)

# Semi-structured types are serialized in json
if isinstance(
self._return_type,
(
ArrayType,
MapType,
StructType,
),
) and not isinstance(result, DataFrame):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no support for table sprocs yet, why do we need not isinstance(result, DataFrame)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen examples in the tests where a sproc returns a Dataframe, but it's signature says it returns a Structtype. If it's truly a struct type it needs to be serialized, but if it's just a dataframe it does not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could you please point me to the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here.

result = json.dumps(result, indent=2, cls=CUSTOM_JSON_ENCODER)

return result


class MockStoredProcedureRegistration(StoredProcedureRegistration):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._registry: Dict[str, Callable] = dict()

def register_from_file(
self,
file_path: str,
func_name: str,
return_type: Optional[DataType] = None,
input_types: Optional[List[DataType]] = None,
name: Optional[Union[str, Iterable[str]]] = None,
is_permanent: bool = False,
stage_location: Optional[str] = None,
imports: Optional[List[Union[str, Tuple[str, str]]]] = None,
packages: Optional[List[Union[str, ModuleType]]] = None,
replace: bool = False,
if_not_exists: bool = False,
parallel: int = 4,
execute_as: typing.Literal["caller", "owner"] = "owner",
strict: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
skip_upload_on_content_match: bool = False,
) -> StoredProcedure:
LocalTestOOBTelemetryService.get_instance().log_not_supported_error(
external_feature_name="register sproc from file",
internal_feature_name="MockStoredProcedureRegistration.register_from_file",
parameters_info={},
raise_error=NotImplementedError,
)

def _do_register_sp(
self,
func: Union[Callable, Tuple[str, str]],
return_type: DataType,
input_types: List[DataType],
sp_name: str,
stage_location: Optional[str],
imports: Optional[List[Union[str, Tuple[str, str]]]],
packages: Optional[List[Union[str, ModuleType]]],
replace: bool,
if_not_exists: bool,
parallel: int,
strict: bool,
*,
source_code_display: bool = False,
statement_params: Optional[Dict[str, str]] = None,
execute_as: typing.Literal["caller", "owner"] = "owner",
anonymous: bool = False,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
force_inline_code: bool = False,
) -> StoredProcedure:
(
udf_name,
is_pandas_udf,
is_dataframe_input,
return_type,
input_types,
) = process_registration_inputs(
self._session,
TempObjectType.PROCEDURE,
func,
return_type,
input_types,
sp_name,
anonymous,
)

if is_pandas_udf:
raise TypeError("pandas stored procedure is not supported")

if packages or imports:
LocalTestOOBTelemetryService.get_instance().log_not_supported_error(
external_feature_name="uploading imports and packages for sprocs",
internal_feature_name="MockStoredProcedureRegistration._do_register_sp",
parameters_info={},
raise_error=NotImplementedError,
)

check_python_runtime_version(self._session._runtime_version_from_requirement)

if udf_name in self._registry and not replace:
raise SnowparkSQLException(
sfc-gh-jrose marked this conversation as resolved.
Show resolved Hide resolved
f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.",
error_code="1304",
)

sproc = MockStoredProcedure(
func,
return_type,
input_types,
udf_name,
execute_as=execute_as,
)

self._registry[udf_name] = sproc

return sproc

def call(
self,
sproc_name: str,
*args: Any,
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
):

if sproc_name not in self._registry:
raise SnowparkSQLException(
f"[Local Testing] sproc {sproc_name} does not exist."
)

return self._registry[sproc_name](
*args, session=session, statement_params=statement_params
)
15 changes: 9 additions & 6 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
_extract_schema_and_data_from_pandas_df,
)
from snowflake.snowpark.mock._plan_builder import MockSnowflakePlanBuilder
from snowflake.snowpark.mock._stored_procedure import MockStoredProcedureRegistration
from snowflake.snowpark.mock._udf import MockUDFRegistration
from snowflake.snowpark.query_history import QueryHistory
from snowflake.snowpark.row import Row
Expand Down Expand Up @@ -441,12 +442,14 @@ def __init__(

if isinstance(conn, MockServerConnection):
self._udf_registration = MockUDFRegistration(self)
self._sp_registration = MockStoredProcedureRegistration(self)
else:
self._udf_registration = UDFRegistration(self)
self._sp_registration = StoredProcedureRegistration(self)

self._udtf_registration = UDTFRegistration(self)
self._udaf_registration = UDAFRegistration(self)
self._sp_registration = StoredProcedureRegistration(self)

self._plan_builder = (
SnowflakePlanBuilder(self)
if isinstance(self._conn, ServerConnection)
Expand Down Expand Up @@ -2805,11 +2808,6 @@ def sproc(self) -> StoredProcedureRegistration:
Returns a :class:`stored_procedure.StoredProcedureRegistration` object that you can use to register stored procedures.
See details of how to use this object in :class:`stored_procedure.StoredProcedureRegistration`.
"""
if isinstance(self, MockServerConnection):
self._conn.log_not_supported_error(
external_feature_name="Session.sproc",
raise_error=NotImplementedError,
)
return self._sp_registration

def _infer_is_return_table(
Expand Down Expand Up @@ -2917,6 +2915,11 @@ def _call(
is_return_table: When set to a non-null value, it signifies whether the return type of sproc_name
is a table return type. This skips infer check and returns a dataframe with appropriate sql call.
"""
if isinstance(self._sp_registration, MockStoredProcedureRegistration):
return self._sp_registration.call(
sproc_name, *args, session=self, statement_params=statement_params
)

validate_object_name(sproc_name)
query = generate_call_python_sp_sql(self, sproc_name, *args)

Expand Down
17 changes: 13 additions & 4 deletions src/snowflake/snowpark/stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ def __init__(
self._anonymous_sp_sql = anonymous_sp_sql
self._is_return_table = isinstance(return_type, StructType)

def __call__(
def _validate_call(
self,
*args: Any,
args: List[Any],
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
) -> Any:
):
if args and isinstance(args[0], snowflake.snowpark.session.Session):
if session:
raise ValueError(
Expand All @@ -98,6 +97,16 @@ def __call__(
f"Incorrect number of arguments passed to the stored procedure. Expected: {len(self._input_types)}, Found: {len(args)}"
)

return args, session

def __call__(
self,
*args: Any,
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
) -> Any:
args, session = self._validate_call(args, session)

session._conn._telemetry_client.send_function_usage_telemetry(
"StoredProcedure.__call__", TelemetryField.FUNC_CAT_USAGE.value
)
Expand Down
2 changes: 2 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def test_random(session):
df.select(random()).collect()


@pytest.mark.localtest
def test_sqrt(session):
Utils.check_answer(
TestData.test_data1(session).select(sqrt(col("NUM"))),
Expand Down Expand Up @@ -552,6 +553,7 @@ def test_log(session):
)


@pytest.mark.localtest
def test_pow(session):
Utils.check_answer(
TestData.double2(session).select(pow(col("A"), col("B"))),
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def test_strtok_to_array(session):
assert res[0] == "a" and res[1] == "b" and res[2] == "c"


@pytest.mark.local
@pytest.mark.localtest
@pytest.mark.parametrize("use_col", [True, False])
@pytest.mark.parametrize(
"values,expected",
Expand All @@ -431,7 +431,7 @@ def test_greatest(session, use_col, values, expected):
assert res[0][0] == expected


@pytest.mark.local
@pytest.mark.localtest
@pytest.mark.parametrize("use_col", [True, False])
@pytest.mark.parametrize(
"values,expected",
Expand Down
Loading
Loading