Skip to content

Commit

Permalink
[NADE] Rewrote extension function collection to use a Pydantic-based …
Browse files Browse the repository at this point in the history
…schema validator (#1096)
  • Loading branch information
sfc-gh-bdufour authored May 22, 2024
1 parent d99449c commit 9e2ea4f
Show file tree
Hide file tree
Showing 8 changed files with 527 additions and 954 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import functools
import inspect
import sys
from typing import Any, Callable, List
from typing import Callable

try:
import snowflake.snowpark
Expand All @@ -12,107 +14,155 @@ except ModuleNotFoundError as exc:
)
sys.exit(1)

orig_globals = globals().copy()

found_correct_version = hasattr(
__snowflake_internal_found_correct_version = hasattr(
snowflake.snowpark.context, "_is_execution_environment_sandboxed_for_client"
) and hasattr(snowflake.snowpark.context, "_should_continue_registration")

if not found_correct_version:
if not __snowflake_internal_found_correct_version:
print(
"Did not find the minimum required version for snowflake-snowpark-python package. Please upgrade to v1.15.0 or higher.",
file=sys.stderr,
)
sys.exit(1)

__snowflake_cli_native_app_internal_callback_return_list: List[Any] = []
__snowflake_global_collected_extension_fn_json = []

def __snowflake_internal_create_extension_fn_registration_callback():
def __snowflake_internal_try_extract_lineno(extension_fn):
try:
return inspect.getsourcelines(extension_fn)[1]
except Exception:
return None

def __snowflake_internal_extract_extension_fn_name(extension_fn):
try:
import snowflake.snowpark._internal.utils as snowpark_utils

if hasattr(snowpark_utils, 'TEMP_OBJECT_NAME_PREFIX'):
if extension_fn.object_name.startswith(snowpark_utils.TEMP_OBJECT_NAME_PREFIX):
# The object name is a generated one, don't use it
return None

except Exception:
# ignore any exception and fall back to using the object name reported from Snowpark
pass

return extension_fn.object_name

def __snowflake_internal_create_package_list(extension_fn):
return [pkg_spec.strip() for pkg_spec in extension_fn.all_packages.split(",")]

def __snowflake_internal_make_extension_fn_signature(extension_fn):
# Try to fetch the original argument names from the extension function
try:
args_spec = inspect.getfullargspec(extension_fn.func)
original_arg_names = args_spec[0]
start_index = len(original_arg_names) - len(extension_fn.input_sql_types)
signature = []
defaults_start_index = len(original_arg_names) - len(args_spec.defaults or [])
for i in range(len(extension_fn.input_sql_types)):
arg = {
'name': original_arg_names[start_index + i],
'type': extension_fn.input_sql_types[i]
}
if i >= defaults_start_index:
arg['default'] = args_spec.defaults[defaults_start_index + i]
signature.append(arg)

return signature
except Exception as e:
msg = str(e)
pass # ignore, we'll use the fallback strategy

# Failed to extract the original arguments through reflection, fall back to alternative approach
return [
{"name": input_arg.name, "type": input_type}
for (input_arg, input_type) in zip(extension_fn.input_args, extension_fn.input_sql_types)
]

def __snowflake_internal_to_extension_fn_type(object_type):
if object_type.name == "AGGREGATE_FUNCTION":
return "aggregate function"
if object_type.name == "TABLE_FUNCTION":
return "table function"
return object_type.name.lower()

def __snowflake_internal_extension_fn_to_json(extension_fn):
if not isinstance(extension_fn.func, Callable):
# Unsupported case: extension function is a tuple
return

if extension_fn.anonymous:
# unsupported, native application extension functions need to be explicitly named
return

# Collect basic properties of the extension function
extension_fn_json = {
"type": __snowflake_internal_to_extension_fn_type(extension_fn.object_type),
"lineno": __snowflake_internal_try_extract_lineno(extension_fn.func),
"name": __snowflake_internal_extract_extension_fn_name(extension_fn),
"handler": extension_fn.func.__name__,
"imports": extension_fn.all_imports or [],
"packages": __snowflake_internal_create_package_list(extension_fn),
"runtime": extension_fn.runtime_version,
"returns": extension_fn.return_sql.upper().replace("RETURNS ", "").strip(),
"signature": __snowflake_internal_make_extension_fn_signature(extension_fn),
"external_access_integrations": extension_fn.external_access_integrations or [],
"secrets": extension_fn.secrets or {},
}

if extension_fn.object_type.name == "PROCEDURE" and extension_fn.execute_as is not None:
extension_fn_json['execute_as_caller'] = (extension_fn.execute_as == 'caller')

def __snowflake_cli_native_app_internal_callback_replacement():
global __snowflake_cli_native_app_internal_callback_return_list
if extension_fn.native_app_params is not None:
schema = extension_fn.native_app_params.get("schema")
if schema is not None:
extension_fn_json["schema"] = schema
app_roles = extension_fn.native_app_params.get("application_roles")
if app_roles is not None:
extension_fn_json["application_roles"] = app_roles

def __snowflake_cli_native_app_internal_transform_snowpark_object_to_json(
extension_function_properties,
):
return extension_fn_json

{% raw %}ext_fn = extension_function_properties
extension_function_dict = {
"object_type": ext_fn.object_type.name,
"object_name": ext_fn.object_name,
"input_args": [
{"name": input_arg.name, "datatype": type(input_arg.datatype).__name__}
for input_arg in ext_fn.input_args
],
"input_sql_types": ext_fn.input_sql_types,
"return_sql": ext_fn.return_sql,
"runtime_version": ext_fn.runtime_version,
"all_imports": ext_fn.all_imports,
"all_packages": ext_fn.all_packages,
"handler": ext_fn.handler,
"external_access_integrations": ext_fn.external_access_integrations,
"secrets": ext_fn.secrets,
"inline_python_code": ext_fn.inline_python_code,
"raw_imports": ext_fn.raw_imports,
"replace": ext_fn.replace,
"if_not_exists": ext_fn.if_not_exists,
"execute_as": ext_fn.execute_as,
"anonymous": ext_fn.anonymous,
# Set func based on type
"func": ext_fn.func.__name__
if isinstance(ext_fn.func, Callable)
else ext_fn.func,
}
# Set native app params based on dictionary
if ext_fn.native_app_params is not None:
extension_function_dict["schema"] = ext_fn.native_app_params.get(
"schema", None
)
extension_function_dict["application_roles"] = ext_fn.native_app_params.get(
"application_roles", None
)
else:
extension_function_dict["schema"] = extension_function_dict[
"application_roles"
] = None
# Imports and handler will be set at a later time.{% endraw %}
return extension_function_dict

def __snowflake_cli_native_app_internal_callback_append_to_list(
callback_return_list, extension_function_properties
def __snowflake_internal_collect_extension_fn(
collected_extension_fn_json_list, extension_function_properties
):
extension_function_dict = (
__snowflake_cli_native_app_internal_transform_snowpark_object_to_json(
extension_function_properties
)
)
callback_return_list.append(extension_function_dict)
extension_fn_json = __snowflake_internal_extension_fn_to_json(extension_function_properties)
collected_extension_fn_json_list.append(extension_fn_json)
return False

return functools.partial(
__snowflake_cli_native_app_internal_callback_append_to_list,
__snowflake_cli_native_app_internal_callback_return_list,
__snowflake_internal_collect_extension_fn,
__snowflake_global_collected_extension_fn_json,
)


with open("{{py_file}}", mode="r", encoding="utf-8") as udf_code:
code = udf_code.read()


snowflake.snowpark.context._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001
True
)
snowflake.snowpark.context._should_continue_registration = ( # noqa: SLF001
__snowflake_cli_native_app_internal_callback_replacement()
__snowflake_internal_create_extension_fn_registration_callback()
)
snowflake.snowpark.session._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001
True
)

for global_key in list(globals().keys()):
if global_key.startswith("__snowflake_internal"):
del globals()[global_key]

del globals()["global_key"] # make sure to clean up the loop variable as well

try:
exec(code, orig_globals)
import importlib
with contextlib.redirect_stdout(None):
with contextlib.redirect_stderr(None):
__snowflake_internal_spec = importlib.util.spec_from_file_location("<string>", "{{py_file}}")
__snowflake_internal_module = importlib.util.module_from_spec(__snowflake_internal_spec)
__snowflake_internal_spec.loader.exec_module(__snowflake_internal_module)
except Exception as exc: # Catch any error
print("An exception occurred while executing file: ", exc, file=sys.stderr)
sys.exit(1)


import json
print(json.dumps(__snowflake_cli_native_app_internal_callback_return_list))
print(json.dumps(__snowflake_global_collected_extension_fn_json))
Loading

0 comments on commit 9e2ea4f

Please sign in to comment.