-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1183322: [Local Testing] Add support for registering sprocs #1338
Changes from all commits
ae7584c
62f3a8b
ed92ff9
e52c329
1fecb98
6b648ae
08c3569
f23b398
de3d58f
e7421b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
import json | ||
import sys | ||
import typing | ||
from types import ModuleType | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import snowflake.snowpark | ||
from snowflake.snowpark._internal.udf_utils import ( | ||
check_python_runtime_version, | ||
process_registration_inputs, | ||
) | ||
from snowflake.snowpark._internal.utils import TempObjectType | ||
from snowflake.snowpark.column import Column | ||
from snowflake.snowpark.dataframe import DataFrame | ||
from snowflake.snowpark.exceptions import SnowparkSQLException | ||
from snowflake.snowpark.mock import CUSTOM_JSON_ENCODER | ||
from snowflake.snowpark.mock._plan import calculate_expression | ||
from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator | ||
from snowflake.snowpark.stored_procedure import ( | ||
StoredProcedure, | ||
StoredProcedureRegistration, | ||
) | ||
from snowflake.snowpark.types import ( | ||
ArrayType, | ||
DataType, | ||
MapType, | ||
StructType, | ||
_FractionalType, | ||
_IntegralType, | ||
) | ||
|
||
from ._telemetry import LocalTestOOBTelemetryService | ||
|
||
if sys.version_info <= (3, 9): | ||
from typing import Iterable | ||
else: | ||
from collections.abc import Iterable | ||
|
||
|
||
def sproc_types_are_compatible(x, y): | ||
if ( | ||
isinstance(x, type(y)) | ||
or isinstance(x, _IntegralType) | ||
and isinstance(y, _IntegralType) | ||
or isinstance(x, _FractionalType) | ||
and isinstance(y, _FractionalType) | ||
): | ||
return True | ||
return False | ||
|
||
|
||
class MockStoredProcedure(StoredProcedure): | ||
def __call__( | ||
self, | ||
*args: Any, | ||
session: Optional["snowflake.snowpark.session.Session"] = None, | ||
statement_params: Optional[Dict[str, str]] = None, | ||
) -> Any: | ||
args, session = self._validate_call(args, session) | ||
|
||
# Unpack columns if passed | ||
parsed_args = [] | ||
for arg, expected_type in zip(args, self._input_types): | ||
if isinstance(arg, Column): | ||
expr = arg._expression | ||
|
||
# If expression does not define its datatype we cannot verify it's compatibale. | ||
# This is potentially unsafe. | ||
if expr.datatype and not sproc_types_are_compatible( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: in what case is expr.datatype empty here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The datatype of the sqrt function in this example is None: |
||
expr.datatype, expected_type | ||
): | ||
raise ValueError( | ||
f"Unexpected type {expr.datatype} for sproc argument of type {expected_type}" | ||
) | ||
|
||
# Expression may be a nested expression. Expression should not need any input data | ||
# and should only return one value so that it can be passed as a literal value. | ||
# We pass in a single None value so that the expression evaluator has some data to | ||
# pass to the expressions. | ||
resolved_expr = calculate_expression( | ||
expr, | ||
ColumnEmulator(data=[None]), | ||
session._analyzer, | ||
{}, | ||
) | ||
|
||
# If the length of the resolved expression is not a single value we cannot pass it as a literal. | ||
if len(resolved_expr) != 1: | ||
raise ValueError( | ||
"[Local Testing] Unexpected argument type {expr.__class__.__name__} for call to sproc" | ||
) | ||
parsed_args.append(resolved_expr[0]) | ||
sfc-gh-aling marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
parsed_args.append(arg) | ||
|
||
result = self.func(session, *parsed_args) | ||
|
||
# Semi-structured types are serialized in json | ||
if isinstance( | ||
self._return_type, | ||
( | ||
ArrayType, | ||
MapType, | ||
StructType, | ||
), | ||
) and not isinstance(result, DataFrame): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have no support for table sprocs yet, why do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've seen examples in the tests where a sproc returns a Dataframe, but it's signature says it returns a Structtype. If it's truly a struct type it needs to be serialized, but if it's just a dataframe it does not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Could you please point me to the test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See here. |
||
result = json.dumps(result, indent=2, cls=CUSTOM_JSON_ENCODER) | ||
|
||
return result | ||
|
||
|
||
class MockStoredProcedureRegistration(StoredProcedureRegistration): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._registry: Dict[str, Callable] = dict() | ||
|
||
def register_from_file( | ||
self, | ||
file_path: str, | ||
func_name: str, | ||
return_type: Optional[DataType] = None, | ||
input_types: Optional[List[DataType]] = None, | ||
name: Optional[Union[str, Iterable[str]]] = None, | ||
is_permanent: bool = False, | ||
stage_location: Optional[str] = None, | ||
imports: Optional[List[Union[str, Tuple[str, str]]]] = None, | ||
packages: Optional[List[Union[str, ModuleType]]] = None, | ||
replace: bool = False, | ||
if_not_exists: bool = False, | ||
parallel: int = 4, | ||
execute_as: typing.Literal["caller", "owner"] = "owner", | ||
strict: bool = False, | ||
external_access_integrations: Optional[List[str]] = None, | ||
secrets: Optional[Dict[str, str]] = None, | ||
*, | ||
statement_params: Optional[Dict[str, str]] = None, | ||
source_code_display: bool = True, | ||
skip_upload_on_content_match: bool = False, | ||
) -> StoredProcedure: | ||
LocalTestOOBTelemetryService.get_instance().log_not_supported_error( | ||
external_feature_name="register sproc from file", | ||
internal_feature_name="MockStoredProcedureRegistration.register_from_file", | ||
parameters_info={}, | ||
raise_error=NotImplementedError, | ||
) | ||
|
||
def _do_register_sp( | ||
self, | ||
func: Union[Callable, Tuple[str, str]], | ||
return_type: DataType, | ||
input_types: List[DataType], | ||
sp_name: str, | ||
stage_location: Optional[str], | ||
imports: Optional[List[Union[str, Tuple[str, str]]]], | ||
packages: Optional[List[Union[str, ModuleType]]], | ||
replace: bool, | ||
if_not_exists: bool, | ||
parallel: int, | ||
strict: bool, | ||
*, | ||
source_code_display: bool = False, | ||
statement_params: Optional[Dict[str, str]] = None, | ||
execute_as: typing.Literal["caller", "owner"] = "owner", | ||
anonymous: bool = False, | ||
api_call_source: str, | ||
skip_upload_on_content_match: bool = False, | ||
is_permanent: bool = False, | ||
external_access_integrations: Optional[List[str]] = None, | ||
secrets: Optional[Dict[str, str]] = None, | ||
force_inline_code: bool = False, | ||
) -> StoredProcedure: | ||
( | ||
udf_name, | ||
is_pandas_udf, | ||
is_dataframe_input, | ||
return_type, | ||
input_types, | ||
) = process_registration_inputs( | ||
self._session, | ||
TempObjectType.PROCEDURE, | ||
func, | ||
return_type, | ||
input_types, | ||
sp_name, | ||
anonymous, | ||
) | ||
|
||
if is_pandas_udf: | ||
raise TypeError("pandas stored procedure is not supported") | ||
|
||
if packages or imports: | ||
LocalTestOOBTelemetryService.get_instance().log_not_supported_error( | ||
external_feature_name="uploading imports and packages for sprocs", | ||
internal_feature_name="MockStoredProcedureRegistration._do_register_sp", | ||
parameters_info={}, | ||
raise_error=NotImplementedError, | ||
) | ||
|
||
check_python_runtime_version(self._session._runtime_version_from_requirement) | ||
|
||
if udf_name in self._registry and not replace: | ||
raise SnowparkSQLException( | ||
sfc-gh-jrose marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.", | ||
error_code="1304", | ||
) | ||
|
||
sproc = MockStoredProcedure( | ||
func, | ||
return_type, | ||
input_types, | ||
udf_name, | ||
execute_as=execute_as, | ||
) | ||
|
||
self._registry[udf_name] = sproc | ||
|
||
return sproc | ||
|
||
def call( | ||
self, | ||
sproc_name: str, | ||
*args: Any, | ||
session: Optional["snowflake.snowpark.session.Session"] = None, | ||
statement_params: Optional[Dict[str, str]] = None, | ||
): | ||
|
||
if sproc_name not in self._registry: | ||
raise SnowparkSQLException( | ||
f"[Local Testing] sproc {sproc_name} does not exist." | ||
) | ||
|
||
return self._registry[sproc_name]( | ||
*args, session=session, statement_params=statement_params | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we work on coercion, this function should be replaced by like
DataType.can_coerce_to(DataType)
. @sfc-gh-aling