diff --git a/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/callback_source.py.jinja b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/callback_source.py.jinja index 3e8fa51e48..d7553c392e 100644 --- a/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/callback_source.py.jinja +++ b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/callback_source.py.jinja @@ -1,6 +1,8 @@ +import contextlib import functools +import inspect import sys -from typing import Any, Callable, List +from typing import Callable try: import snowflake.snowpark @@ -12,107 +14,155 @@ except ModuleNotFoundError as exc: ) sys.exit(1) -orig_globals = globals().copy() - -found_correct_version = hasattr( +__snowflake_internal_found_correct_version = hasattr( snowflake.snowpark.context, "_is_execution_environment_sandboxed_for_client" ) and hasattr(snowflake.snowpark.context, "_should_continue_registration") -if not found_correct_version: +if not __snowflake_internal_found_correct_version: print( "Did not find the minimum required version for snowflake-snowpark-python package. Please upgrade to v1.15.0 or higher.", file=sys.stderr, ) sys.exit(1) -__snowflake_cli_native_app_internal_callback_return_list: List[Any] = [] +__snowflake_global_collected_extension_fn_json = [] + +def __snowflake_internal_create_extension_fn_registration_callback(): + def __snowflake_internal_try_extract_lineno(extension_fn): + try: + return inspect.getsourcelines(extension_fn)[1] + except Exception: + return None + + def __snowflake_internal_extract_extension_fn_name(extension_fn): + try: + import snowflake.snowpark._internal.utils as snowpark_utils + + if hasattr(snowpark_utils, 'TEMP_OBJECT_NAME_PREFIX'): + if extension_fn.object_name.startswith(snowpark_utils.TEMP_OBJECT_NAME_PREFIX): + # The object name is a generated one, don't use it + return None + + except Exception: + # ignore any exception and fall back to using the object name reported from Snowpark + pass + + return extension_fn.object_name + + def __snowflake_internal_create_package_list(extension_fn): + return [pkg_spec.strip() for pkg_spec in extension_fn.all_packages.split(",")] + + def __snowflake_internal_make_extension_fn_signature(extension_fn): + # Try to fetch the original argument names from the extension function + try: + args_spec = inspect.getfullargspec(extension_fn.func) + original_arg_names = args_spec[0] + start_index = len(original_arg_names) - len(extension_fn.input_sql_types) + signature = [] + defaults_start_index = len(original_arg_names) - len(args_spec.defaults or []) + for i in range(len(extension_fn.input_sql_types)): + arg = { + 'name': original_arg_names[start_index + i], + 'type': extension_fn.input_sql_types[i] + } + if i >= defaults_start_index: + arg['default'] = args_spec.defaults[defaults_start_index + i] + signature.append(arg) + + return signature + except Exception as e: + msg = str(e) + pass # ignore, we'll use the fallback strategy + + # Failed to extract the original arguments through reflection, fall back to alternative approach + return [ + {"name": input_arg.name, "type": input_type} + for (input_arg, input_type) in zip(extension_fn.input_args, extension_fn.input_sql_types) + ] + + def __snowflake_internal_to_extension_fn_type(object_type): + if object_type.name == "AGGREGATE_FUNCTION": + return "aggregate function" + if object_type.name == "TABLE_FUNCTION": + return "table function" + return object_type.name.lower() + + def __snowflake_internal_extension_fn_to_json(extension_fn): + if not isinstance(extension_fn.func, Callable): + # Unsupported case: extension function is a tuple + return + + if extension_fn.anonymous: + # unsupported, native application extension functions need to be explicitly named + return + + # Collect basic properties of the extension function + extension_fn_json = { + "type": __snowflake_internal_to_extension_fn_type(extension_fn.object_type), + "lineno": __snowflake_internal_try_extract_lineno(extension_fn.func), + "name": __snowflake_internal_extract_extension_fn_name(extension_fn), + "handler": extension_fn.func.__name__, + "imports": extension_fn.all_imports or [], + "packages": __snowflake_internal_create_package_list(extension_fn), + "runtime": extension_fn.runtime_version, + "returns": extension_fn.return_sql.upper().replace("RETURNS ", "").strip(), + "signature": __snowflake_internal_make_extension_fn_signature(extension_fn), + "external_access_integrations": extension_fn.external_access_integrations or [], + "secrets": extension_fn.secrets or {}, + } + if extension_fn.object_type.name == "PROCEDURE" and extension_fn.execute_as is not None: + extension_fn_json['execute_as_caller'] = (extension_fn.execute_as == 'caller') -def __snowflake_cli_native_app_internal_callback_replacement(): - global __snowflake_cli_native_app_internal_callback_return_list + if extension_fn.native_app_params is not None: + schema = extension_fn.native_app_params.get("schema") + if schema is not None: + extension_fn_json["schema"] = schema + app_roles = extension_fn.native_app_params.get("application_roles") + if app_roles is not None: + extension_fn_json["application_roles"] = app_roles - def __snowflake_cli_native_app_internal_transform_snowpark_object_to_json( - extension_function_properties, - ): + return extension_fn_json - {% raw %}ext_fn = extension_function_properties - extension_function_dict = { - "object_type": ext_fn.object_type.name, - "object_name": ext_fn.object_name, - "input_args": [ - {"name": input_arg.name, "datatype": type(input_arg.datatype).__name__} - for input_arg in ext_fn.input_args - ], - "input_sql_types": ext_fn.input_sql_types, - "return_sql": ext_fn.return_sql, - "runtime_version": ext_fn.runtime_version, - "all_imports": ext_fn.all_imports, - "all_packages": ext_fn.all_packages, - "handler": ext_fn.handler, - "external_access_integrations": ext_fn.external_access_integrations, - "secrets": ext_fn.secrets, - "inline_python_code": ext_fn.inline_python_code, - "raw_imports": ext_fn.raw_imports, - "replace": ext_fn.replace, - "if_not_exists": ext_fn.if_not_exists, - "execute_as": ext_fn.execute_as, - "anonymous": ext_fn.anonymous, - # Set func based on type - "func": ext_fn.func.__name__ - if isinstance(ext_fn.func, Callable) - else ext_fn.func, - } - # Set native app params based on dictionary - if ext_fn.native_app_params is not None: - extension_function_dict["schema"] = ext_fn.native_app_params.get( - "schema", None - ) - extension_function_dict["application_roles"] = ext_fn.native_app_params.get( - "application_roles", None - ) - else: - extension_function_dict["schema"] = extension_function_dict[ - "application_roles" - ] = None - # Imports and handler will be set at a later time.{% endraw %} - return extension_function_dict - - def __snowflake_cli_native_app_internal_callback_append_to_list( - callback_return_list, extension_function_properties + def __snowflake_internal_collect_extension_fn( + collected_extension_fn_json_list, extension_function_properties ): - extension_function_dict = ( - __snowflake_cli_native_app_internal_transform_snowpark_object_to_json( - extension_function_properties - ) - ) - callback_return_list.append(extension_function_dict) + extension_fn_json = __snowflake_internal_extension_fn_to_json(extension_function_properties) + collected_extension_fn_json_list.append(extension_fn_json) return False return functools.partial( - __snowflake_cli_native_app_internal_callback_append_to_list, - __snowflake_cli_native_app_internal_callback_return_list, + __snowflake_internal_collect_extension_fn, + __snowflake_global_collected_extension_fn_json, ) - -with open("{{py_file}}", mode="r", encoding="utf-8") as udf_code: - code = udf_code.read() - - snowflake.snowpark.context._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001 True ) snowflake.snowpark.context._should_continue_registration = ( # noqa: SLF001 - __snowflake_cli_native_app_internal_callback_replacement() + __snowflake_internal_create_extension_fn_registration_callback() ) snowflake.snowpark.session._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001 True ) +for global_key in list(globals().keys()): + if global_key.startswith("__snowflake_internal"): + del globals()[global_key] + +del globals()["global_key"] # make sure to clean up the loop variable as well + try: - exec(code, orig_globals) + import importlib + with contextlib.redirect_stdout(None): + with contextlib.redirect_stderr(None): + __snowflake_internal_spec = importlib.util.spec_from_file_location("", "{{py_file}}") + __snowflake_internal_module = importlib.util.module_from_spec(__snowflake_internal_spec) + __snowflake_internal_spec.loader.exec_module(__snowflake_internal_module) except Exception as exc: # Catch any error print("An exception occurred while executing file: ", exc, file=sys.stderr) sys.exit(1) + import json -print(json.dumps(__snowflake_cli_native_app_internal_callback_return_list)) +print(json.dumps(__snowflake_global_collected_extension_fn_json)) diff --git a/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/extension_function_utils.py b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/extension_function_utils.py index f5a8b2b07a..ff931fca3d 100644 --- a/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/extension_function_utils.py +++ b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/extension_function_utils.py @@ -1,17 +1,23 @@ from __future__ import annotations -from pathlib import Path from typing import ( - Any, - Dict, List, Optional, - Tuple, - Type, - Union, + Sequence, ) from click.exceptions import ClickException +from snowflake.cli.api.project.schemas.snowpark.argument import Argument +from snowflake.cli.api.project.util import ( + is_valid_identifier, + is_valid_string_literal, + to_identifier, + to_string_literal, +) +from snowflake.cli.plugins.nativeapp.codegen.snowpark.models import ( + ExtensionFunctionTypeEnum, + NativeAppExtensionFunction, +) class MalformedExtensionFunctionError(ClickException): @@ -21,341 +27,59 @@ def __init__(self, message: str): super().__init__(message=message) -# This prefix is created by Snowpark for an extension function when the user has not supplied any themselves. -# TODO: move to sandbox execution to omit object name in this case: https://github.com/snowflakedb/snowflake-cli/pull/1056/files#r1599784063 -TEMP_OBJECT_NAME_PREFIX = "SNOWPARK_TEMP_" - - -def get_object_type_as_text(name: str) -> str: - """ - Replace underscores with spaces in a given string. - - Parameters: - name (str): Any arbitrary string - Returns: - A string that has replaced underscores with spaces. - """ - return name.replace("_", " ") - - -def _sanitize_str_attribute( - ex_fn: Dict[str, Any], - attr: str, - make_uppercase: bool = False, - py_file: Optional[Path] = None, - raise_err: bool = False, -): - """ - Sanitizes a single key-value pair of the specified dictionary. As part of the sanitization, - it goes through a few checks. A key must be created if it does not already exist. - Then, it checks the type of the value of the key, i.e. if it is of type str, and if it contains any leading or trailing whitespaces. - A user is able to specify if they want to re-assign a key to an uppercase instance of its original value. - If any of the sanitization checks fail and the user wants to raise an error, it throws a MalformedExtensionFunctionError. - """ - assign_to_none = True - if ex_fn.get(attr, None): - if not isinstance(ex_fn[attr], str): - raise MalformedExtensionFunctionError( - f"Attribute '{attr}' of extension function must be of type 'str'." - ) - - if ( - len(ex_fn[attr].strip()) > 0 - ): # To prevent where attr value is " " etc, which should still be invalid - assign_to_none = False - if make_uppercase: - ex_fn[attr] = ex_fn[attr].upper() - - if assign_to_none: - ex_fn[attr] = None - - if assign_to_none and raise_err: - raise MalformedExtensionFunctionError( - _create_missing_attr_message(attribute=attr, py_file=py_file) - ) - - -def _sanitize_list_or_dict_attribute( - ex_fn: Dict[str, Any], - attr: str, - expected_type: Type, - default_value: Any = None, - py_file: Optional[Path] = None, - raise_err: bool = False, -): - """ - Sanitizes a single key-value pair of the specified dictionary. As part of the sanitization, - it goes through a few checks. A key must be created if it does not already exist. - Then, it checks the type of the value of the key. It also checks for the length of the value, which is why the value must be of type list or dict. - A user is able to specity a default value that they want to assign a newly created key to. - If any of the sanitization checks fail and the user wants to raise an error, it throws a MalformedExtensionFunctionError. - """ - assign_to_default = True - if ex_fn.get(attr, None): - if not isinstance(ex_fn[attr], expected_type): - raise MalformedExtensionFunctionError( - f"Attribute '{attr}' of extension function must be of type '{expected_type}'." - ) - - if len(ex_fn[attr]) > 0: - assign_to_default = False - - if assign_to_default: - ex_fn[attr] = default_value - - if assign_to_default and raise_err: - raise MalformedExtensionFunctionError( - _create_missing_attr_message(attribute=attr, py_file=py_file) - ) - - -def _create_missing_attr_message(attribute: str, py_file: Optional[Path]): - """ - This message string is used to create an instance of the MalformedExtensionFunctionError. - """ - if py_file is None: - raise ValueError("Python file path must not be None.") - return f"Required attribute '{attribute}' of extension function is missing or incorrectly defined for an extension function defined in python file {py_file.absolute()}." - - -def _is_function_wellformed(ex_fn: Dict[str, Any]) -> bool: - """ - Checks if the specified dictionary contains a key called 'func'. - if it does, then the value must be of type str or a list of fixed size 2. - It further checks the item at 1st index of this list. - """ - if ex_fn.get("func", None): - if isinstance(ex_fn["func"], str): - return ex_fn["func"].strip() != "" - elif isinstance(ex_fn["func"], list) and (len(ex_fn["func"]) == 2): - return isinstance(ex_fn["func"][1], str) and ex_fn["func"][1].strip() != "" - return False - - -def sanitize_extension_function_data(ex_fn: Dict[str, Any], py_file: Path): - """ - Helper function to sanitize different attributes of a dictionary. As part of the sanitization, validations and default assignments are performed. - This helper function is needed because different callback sources can create dictionaries with different/missing keys, and since - Snowflake CLI may not own all callback implementations, the dictionaries need to have the minimum set of keys and their default - values to be used in creation of the SQL DDL statements. - - Parameters: - ex_fn (Dict[str, Any]): A dictionary of key value pairs to sanitize - py_file (Path): The python file from which this dictionary was created. - Returns: - A boolean value, True if everything has been successfully validated and assigned, False if an error was encountered. - """ - # TODO: accumulate errors/warnings instead of per-attribute interruption: https://github.com/snowflakedb/snowflake-cli/pull/1056/files#r1599904008 - - # Must have keys to create an extension function in SQL for Native Apps - _sanitize_str_attribute( - ex_fn=ex_fn, - attr="object_type", - make_uppercase=True, - py_file=py_file, - raise_err=True, - ) - - _sanitize_str_attribute( - ex_fn=ex_fn, - attr="object_name", - make_uppercase=True, - py_file=py_file, - raise_err=True, - ) - - _sanitize_str_attribute( - ex_fn=ex_fn, - attr="return_sql", - make_uppercase=True, - py_file=py_file, - raise_err=True, - ) - - if not _is_function_wellformed(ex_fn=ex_fn): - raise MalformedExtensionFunctionError( - _create_missing_attr_message(attribute="func", py_file=py_file) - ) - - default_raw_imports: List[Union[str, Tuple[str, str]]] = [] - _sanitize_list_or_dict_attribute( - ex_fn=ex_fn, - attr="raw_imports", - expected_type=list, - default_value=default_raw_imports, - py_file=py_file, - raise_err=True, - ) - - _sanitize_str_attribute(ex_fn=ex_fn, attr="schema", make_uppercase=True) - # Custom message, hence throwing an error separately - if ex_fn["schema"] is None: - raise MalformedExtensionFunctionError( - f"Required attribute 'schema' in 'native_app_params' of extension function is missing for an extension function defined in python file {py_file.absolute()}." - ) - - # Other optional keys - ex_fn["anonymous"] = ex_fn.get("anonymous", False) - ex_fn["replace"] = ex_fn.get("replace", False) - ex_fn["if_not_exists"] = ex_fn.get("if_not_exists", False) +def get_sql_object_type(extension_fn: NativeAppExtensionFunction) -> Optional[str]: + if extension_fn.function_type == ExtensionFunctionTypeEnum.PROCEDURE: + return "PROCEDURE" + elif extension_fn.function_type in ( + ExtensionFunctionTypeEnum.FUNCTION, + ExtensionFunctionTypeEnum.TABLE_FUNCTION, + ): + return "FUNCTION" + elif extension_fn.function_type == extension_fn.function_type.AGGREGATE_FUNCTION: + return "AGGREGATE FUNCTION" + else: + return None - if ex_fn["replace"] and ex_fn["if_not_exists"]: - raise MalformedExtensionFunctionError( - "Options 'replace' and 'if_not_exists' are incompatible." - ) - default_input_args: List[Dict[str, Any]] = [] - _sanitize_list_or_dict_attribute( - ex_fn=ex_fn, - attr="input_args", - expected_type=list, - default_value=default_input_args, - ) - default_input_types: List[str] = [] - _sanitize_list_or_dict_attribute( - ex_fn=ex_fn, - attr="input_sql_types", - expected_type=list, - default_value=default_input_types, - ) - if len(ex_fn["input_args"]) != len(ex_fn["input_sql_types"]): - raise MalformedExtensionFunctionError( - "The number of extension function parameters does not match the number of parameter types." - ) +def get_sql_argument_signature(arg: Argument) -> str: + formatted = f"{arg.name} {arg.arg_type}" + if arg.default is not None: + formatted = f"{formatted} DEFAULT {arg.default}" + return formatted - _sanitize_str_attribute(ex_fn=ex_fn, attr="all_imports") - _sanitize_str_attribute(ex_fn=ex_fn, attr="all_packages") - _sanitize_list_or_dict_attribute( - ex_fn=ex_fn, attr="external_access_integrations", expected_type=list - ) - _sanitize_list_or_dict_attribute(ex_fn=ex_fn, attr="secrets", expected_type=dict) - _sanitize_str_attribute(ex_fn=ex_fn, attr="inline_python_code") - _sanitize_str_attribute(ex_fn=ex_fn, attr="execute_as", make_uppercase=True) - _sanitize_str_attribute(ex_fn=ex_fn, attr="handler") - _sanitize_str_attribute( - ex_fn=ex_fn, attr="runtime_version", py_file=py_file, raise_err=True - ) - has_app_roles = ( - ex_fn.get("application_roles", None) and len(ex_fn["application_roles"]) > 0 - ) - if has_app_roles: - if all(isinstance(app_role, str) for app_role in ex_fn["application_roles"]): - ex_fn["application_roles"] = [ - app_role.upper() for app_role in ex_fn["application_roles"] - ] +def get_qualified_object_name(extension_fn: NativeAppExtensionFunction) -> str: + qualified_name = to_identifier(extension_fn.name) + if extension_fn.schema_name: + if is_valid_identifier(extension_fn.schema_name): + qualified_name = f"{extension_fn.schema_name}.{qualified_name}" else: - raise MalformedExtensionFunctionError( - f"Attribute 'application_roles' of extension function must be a list of strings." + full_schema = ".".join( + [ + to_identifier(schema_part) + for schema_part in extension_fn.schema_name.split(".") + ] ) - else: - ex_fn["application_roles"] = [] - - -def _get_handler_path_without_suffix( - file_path: Path, deploy_root: Path, suffix_str_to_rm: Optional[str] = None -) -> str: - """ - Get a handler for an extension function based on the file path on the stage. If a specific suffix needs to be removed from the path, - then that is also taken into account. - """ - return "NotImplementedHandler" - + qualified_name = f"{full_schema}.{qualified_name}" -def _get_handler( - dest_file: Path, func: Union[str, Tuple[str, str]], deploy_root: Path -) -> Optional[str]: - """ - Gets the handler for the extension function to be used in the creation of the SQL statement. - """ - if isinstance(func, str): - return f"{_get_handler_path_without_suffix(file_path=dest_file, suffix_str_to_rm='.py', deploy_root=deploy_root)}.{func}" - else: - # isinstance(self.func, Tuple[str, str]) is only possible if using decorator.register_from_file(), which is not allowed in codegen as of now. - # When allowed, refer to https://github.com/snowflakedb/snowpark-python/blob/v1.15.0/src/snowflake/snowpark/_internal/udf_utils.py#L1092 on resolving handler name - raise MalformedExtensionFunctionError( - f"Could not determine handler name for {func[1]}." - ) + return qualified_name -def _get_schema_and_name_for_extension_function( - object_name: str, schema: Optional[str], func: str -) -> Optional[str]: +def ensure_string_literal(value: str) -> str: """ - Gets the name of the extension function to be used in the creation of the SQL statement. - It will use the schema and the python function's name as the object name if the function name is determined to be a Snowpark-generated placeholder. - Otherwise, it will honor the user's input for object name. + Returns the string literal representation of the given value, or the value itself if + it was already a valid string literal. """ - if object_name.startswith(TEMP_OBJECT_NAME_PREFIX): - return f"{schema}.{func}" if schema else func - else: - return f"{schema}.{object_name}" if schema else object_name + if is_valid_string_literal(value): + return value + return to_string_literal(value) -def _is_single_quoted(name: str) -> bool: - """ - Helper function to do a generic check on whether the provided string is surrounded by single quotes. - """ - return name.startswith("'") and name.endswith("'") - - -def _ensure_single_quoted(obj_lst: List[str]) -> List[str]: - """ - Helper function to ensure that a list of object strings is transformed to a list of object strings surrounded by single quotes. +def ensure_all_string_literals(values: Sequence[str]) -> List[str]: """ - return [obj if _is_single_quoted(obj) else f"'{obj}'" for obj in obj_lst] + Ensures that all provided values are valid string literals. - -def _get_all_imports(raw_imports: List[Union[str, Tuple[str, str]]]) -> str: - """ - Creates a string containing all the relevant imports for an extension function. This string is used in the creation of the SQL statement. - - Parameters: - raw_imports (List[Union[str, Tuple[str, str]]]): The raw imports that will be used to create the final string. - The function needes to handle different input types, similar to snowpark. - Example 1: [("tests/resources/test_udf_dir/test_udf_file.py", "resources.test_udf_dir.test_udf_file")] - Example 2: session.add_import("tests/resources/test_udf_dir/test_udf_file.py") - Example 3: session.add_import("tests/resources/test_udf_dir/test_udf_file.py", import_path="resources.test_udf_dir.test_udf_file") - Returns: - A string containing all the imports. - """ - all_urls: List[str] = [] - for raw_import in raw_imports: # Example 1 - if isinstance(raw_import, str): # Example 2 - all_urls.append(raw_import) - else: # Example 3 - local_path = Path(raw_import[0]) - stage_import = raw_import[1] - local_path_suffix = local_path.suffix - if local_path_suffix != "": - # 1. We use suffix check here instead of local_path.is_file() as local_path may not exist, making is_file() False. - # We do not provide validation on local_path existing, and hence should not fail or treat it differently than any other file. - # 2. stage_import may already have a suffix, but we do not provide validation on it. - # It is on the user to know and use Snowpark's decorator attributes correctly. - without_suffix = "/".join(stage_import.split(".")) - all_urls.append(f"{without_suffix}{local_path_suffix}") - else: - file_path = "/".join(stage_import.split(".")) - all_urls.append(file_path) - return ",".join(_ensure_single_quoted(all_urls)) - - -def enrich_ex_fn(ex_fn: Dict[str, Any], py_file: Path, deploy_root: Path): - """ - Sets additional properties for a given extension function dictionary, that could not be set earlier due to missing information or limited access to the execution context. - Parameters: - ex_fn (Dict[str, Any]): A dictionary of key value pairs to sanitize - py_file (Path): The python file from which this dictionary was created. - deploy_root (Path): The deploy root of the the project. Returns: - The original but edited extension function dictionary + A list with all values transformed to be valid string literals (as necessary). """ - ex_fn["handler"] = _get_handler( - dest_file=py_file, func=ex_fn["func"], deploy_root=deploy_root - ) - ex_fn["object_name"] = _get_schema_and_name_for_extension_function( - object_name=ex_fn["object_name"], - schema=ex_fn["schema"], - func=ex_fn["func"], - ) - ex_fn["all_imports"] = _get_all_imports(raw_imports=ex_fn["raw_imports"] or []) + return [ensure_string_literal(value) for value in values] diff --git a/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/models.py b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/models.py new file mode 100644 index 0000000000..e74f755d1a --- /dev/null +++ b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/models.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from enum import Enum +from typing import List, Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.snowpark.callable import _CallableBase +from snowflake.cli.api.project.schemas.updatable_model import IdentifierField + + +class ExtensionFunctionTypeEnum(str, Enum): + PROCEDURE = "procedure" + FUNCTION = "function" + TABLE_FUNCTION = "table function" + AGGREGATE_FUNCTION = "aggregate function" + + +class NativeAppExtensionFunction(_CallableBase): + function_type: ExtensionFunctionTypeEnum = Field( + title="The type of extension function, one of 'procedure', 'function', 'table function' or 'aggregate function'.", + alias="type", + ) + lineno: Optional[int] = Field( + title="The line number of the extension function", default=None + ) + name: Optional[str] = Field( + title="The name of the extension function", default=None + ) + packages: Optional[List[str]] = Field( + title="List of packages (with optional version constraints) to be loaded for the function", + default={}, + ) + schema_name: Optional[str] = IdentifierField( + title=f"Name of the schema for the function", + default=None, + alias="schema", + ) + application_roles: Optional[List[str]] = Field( + title="Application roles granted usage to the function", + default=[], + ) + execute_as_caller: Optional[bool] = Field( + title="Determine whether the procedure is executed with the privileges of " + "the owner or with the privileges of the caller", + default=False, + ) diff --git a/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py index 7193e8274f..68bd16d293 100644 --- a/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py +++ b/src/snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.schemas.native_app.path_mapping import ( PathMapping, @@ -23,10 +24,17 @@ execute_script_in_sandbox, ) from snowflake.cli.plugins.nativeapp.codegen.snowpark.extension_function_utils import ( - enrich_ex_fn, - get_object_type_as_text, - sanitize_extension_function_data, + ensure_all_string_literals, + ensure_string_literal, + get_qualified_object_name, + get_sql_argument_signature, + get_sql_object_type, ) +from snowflake.cli.plugins.nativeapp.codegen.snowpark.models import ( + ExtensionFunctionTypeEnum, + NativeAppExtensionFunction, +) +from snowflake.cli.plugins.stage.diff import to_stage_path DEFAULT_TIMEOUT = 30 TEMPLATE_PATH = Path(__file__).parent / "callback_source.py.jinja" @@ -83,6 +91,10 @@ def _determine_virtual_env( return {} +def _is_python_file_artifact(src: Path, dest: Path): + return src.is_file() and src.suffix == ".py" + + def _execute_in_sandbox( py_file: str, deploy_root: Path, kwargs: Dict[str, Any] ) -> Optional[List[Dict[str, Any]]]: @@ -98,8 +110,6 @@ def _execute_in_sandbox( timeout=DEFAULT_TIMEOUT, **kwargs, ) - cc.message(f"stdout: {completed_process.stdout}") - cc.message(f"stderr: {completed_process.stderr}") except SandboxExecutionError as sdbx_err: cc.warning( f"Could not fetch Snowpark objects from {py_file} due to {sdbx_err}, continuing execution for the rest of the python files." @@ -154,217 +164,182 @@ def process( artifact_to_process: PathMapping, processor_mapping: Optional[ProcessorMapping], **kwargs, - ) -> Dict[Path, str]: + ) -> str: # String output is temporary until we have better e2e testing mechanism """ - Intended to be the main method which can perform all relevant processing, and/or write to a target file, which depends on the type of processor. - For SnowparkAnnotationProcessor, the target file is the setup script. + Collects code annotations from Snowpark python files containing extension functions and augments the existing + setup script with generated SQL that registers these functions. """ - kwargs = ( - _determine_virtual_env(self.project_root, processor_mapping) - if processor_mapping is not None - else {} - ) - - # 1. Get the artifact src to dest mapping bundle_map = BundleMap( project_root=self.project_root, deploy_root=self.deploy_root ) bundle_map.add(artifact_to_process) - # 2. Get raw extension functions through Snowpark callback - dest_file_py_file_to_collected_raw_ex_fns: Dict[Path, Optional[Any]] = {} - - def is_python_file_artifact(src: Path, dest: Path): - return src.is_file() and src.suffix == ".py" - - for src_file, dest_file in bundle_map.all_mappings( - absolute=True, expand_directories=True, predicate=is_python_file_artifact - ): - try: - collected_raw_ex_fns = _execute_in_sandbox( - py_file=str(dest_file.resolve()), - deploy_root=self.deploy_root, - kwargs=kwargs, - ) - except Exception as exc: - cc.warning( - f"Error processing extension functions in {src_file}: {exc}" - ) # Display the actual file for the user to inspect - cc.warning("Skipping generating code of all objects from this file.") - collected_raw_ex_fns = None - - if not collected_raw_ex_fns: - continue + collected_extension_functions_by_path = self.collect_extension_functions( + bundle_map, processor_mapping + ) - cc.message(f"This is the file path in deploy root: {dest_file}\n") - cc.message("This is the list of collected extension functions:") - cc.message(pprint.pformat(collected_raw_ex_fns)) + collected_output = [] + for py_file, extension_fns in collected_extension_functions_by_path.items(): + for extension_fn in extension_fns: + create_stmt = generate_create_sql_ddl_statement(extension_fn) + if create_stmt is None: + continue - filtered_collection = list( - filter( - lambda item: (item is not None) and (len(item) > 0), - collected_raw_ex_fns, - ) - ) - if len(filtered_collection) != len(collected_raw_ex_fns): - cc.warning( - "Discovered extension functions that have value None or do not contain any information." + cc.message( + "-- Generating Snowpark annotation SQL code for {}".format(py_file) ) - cc.warning( - "Skipping generating code of all such objects from this file." + cc.message(create_stmt) + collected_output.append( + f"-- {py_file.relative_to(bundle_map.deploy_root())}" ) + collected_output.append(create_stmt) - # 4. Enrich the raw extension functions by setting additional properties - for raw_ex_fn in filtered_collection: - sanitize_extension_function_data(ex_fn=raw_ex_fn, py_file=dest_file) - enrich_ex_fn( - ex_fn=raw_ex_fn, - py_file=dest_file, - deploy_root=self.deploy_root, - ) - - dest_file_py_file_to_collected_raw_ex_fns[dest_file] = filtered_collection + grant_statements = generate_grant_sql_ddl_statements(extension_fn) + if grant_statements is not None: + cc.message(grant_statements) + collected_output.append(grant_statements) - # For each extension function, generate its related SQL statements - dest_file_py_file_to_ddl_map: Dict[ - Path, str - ] = self.generate_sql_ddl_statements(dest_file_py_file_to_collected_raw_ex_fns) + return "\n".join(collected_output) - # TODO: Temporary for testing, while feature is being built in phases - return dest_file_py_file_to_ddl_map + def _normalize(self, extension_fn: NativeAppExtensionFunction, py_file: Path): + if extension_fn.name is None: + # The extension function was not named explicitly, use the name of the Python function object as its name + extension_fn.name = extension_fn.handler - def generate_sql_ddl_statements( - self, dest_file_py_file_to_collected_raw_ex_fns: Dict[Path, Optional[Any]] - ) -> Dict[Path, str]: - """ - Generates SQL DDL statements based on the entities collected from a set of python files in the artifact_to_process. - """ - dest_file_py_file_to_ddl_map: Dict[Path, str] = {} - for py_file in dest_file_py_file_to_collected_raw_ex_fns: + # Compute the fully qualified handler + extension_fn.handler = f"{py_file.stem}.{extension_fn.handler}" - collected_ex_fns = dest_file_py_file_to_collected_raw_ex_fns[ - py_file - ] # Collected entities is List[Dict[str, Any]] - if collected_ex_fns is None: - continue + if extension_fn.imports is None: + extension_fn.imports = [] + extension_fn.imports.append(f"/{to_stage_path(py_file)}") - ddl_lst_per_ef: List[str] = [] - for ex_fn in collected_ex_fns: - create_sql = generate_create_sql_ddl_statements(ex_fn) - if create_sql: - ddl_lst_per_ef.append(create_sql) - grant_sql = generate_grant_sql_ddl_statements(ex_fn) - if grant_sql: - ddl_lst_per_ef.append(grant_sql) + def collect_extension_functions( + self, bundle_map: BundleMap, processor_mapping: Optional[ProcessorMapping] + ) -> Dict[Path, List[NativeAppExtensionFunction]]: + kwargs = ( + _determine_virtual_env(self.project_root, processor_mapping) + if processor_mapping is not None + else {} + ) - if len(ddl_lst_per_ef) > 0: - dest_file_py_file_to_ddl_map[py_file] = "\n".join(ddl_lst_per_ef) + collected_extension_fns_by_path: Dict[ + Path, List[NativeAppExtensionFunction] + ] = {} - return dest_file_py_file_to_ddl_map + for src_file, dest_file in bundle_map.all_mappings( + absolute=True, expand_directories=True, predicate=_is_python_file_artifact + ): + collected_extension_function_json = _execute_in_sandbox( + py_file=str(dest_file.resolve()), + deploy_root=self.deploy_root, + kwargs=kwargs, + ) + if collected_extension_function_json is None: + cc.warning(f"Error processing extension functions in {src_file}") + cc.warning("Skipping generating code of all objects from this file.") + continue -def generate_create_sql_ddl_statements(ex_fn: Dict[str, Any]) -> Optional[str]: + collected_extension_functions = [] + for extension_function_json in collected_extension_function_json: + try: + extension_fn = NativeAppExtensionFunction(**extension_function_json) + self._normalize( + extension_fn, + py_file=dest_file.relative_to(bundle_map.deploy_root()), + ) + collected_extension_functions.append(extension_fn) + except SchemaValidationError: + cc.warning("Invalid extension function definition") + + if collected_extension_functions: + cc.message(f"This is the file path in deploy root: {dest_file}\n") + cc.message("This is the list of collected extension functions:") + cc.message(pprint.pformat(collected_extension_functions)) + + collected_extension_fns_by_path[ + dest_file + ] = collected_extension_functions + + return collected_extension_fns_by_path + + +def generate_create_sql_ddl_statement( + extension_fn: NativeAppExtensionFunction, +) -> Optional[str]: """ - Generates a "CREATE FUNCTION/PROCEDURE ... " SQL DDL statement based on a dictionary of extension function properties. + Generates a "CREATE FUNCTION/PROCEDURE ... " SQL DDL statement based on an extension function definition. Logic for this create statement has been lifted from snowflake-snowpark-python v1.15.0 package. - Anonymous procedures are not allowed in Native Apps, and hence if a user passes in the two corresponding properties, - this function will skip the DDL generation. """ - object_type = ex_fn["object_type"] - object_name = ex_fn["object_name"] - - if object_type == "PROCEDURE" and ex_fn["anonymous"]: - cc.warning( - dedent( - f"""{object_type.replace(' ', '-')} {object_name} cannot be an anonymous procedure in a Snowflake Native App. - Skipping generation of 'CREATE FUNCTION/PROCEDURE ...' SQL statement for this object.""" - ) - ) + object_type = get_sql_object_type(extension_fn) + if object_type is None: + cc.warning(f"Unsupported extension function type: {extension_fn.function_type}") return None - replace_in_sql = f" OR REPLACE " if ex_fn["replace"] else "" - - sql_func_args = ",".join( - [ - f"{a['name']} {t}" - for a, t in zip(ex_fn["input_args"], ex_fn["input_sql_types"]) - ] + arguments_in_sql = ", ".join( + [get_sql_argument_signature(arg) for arg in extension_fn.signature] ) - imports_in_sql = ( - f"\nIMPORTS=({ex_fn['all_imports']})" if ex_fn["all_imports"] else "" - ) + create_query = dedent( + f""" + CREATE OR REPLACE + {object_type} {get_qualified_object_name(extension_fn)}({arguments_in_sql}) + RETURNS {extension_fn.returns} + LANGUAGE PYTHON + RUNTIME_VERSION={extension_fn.runtime} + """ + ).strip() - packages_in_sql = ( - f"\nPACKAGES=({ex_fn['all_packages']})" if ex_fn["all_packages"] else "" - ) + if extension_fn.imports: + create_query += ( + f"\nIMPORTS=({', '.join(ensure_all_string_literals(extension_fn.imports))})" + ) - external_access_integrations = ex_fn["external_access_integrations"] - external_access_integrations_in_sql = ( - f"""\nEXTERNAL_ACCESS_INTEGRATIONS=({','.join(external_access_integrations)})""" - if external_access_integrations - else "" - ) + if extension_fn.packages: + create_query += f"\nPACKAGES=({', '.join(ensure_all_string_literals(extension_fn.packages))})" - secrets = ex_fn["secrets"] - secrets_in_sql = ( - f"""\nSECRETS=({",".join([f"'{k}'={v}" for k, v in secrets.items()])})""" - if secrets - else "" - ) + if extension_fn.external_access_integrations: + create_query += f"\nEXTERNAL_ACCESS_INTEGRATIONS=({', '.join(ensure_all_string_literals(extension_fn.external_access_integrations))})" + + if extension_fn.secrets: + create_query += f"""\nSECRETS=({', '.join([f"{ensure_string_literal(k)}={v}" for k, v in extension_fn.secrets.items()])})""" + + create_query += f"\nHANDLER={ensure_string_literal(extension_fn.handler)}" - execute_as = ex_fn["execute_as"] - if execute_as is None: - execute_as_sql = "" - else: - execute_as_sql = f"\nEXECUTE AS {execute_as}" - - inline_python_code = ex_fn["inline_python_code"] - if inline_python_code: - inline_python_code_in_sql = f"""\ -AS $$ -{inline_python_code} -$$ -""" - else: - inline_python_code_in_sql = "" - - create_query = f"""\ -CREATE{replace_in_sql} -{get_object_type_as_text(object_type)} {'IF NOT EXISTS' if ex_fn["if_not_exists"] else ''}{object_name}({sql_func_args}) -{ex_fn["return_sql"]} -LANGUAGE PYTHON -RUNTIME_VERSION={ex_fn["runtime_version"]} {imports_in_sql}{packages_in_sql}{external_access_integrations_in_sql}{secrets_in_sql} -HANDLER='{ex_fn["handler"]}'{execute_as_sql} -{inline_python_code_in_sql}""" + if extension_fn.function_type == ExtensionFunctionTypeEnum.PROCEDURE: + if extension_fn.execute_as_caller: + create_query += f"\nEXECUTE AS CALLER" + else: + create_query += f"\nEXECUTE AS OWNER" + create_query += ";\n" return create_query -def generate_grant_sql_ddl_statements(ex_fn: Dict[str, Any]) -> Optional[str]: +def generate_grant_sql_ddl_statements( + extension_fn: NativeAppExtensionFunction, +) -> Optional[str]: """ Generates a "GRANT USAGE TO ... " SQL DDL statement based on a dictionary of extension function properties. If no application roles are present, then the function returns None. """ - if ex_fn["application_roles"] is None: + if not extension_fn.application_roles: cc.warning( "Skipping generation of 'GRANT USAGE ON ...' SQL statement for this object due to lack of application roles." ) return None grant_sql_statements = [] - for app_role in ex_fn["application_roles"]: + for app_role in extension_fn.application_roles: grant_sql_statement = dedent( f"""\ - GRANT USAGE ON {get_object_type_as_text(ex_fn["object_type"])} {ex_fn["object_name"]} + GRANT USAGE ON {get_sql_object_type(extension_fn)} {get_qualified_object_name(extension_fn)} TO APPLICATION ROLE {app_role}; """ - ) + ).strip() grant_sql_statements.append(grant_sql_statement) - if len(grant_sql_statements) == 0: - return None return "\n".join(grant_sql_statements) diff --git a/tests/nativeapp/codegen/snowpark/__snapshots__/test_python_processor.ambr b/tests/nativeapp/codegen/snowpark/__snapshots__/test_python_processor.ambr index 625a31c7f6..e60ae6761b 100644 --- a/tests/nativeapp/codegen/snowpark/__snapshots__/test_python_processor.ambr +++ b/tests/nativeapp/codegen/snowpark/__snapshots__/test_python_processor.ambr @@ -1,116 +1,78 @@ # serializer version: 1 # name: test_generate_create_sql_ddl_statements_w_all_entries ''' - CREATE OR REPLACE - TABLE FUNCTION SNOWPARK_TEMP_FUNCTION_WZUNHMZJKA(arg1 INT) - RETURNS INT + CREATE OR REPLACE + PROCEDURE DATA.my_function(first int DEFAULT 42) + RETURNS int LANGUAGE PYTHON - RUNTIME_VERSION=3.11 - IMPORTS=('path_one', 'path_two') - PACKAGES=('package_one', 'package_two') - EXTERNAL_ACCESS_INTEGRATIONS=(integration_one,integration_two) - SECRETS=('key1'=secret_one,'key2'=integration_two) - HANDLER='dummy_handler' - EXECUTE AS OWNER - AS $$ - dummy_inline_code - $$ + RUNTIME_VERSION=3.11 + IMPORTS=('/path/to/import1.py', '/path/to/import2.zip') + PACKAGES=('package_one==1.0.2', 'package_two') + EXTERNAL_ACCESS_INTEGRATIONS=('integration_one', 'integration_two') + SECRETS=('key1'=secret_one, 'key2'=integration_two) + HANDLER='my_function_handler' + EXECUTE AS OWNER; ''' # --- # name: test_generate_create_sql_ddl_statements_w_select_entries ''' - CREATE - TABLE FUNCTION SNOWPARK_TEMP_FUNCTION_WZUNHMZJKA(arg1 INT) - RETURNS INT + CREATE OR REPLACE + PROCEDURE my_function(first int DEFAULT 42) + RETURNS int LANGUAGE PYTHON - RUNTIME_VERSION=3.11 - HANDLER='dummy_handler' + RUNTIME_VERSION=3.11 + HANDLER='my_function_handler' + EXECUTE AS OWNER; ''' # --- # name: test_generate_grant_sql_ddl_statements ''' - GRANT USAGE ON TABLE FUNCTION CORE.MYFUNC + GRANT USAGE ON PROCEDURE DATA.my_function TO APPLICATION ROLE APP_ADMIN; - - GRANT USAGE ON TABLE FUNCTION CORE.MYFUNC + GRANT USAGE ON PROCEDURE DATA.my_function TO APPLICATION ROLE APP_VIEWER; - ''' # --- -# name: test_generate_sql_ddl_statements - ''' - CREATE OR REPLACE - TABLE FUNCTION DATA.(arg1 INT) - RETURNS INT - LANGUAGE PYTHON - RUNTIME_VERSION=3.11 - IMPORTS=('a/b/c.py') - PACKAGES=('package_one', 'package_two') - EXTERNAL_ACCESS_INTEGRATIONS=(integration_one,integration_two) - SECRETS=('key1'=secret_one,'key2'=integration_two) - HANDLER='NotImplementedHandler.' - EXECUTE AS OWNER - AS $$ - dummy_inline_code - $$ - - GRANT USAGE ON TABLE FUNCTION DATA. - TO APPLICATION ROLE APP_ADMIN; - - GRANT USAGE ON TABLE FUNCTION DATA. - TO APPLICATION ROLE APP_VIEWER; - - ''' +# name: test_process_no_collected_functions + '' # --- -# name: test_generate_sql_ddl_statements.1 +# name: test_process_with_collected_functions ''' - CREATE OR REPLACE - TABLE FUNCTION DATA.(arg1 INT) - RETURNS INT + -- stagepath/main.py + CREATE OR REPLACE + PROCEDURE DATA.my_function(first int DEFAULT 42) + RETURNS int LANGUAGE PYTHON - RUNTIME_VERSION=3.11 - IMPORTS=('a/b/c.py') - PACKAGES=('package_one', 'package_two') - EXTERNAL_ACCESS_INTEGRATIONS=(integration_one,integration_two) - SECRETS=('key1'=secret_one,'key2'=integration_two) - HANDLER='NotImplementedHandler.' - EXECUTE AS OWNER - AS $$ - dummy_inline_code - $$ + RUNTIME_VERSION=3.11 + IMPORTS=('/path/to/import1.py', '/path/to/import2.zip', '/stagepath/main.py') + PACKAGES=('package_one==1.0.2', 'package_two') + EXTERNAL_ACCESS_INTEGRATIONS=('integration_one', 'integration_two') + SECRETS=('key1'=secret_one, 'key2'=integration_two) + HANDLER='main.my_function_handler' + EXECUTE AS OWNER; - GRANT USAGE ON TABLE FUNCTION DATA. + GRANT USAGE ON PROCEDURE DATA.my_function TO APPLICATION ROLE APP_ADMIN; - - GRANT USAGE ON TABLE FUNCTION DATA. + GRANT USAGE ON PROCEDURE DATA.my_function TO APPLICATION ROLE APP_VIEWER; - - ''' -# --- -# name: test_generate_sql_ddl_statements_filtered_create - ''' - CREATE OR REPLACE - TABLE FUNCTION DATA.(arg1 INT) - RETURNS INT + -- stagepath/data.py + CREATE OR REPLACE + PROCEDURE DATA.my_function(first int DEFAULT 42) + RETURNS int LANGUAGE PYTHON - RUNTIME_VERSION=3.11 - IMPORTS=('a/b/c.py') - PACKAGES=('package_one', 'package_two') - EXTERNAL_ACCESS_INTEGRATIONS=(integration_one,integration_two) - SECRETS=('key1'=secret_one,'key2'=integration_two) - HANDLER='NotImplementedHandler.' - EXECUTE AS OWNER - AS $$ - dummy_inline_code - $$ + RUNTIME_VERSION=3.11 + IMPORTS=('/path/to/import1.py', '/path/to/import2.zip', '/stagepath/data.py') + PACKAGES=('package_one==1.0.2', 'package_two') + EXTERNAL_ACCESS_INTEGRATIONS=('integration_one', 'integration_two') + SECRETS=('key1'=secret_one, 'key2'=integration_two) + HANDLER='data.my_function_handler' + EXECUTE AS OWNER; - GRANT USAGE ON TABLE FUNCTION DATA. + GRANT USAGE ON PROCEDURE DATA.my_function TO APPLICATION ROLE APP_ADMIN; - - GRANT USAGE ON TABLE FUNCTION DATA. + GRANT USAGE ON PROCEDURE DATA.my_function TO APPLICATION ROLE APP_VIEWER; - ''' # --- diff --git a/tests/nativeapp/codegen/snowpark/test_extension_function_utils.py b/tests/nativeapp/codegen/snowpark/test_extension_function_utils.py index d58939a1aa..bc41373a66 100644 --- a/tests/nativeapp/codegen/snowpark/test_extension_function_utils.py +++ b/tests/nativeapp/codegen/snowpark/test_extension_function_utils.py @@ -1,199 +1,88 @@ -from pathlib import Path - import pytest import snowflake.cli.plugins.nativeapp.codegen.snowpark.extension_function_utils as ef_utils - -# -------------------------------------------------------- -# ------------- get_object_type_as_text ------------------ -# -------------------------------------------------------- +from snowflake.cli.api.project.schemas.snowpark.argument import Argument +from snowflake.cli.plugins.nativeapp.codegen.snowpark.models import ( + NativeAppExtensionFunction, +) @pytest.mark.parametrize( - ("input_param, expected"), + "function_type, expected", [ - ("TABLE_FUNCTION", "TABLE FUNCTION"), - ("FUNCTION", "FUNCTION"), - ("AGGREGATE-FUNCTION", "AGGREGATE-FUNCTION"), + ("procedure", "PROCEDURE"), + ("function", "FUNCTION"), + ("aggregate function", "AGGREGATE FUNCTION"), + ("table function", "FUNCTION"), ], ) -def test_get_object_type_as_text(input_param, expected): - actual = ef_utils.get_object_type_as_text(input_param) - assert actual == expected - - -# -------------------------------------------------------- -# --------- sanitize_extension_function_data ------------- -# -------------------------------------------------------- - - -def test_sanitize_extension_function_data_required_keys(snapshot): - """ - This test will start off with an empty dictionary which will act as the extension function. - With every exception sanitize_extension_function_data() hits, we add in the required info - to progress to the next exception/execution. - """ - ex_fn = {} - some_path = Path("some/path") - - # Test for absence or malformed object_type, object_name and return_sql - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "object_type" in err.value.message - - ex_fn["object_type"] = ["function"] - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "object_type" in err.value.message - - ex_fn["object_type"] = "function" - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "object_name" in err.value.message - - ex_fn["object_name"] = "some_name" - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "return_sql" in err.value.message - - ex_fn["return_sql"] = "returns null" - - # Test for absence of func, and malformed func - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "func" in err.value.message - - wrong_func_possibilities = [[], "", " ", (None, ""), (None, " ")] - for val in wrong_func_possibilities: - ex_fn["func"] = val - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "func" in err.value.message - - right_func_possibilities = [[None, "dummy"], "dummy"] - for val in right_func_possibilities: - ex_fn["func"] = val - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "raw_imports" in err.value.message - - # Test for absence or malformed schema and runtime_version - ex_fn["raw_imports"] = {""} - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "raw_imports" in err.value.message - - ex_fn["raw_imports"] = ["some/path"] - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "schema" in err.value.message - - ex_fn["schema"] = "core" - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "runtime_version" in err.value.message - - ex_fn["runtime_version"] = "3.8" - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - - assert ex_fn == snapshot - - -def test_sanitize_extension_function_data_other_malformed_keys(): - ex_fn = { - "object_type": "function", - "object_name": "some_name", - "return_sql": "returns null", - "func": "dummy_func", - "raw_imports": ["some/path"], - "schema": "core", - "replace": True, - "if_not_exists": True, - "runtime_version": "3.8", - } - some_path = Path("some/path") - - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "incompatible" in err.value.message - - ex_fn["if_not_exists"] = False - ex_fn["input_args"] = ["dummy"] - ex_fn["input_sql_types"] = ["dummy", "values"] - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "number of extension function parameters" in err.value.message - - ex_fn["input_sql_types"] = ["dummy"] - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert ex_fn["application_roles"] == [] - - ex_fn["application_roles"] = ["app_viewer", None, {}] - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert "application_roles" in err.value.message - - ex_fn["application_roles"] = ["app_viewer", "app_admin"] - ef_utils.sanitize_extension_function_data(ex_fn=ex_fn, py_file=some_path) - assert ex_fn["application_roles"] == ["APP_VIEWER", "APP_ADMIN"] - - -# -------------------------------------------------------- -# ------------------- enrich_ex_fn ----------------------- -# -------------------------------------------------------- - - -def test_enrich_ex_fn(snapshot): - ex_fn = { - "func": [None, "dummy"], - "object_name": "SNOWPARK_TEMP_", - "schema": None, - "raw_imports": [], - } - - with pytest.raises(ef_utils.MalformedExtensionFunctionError) as err: - ef_utils.enrich_ex_fn( - ex_fn=ex_fn, - py_file=Path("some", "file.py"), - deploy_root=Path("output", "deploy"), - ) - assert "determine handler name" in err.value.message - - ex_fn["func"] = "dummy" - ef_utils.enrich_ex_fn( - ex_fn=ex_fn, - py_file=Path("some", "file.py"), - deploy_root=Path("output", "deploy"), +def test_get_sql_object_type( + function_type, expected, native_app_extension_function_raw_data +): + native_app_extension_function_raw_data["type"] = function_type + extension_fn = NativeAppExtensionFunction(**native_app_extension_function_raw_data) + assert ef_utils.get_sql_object_type(extension_fn) == expected + + +def test_get_sql_argument_signature(): + arg = Argument(name="foo", type="int") + assert ef_utils.get_sql_argument_signature(arg) == "foo int" + + arg = Argument(name="foo", type="int", default="42") + assert ef_utils.get_sql_argument_signature(arg) == "foo int DEFAULT 42" + + +def test_get_qualified_object_name(native_app_extension_function): + native_app_extension_function.name = "foo" + native_app_extension_function.schema_name = None + + assert ef_utils.get_qualified_object_name(native_app_extension_function) == "foo" + + native_app_extension_function.name = "foo" + native_app_extension_function.schema_name = "my_schema" + + assert ( + ef_utils.get_qualified_object_name(native_app_extension_function) + == "my_schema.foo" ) - assert ex_fn["object_name"] == "dummy" - ex_fn["object_name"] = "MY_FUNC" - ef_utils.enrich_ex_fn( - ex_fn=ex_fn, - py_file=Path("some", "file.py"), - deploy_root=Path("output", "deploy"), + native_app_extension_function.name = "foo" + native_app_extension_function.schema_name = "my schema" + + assert ( + ef_utils.get_qualified_object_name(native_app_extension_function) + == '"my schema".foo' ) - assert ex_fn["object_name"] == "MY_FUNC" - ex_fn["schema"] = "" - ef_utils.enrich_ex_fn( - ex_fn=ex_fn, - py_file=Path("some", "file.py"), - deploy_root=Path("output", "deploy"), + native_app_extension_function.name = "foo" + native_app_extension_function.schema_name = "my.full.schema" + + assert ( + ef_utils.get_qualified_object_name(native_app_extension_function) + == "my.full.schema.foo" ) - assert ex_fn["object_name"] == "MY_FUNC" - - ex_fn["schema"] = "core" - ex_fn["raw_imports"] = [ - "a/b/c.py", - "a/b/c", - ["a/b/c.py", "a.b.c"], - ["a/b/c.jar", "a.b.c"], - ["a/b/c", "a.b.c"], - ] - ef_utils.enrich_ex_fn( - ex_fn=ex_fn, - py_file=Path("some", "file.py"), - deploy_root=Path("output", "deploy"), + + native_app_extension_function.name = "foo" + native_app_extension_function.schema_name = "my.full schema.with special chars" + + assert ( + ef_utils.get_qualified_object_name(native_app_extension_function) + == 'my."full schema"."with special chars".foo' ) - assert ex_fn["object_name"] == "core.MY_FUNC" - assert ex_fn["all_imports"] == snapshot + + +def test_ensure_string_literal(): + assert ef_utils.ensure_string_literal("") == "''" + assert ef_utils.ensure_string_literal("abc") == "'abc'" + assert ef_utils.ensure_string_literal("'abc'") == "'abc'" + assert ef_utils.ensure_string_literal("'abc def'") == "'abc def'" + assert ef_utils.ensure_string_literal("'abc") == r"'\'abc'" + assert ef_utils.ensure_string_literal("abc'") == r"'abc\''" + + +def test_ensure_all_string_literals(): + assert ef_utils.ensure_all_string_literals([]) == [] + assert ef_utils.ensure_all_string_literals(["", "foo", "'bar'"]) == [ + "''", + "'foo'", + "'bar'", + ] diff --git a/tests/nativeapp/codegen/snowpark/test_python_processor.py b/tests/nativeapp/codegen/snowpark/test_python_processor.py index 521c343e02..4e2a213e8b 100644 --- a/tests/nativeapp/codegen/snowpark/test_python_processor.py +++ b/tests/nativeapp/codegen/snowpark/test_python_processor.py @@ -13,7 +13,7 @@ SnowparkAnnotationProcessor, _determine_virtual_env, _execute_in_sandbox, - generate_create_sql_ddl_statements, + generate_create_sql_ddl_statement, generate_grant_sql_ddl_statements, ) @@ -138,36 +138,25 @@ def test_execute_in_sandbox_all_possible_none_cases(mock_sandbox): # -------------------------------------------------------- -# ------- generate_create_sql_ddl_statements ------------- +# ------- generate_create_sql_ddl_statement ------------- # -------------------------------------------------------- def test_generate_create_sql_ddl_statements_w_all_entries( - native_app_codegen_full_json, snapshot + native_app_extension_function, snapshot ): - assert generate_create_sql_ddl_statements(native_app_codegen_full_json) == snapshot + assert generate_create_sql_ddl_statement(native_app_extension_function) == snapshot def test_generate_create_sql_ddl_statements_w_select_entries( - native_app_codegen_full_json, snapshot + native_app_extension_function, snapshot ): - native_app_codegen_full_json["replace"] = False - native_app_codegen_full_json["all_imports"] = "" - native_app_codegen_full_json["all_packages"] = "" - native_app_codegen_full_json["external_access_integrations"] = None - native_app_codegen_full_json["secrets"] = None - native_app_codegen_full_json["execute_as"] = None - native_app_codegen_full_json["inline_python_code"] = None - assert generate_create_sql_ddl_statements(native_app_codegen_full_json) == snapshot - - -def test_generate_create_sql_ddl_statements_none(): - ex_fn = { - "object_type": "PROCEDURE", - "object_name": "CORE.MYFUNC", - "anonymous": True, - } - assert generate_create_sql_ddl_statements(ex_fn=ex_fn) is None + native_app_extension_function.imports = None + native_app_extension_function.packages = None + native_app_extension_function.schema_name = None + native_app_extension_function.secrets = None + native_app_extension_function.external_access_integrations = None + assert generate_create_sql_ddl_statement(native_app_extension_function) == snapshot # -------------------------------------------------------- @@ -175,20 +164,8 @@ def test_generate_create_sql_ddl_statements_none(): # -------------------------------------------------------- -def test_generate_grant_sql_ddl_statements(snapshot): - ex_fn = { - "object_type": "TABLE_FUNCTION", - "object_name": "CORE.MYFUNC", - "application_roles": ["APP_ADMIN", "APP_VIEWER"], - } - assert generate_grant_sql_ddl_statements(ex_fn=ex_fn) == snapshot - - -def test_generate_grant_sql_ddl_statements_none(): - ex_fn = {"application_roles": None} - assert generate_grant_sql_ddl_statements(ex_fn=ex_fn) is None - ex_fn["application_roles"] = [] - assert generate_grant_sql_ddl_statements(ex_fn=ex_fn) is None +def test_generate_grant_sql_ddl_statements(native_app_extension_function, snapshot): + assert generate_grant_sql_ddl_statements(native_app_extension_function) == snapshot # -------------------------------------------------------- @@ -214,42 +191,19 @@ def test_generate_grant_sql_ddl_statements_none(): "output/deploy/stagepath/data.py": "# this is a file\n", } -# Test when exception is thrown while collecting information from callback -@mock.patch( - "snowflake.cli.plugins.nativeapp.codegen.snowpark.python_processor._execute_in_sandbox", - side_effect=SandboxExecutionError("dummy"), -) -def test_process_exception(mock_sandbox, native_app_project_instance): - with temp_local_dir(default_dir_structure) as local_path: - native_app_project_instance.native_app.artifacts = [ - { - "src": "a/b/c/*.py", # Will pick "a/b/c/main.py" - "dest": "stagepath/", - "processors": ["SNOWPARK"], - } - ] - artifact_to_process = native_app_project_instance.native_app.artifacts[0] - dest_file_py_file_to_ddl_map = SnowparkAnnotationProcessor( - project_definition=native_app_project_instance, - project_root=local_path, - deploy_root=Path(local_path, "output/deploy"), - ).process( - artifact_to_process=artifact_to_process, - processor_mapping=ProcessorMapping(name="SNOWPARK"), - ) - assert len(dest_file_py_file_to_ddl_map) == 0 - @mock.patch( "snowflake.cli.plugins.nativeapp.codegen.snowpark.python_processor._execute_in_sandbox", ) -def test_generate_sql_ddl_statements_empty(mock_sandbox, native_app_project_instance): +def test_process_no_collected_functions( + mock_sandbox, native_app_project_instance, snapshot +): with temp_local_dir(minimal_dir_structure) as local_path: native_app_project_instance.native_app.artifacts = [ {"src": "a/b/c/*.py", "dest": "stagepath/", "processors": ["SNOWPARK"]} ] mock_sandbox.side_effect = [None, []] - dest_file_py_file_to_ddl_map = SnowparkAnnotationProcessor( + output = SnowparkAnnotationProcessor( project_definition=native_app_project_instance, project_root=local_path, deploy_root=Path(local_path, "output/deploy"), @@ -257,14 +211,17 @@ def test_generate_sql_ddl_statements_empty(mock_sandbox, native_app_project_inst artifact_to_process=native_app_project_instance.native_app.artifacts[0], processor_mapping=ProcessorMapping(name="SNOWPARK"), ) - assert len(dest_file_py_file_to_ddl_map) == 0 + assert output == snapshot @mock.patch( "snowflake.cli.plugins.nativeapp.codegen.snowpark.python_processor._execute_in_sandbox", ) -def test_generate_sql_ddl_statements( - mock_sandbox, native_app_project_instance, native_app_codegen_full_json, snapshot +def test_process_with_collected_functions( + mock_sandbox, + native_app_project_instance, + native_app_extension_function_raw_data, + snapshot, ): with temp_local_dir(minimal_dir_structure) as local_path: processor_mapping = ProcessorMapping( @@ -279,10 +236,10 @@ def test_generate_sql_ddl_statements( } ] mock_sandbox.side_effect = [ - [native_app_codegen_full_json], - [copy.deepcopy(native_app_codegen_full_json)], + [native_app_extension_function_raw_data], + [copy.deepcopy(native_app_extension_function_raw_data)], ] - dest_file_py_file_to_ddl_map = SnowparkAnnotationProcessor( + output = SnowparkAnnotationProcessor( project_definition=native_app_project_instance, project_root=local_path, deploy_root=Path(local_path, "output/deploy"), @@ -290,38 +247,4 @@ def test_generate_sql_ddl_statements( artifact_to_process=native_app_project_instance.native_app.artifacts[0], processor_mapping=processor_mapping, ) - assert len(dest_file_py_file_to_ddl_map) == 2 - values = list(dest_file_py_file_to_ddl_map.values()) - assert values[0] == snapshot - assert values[1] == snapshot - - -@mock.patch( - "snowflake.cli.plugins.nativeapp.codegen.snowpark.python_processor._execute_in_sandbox", -) -def test_generate_sql_ddl_statements_filtered_create( - mock_sandbox, native_app_project_instance, native_app_codegen_full_json, snapshot -): - with temp_local_dir(minimal_dir_structure) as local_path: - native_app_project_instance.native_app.artifacts = [ - {"src": "a/b/c/*.py", "dest": "stagepath/", "processors": ["SNOWPARK"]} - ] - copy_instance = copy.deepcopy(native_app_codegen_full_json) - copy_instance["object_type"] = "PROCEDURE" - copy_instance["anonymous"] = True - mock_sandbox.side_effect = [ - [native_app_codegen_full_json, None], - [copy_instance, {}], - ] - - dest_file_py_file_to_ddl_map = SnowparkAnnotationProcessor( - project_definition=native_app_project_instance, - project_root=local_path, - deploy_root=Path(local_path, "output/deploy"), - ).process( - artifact_to_process=native_app_project_instance.native_app.artifacts[0], - processor_mapping=ProcessorMapping(name="SNOWPARK"), - ) - - assert len(dest_file_py_file_to_ddl_map) == 1 - assert list(dest_file_py_file_to_ddl_map.values())[0] == snapshot + assert output == snapshot diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index 17496d8a32..88de07d308 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from io import StringIO from pathlib import Path -from typing import Generator, List, NamedTuple, Optional, Union +from typing import Any, Dict, Generator, List, NamedTuple, Optional, Union from unittest import mock import pytest @@ -18,6 +18,9 @@ from snowflake.cli.api.project.schemas.snowpark.argument import Argument from snowflake.cli.api.project.schemas.snowpark.callable import FunctionSchema from snowflake.cli.app.cli_app import app_factory +from snowflake.cli.plugins.nativeapp.codegen.snowpark.models import ( + NativeAppExtensionFunction, +) from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.errors import ProgrammingError from typer import Typer @@ -342,27 +345,28 @@ def native_app_project_instance(): ) -@pytest.fixture() -def native_app_codegen_full_json(): +@pytest.fixture +def native_app_extension_function_raw_data() -> Dict[str, Any]: return { - "object_type": "TABLE_FUNCTION", - "object_name": "SNOWPARK_TEMP_FUNCTION_WZUNHMZJKA", - "input_args": [{"name": "arg1", "datatype": "IntegerType"}], - "input_sql_types": ["INT"], - "return_sql": "RETURNS INT", - "runtime_version": "3.11", - "handler": "dummy_handler", + "type": "procedure", + "lineno": 42, + "name": "my_function", + "signature": [{"name": "first", "type": "int", "default": "42"}], + "returns": "int", + "runtime": "3.11", + "handler": "my_function_handler", "external_access_integrations": ["integration_one", "integration_two"], "secrets": {"key1": "secret_one", "key2": "integration_two"}, - "inline_python_code": "dummy_inline_code", - "raw_imports": ["a/b/c.py"], - "all_packages": "'package_one', 'package_two'", - "all_imports": "'path_one', 'path_two'", - "replace": True, - "if_not_exists": False, - "execute_as": "OWNER", - "anonymous": False, - "func": "", + "packages": ["package_one==1.0.2", "package_two"], + "imports": ["/path/to/import1.py", "/path/to/import2.zip"], + "execute_as_caller": False, "schema": "DATA", "application_roles": ["APP_ADMIN", "APP_VIEWER"], } + + +@pytest.fixture +def native_app_extension_function( + native_app_extension_function_raw_data, +) -> NativeAppExtensionFunction: + return NativeAppExtensionFunction(**native_app_extension_function_raw_data)