Skip to content

Commit

Permalink
fixup! fixup! Add snowpark as dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed May 8, 2024
1 parent f22fcf9 commit 4db55b4
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 57 deletions.
9 changes: 7 additions & 2 deletions src/snowflake/cli/api/commands/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,10 @@ def _password_callback(value: str):
None,
"--variable",
"-D",
help="Variables for the template. For example: `-D \"<key>=<value>\"`, string values must be in `''`.",
help='Variables for the execution context. For example: `-D "<key>=<value>"`. '
"For SQL files variables are use to expand the template and any unknown variable will cause an error. "
"For Python files variables are used to update os.environ dictionary. "
"In case of SQL files string values must be quoted in `''` (consider embedding quoting in the file).",
hidden=True,
show_default=False,
)
Expand Down Expand Up @@ -545,8 +548,10 @@ def __init__(self, key: str, value: str):
self.value = value


def parse_key_value_variables(variables: List[str]) -> List[Variable]:
def parse_key_value_variables(variables: List[str] | None) -> List[Variable]:
"""Util for parsing key=value input. Useful for commands accepting multiple input options."""
if not variables:
return []
result = []
for p in variables:
if "=" not in p:
Expand Down
64 changes: 41 additions & 23 deletions src/snowflake/cli/plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from dataclasses import dataclass
from os import path
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from textwrap import dedent
from typing import Dict, List, Optional, Union

from click import ClickException
from snowflake.cli.api.commands.flags import OnErrorType, parse_key_value_variables
from snowflake.cli.api.commands.flags import (
OnErrorType,
Variable,
parse_key_value_variables,
)
from snowflake.cli.api.console import cli_console
from snowflake.cli.api.project.util import to_string_literal
from snowflake.cli.api.secure_path import SecurePath
Expand Down Expand Up @@ -48,6 +53,10 @@ def path(self) -> str:


class StageManager(SqlExecutionMixin):
def __init__(self):
super().__init__()
self._python_exe_procedure = None

@staticmethod
def get_standard_stage_prefix(name: str) -> str:
# Handle embedded stages
Expand Down Expand Up @@ -221,21 +230,15 @@ def execute(
filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f))
)

sql_variables = self._parse_execute_variables(variables)
parsed_variables = parse_key_value_variables(variables)
sql_variables = self._parse_execute_variables(parsed_variables)
python_variables = {str(v.key).upper(): v.value for v in parsed_variables}
results = []

if any(f.endswith(".py") for f in sorted_file_list):
# Bootstrap Snowpark session
python_exec_sproc = self._bootstrap_snowpark_execution_environment()
else:
python_exec_sproc = None

for file in sorted_file_list:
if file.endswith(".py"):
result = self._execute_python(
file=file,
on_error=on_error,
python_execution_procedure=python_exec_sproc,
file=file, on_error=on_error, variables=python_variables
)
else:
result = self._call_execute_immediate(
Expand Down Expand Up @@ -304,12 +307,10 @@ def _filter_supported_files(files: List[str]) -> List[str]:
return [f for f in files if Path(f).suffix in EXECUTE_SUPPORTED_FILES_FORMATS]

@staticmethod
def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]:
def _parse_execute_variables(variables: List[Variable]) -> Optional[str]:
if not variables:
return None

parsed_variables = parse_key_value_variables(variables)
query_parameters = [f"{v.key}=>{v.value}" for v in parsed_variables]
query_parameters = [f"{v.key}=>{v.value}" for v in variables]
return f" using ({', '.join(query_parameters)})"

@staticmethod
Expand Down Expand Up @@ -358,29 +359,46 @@ def _bootstrap_snowpark_execution_environment(self):
from snowflake.snowpark.functions import sproc

self.snowpark_session.add_packages("snowflake-snowpark-python")
self.snowpark_session.add_packages("snowflake.core")

@sproc(is_permanent=False)
def _python_execution_procedure(_: Session, file_path: str) -> None:
"""Snowpark session-scoped stored procedure to execute content of provided pyrthon file."""
def _python_execution_procedure(
_: Session, file_path: str, variables: Dict | None = None
) -> None:
"""Snowpark session-scoped stored procedure to execute content of provided python file."""
import json

from snowflake.snowpark.files import SnowflakeFile

with SnowflakeFile.open(file_path, require_scoped_url=False) as f:
file_content: str = f.read() # type: ignore
exec(file_content)

wrapper = dedent(
f"""\
import os
os.environ.update({json.dumps(variables)})
"""
)

exec(wrapper + file_content)

return _python_execution_procedure

def _execute_python(
self, file: str, on_error: OnErrorType, python_execution_procedure: Callable
):
def _execute_python(self, file: str, on_error: OnErrorType, variables: Dict):
"""
Executes Python file from stage using a Snowpark temporary procedure.
Currently, there's no option to pass input to the execution.
"""
from snowflake.snowpark.exceptions import SnowparkSQLException

# Bootstrap Snowpark session
if self._python_exe_procedure is None:
self._python_exe_procedure = (
self._bootstrap_snowpark_execution_environment()
)

try:
python_execution_procedure(self.get_standard_stage_prefix(file))
self._python_exe_procedure(self.get_standard_stage_prefix(file), variables) # type: ignore
return StageManager._success_result(file=file)
except SnowparkSQLException as e:
StageManager._handle_execution_exception(on_error=on_error, exception=e)
Expand Down
34 changes: 17 additions & 17 deletions tests/stage/__snapshots__/test_stage.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
| @db.schema.exe/a/s3.sql | SUCCESS | None |
| @db.schema.exe/a/b/s4.sql | SUCCESS | None |
+---------------------------------------------+

'''
# ---
# name: test_execute[@db.schema.exe/[email protected]_files18]
Expand All @@ -46,7 +46,7 @@
|-----------------------+---------+-------|
| @db.schema.exe/s1.sql | SUCCESS | None |
+-----------------------------------------+

'''
# ---
# name: test_execute[@exe-@exe-expected_files0]
Expand All @@ -61,7 +61,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[[email protected]_files17]
Expand All @@ -76,7 +76,7 @@
| @db.schema.exe/a/s3.sql | SUCCESS | None |
| @db.schema.exe/a/b/s4.sql | SUCCESS | None |
+---------------------------------------------+

'''
# ---
# name: test_execute[exe-@exe-expected_files2]
Expand All @@ -91,7 +91,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/*-@exe-expected_files4]
Expand All @@ -106,7 +106,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/*.sql-@exe-expected_files5]
Expand All @@ -121,7 +121,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/-@exe-expected_files3]
Expand All @@ -136,7 +136,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/a-@exe-expected_files6]
Expand All @@ -149,7 +149,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/a/*-@exe-expected_files8]
Expand All @@ -162,7 +162,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/a/*.sql-@exe-expected_files9]
Expand All @@ -175,7 +175,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/a/-@exe-expected_files7]
Expand All @@ -188,7 +188,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/a/b-@exe-expected_files10]
Expand All @@ -199,7 +199,7 @@
|-----------------+---------+-------|
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/a/b/*-@exe-expected_files12]
Expand All @@ -210,7 +210,7 @@
|-----------------+---------+-------|
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute[exe/a/b/*.sql-@exe-expected_files13]
Expand Down Expand Up @@ -254,7 +254,7 @@
|-------------+---------+-------|
| @exe/s1.sql | SUCCESS | None |
+-------------------------------+

'''
# ---
# name: test_execute[snow://exe-@exe-expected_files1]
Expand All @@ -269,7 +269,7 @@
| @exe/a/s3.sql | SUCCESS | None |
| @exe/a/b/s4.sql | SUCCESS | None |
+-----------------------------------+

'''
# ---
# name: test_execute_continue_on_error
Expand All @@ -290,7 +290,7 @@
# name: test_execute_raise_invalid_file_extension_error
'''
╭─ Error ──────────────────────────────────────────────────────────────────────╮
│ Invalid file extension, only `.sql` files are allowed.
│ Invalid file extension, only .sql,.py files are allowed. │
╰──────────────────────────────────────────────────────────────────────────────╯

'''
Expand Down
24 changes: 22 additions & 2 deletions tests/stage/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from snowflake.cli.plugins.stage.manager import StageManager
from snowflake.connector import ProgrammingError
from snowflake.connector.cursor import DictCursor
from snowflake.snowpark.exceptions import SnowparkSQLException

STAGE_MANAGER = "snowflake.cli.plugins.stage.manager.StageManager"

Expand Down Expand Up @@ -850,13 +851,16 @@ def test_execute_no_files_for_stage_path(


@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute_stop_on_error(mock_execute, mock_cursor, runner):
@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment")
def test_execute_stop_on_error(mock_bootstrap, mock_execute, mock_cursor, runner):
error_message = "Error"
mock_execute.side_effect = [
mock_cursor(
[
{"name": "exe/s1.sql"},
{"name": "exe/p1.py"},
{"name": "exe/s2.sql"},
{"name": "exe/p2.py"},
{"name": "exe/s3.sql"},
],
[],
Expand All @@ -873,16 +877,25 @@ def test_execute_stop_on_error(mock_execute, mock_cursor, runner):
mock.call(f"execute immediate from @exe/s1.sql"),
mock.call(f"execute immediate from @exe/s2.sql"),
]
assert mock_bootstrap.return_value.mock_calls == [
mock.call("@exe/p1.py"),
mock.call("@exe/p2.py"),
]
assert e.value.msg == error_message


@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot):
@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment")
def test_execute_continue_on_error(
mock_bootstrap, mock_execute, mock_cursor, runner, snapshot
):
mock_execute.side_effect = [
mock_cursor(
[
{"name": "exe/s1.sql"},
{"name": "exe/p1.py"},
{"name": "exe/s2.sql"},
{"name": "exe/p2.py"},
{"name": "exe/s3.sql"},
],
[],
Expand All @@ -892,6 +905,8 @@ def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot):
mock_cursor([{"3": 3}], []),
]

mock_bootstrap.return_value.side_effect = ["ok", SnowparkSQLException("Test error")]

result = runner.invoke(["stage", "execute", "exe", "--on-error", "continue"])

assert result.exit_code == 0
Expand All @@ -902,3 +917,8 @@ def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot):
mock.call(f"execute immediate from @exe/s2.sql"),
mock.call(f"execute immediate from @exe/s3.sql"),
]

assert mock_bootstrap.return_value.mock_calls == [
mock.call("@exe/p1.py"),
mock.call("@exe/p2.py"),
]
9 changes: 7 additions & 2 deletions tests_integration/__snapshots__/test_stage.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
}),
])
# ---
# name: test_stage_execute.1
# name: test_stage_execute_python
list([
dict({
'Error': None,
'File': '@test_stage_execute/script_template.sql',
'File': 'test_stage_execute/script1.py',
'Status': 'SUCCESS',
}),
dict({
'Error': None,
'File': '@test_stage_execute/script_template.py',
'Status': 'SUCCESS',
}),
])
Expand Down
Loading

0 comments on commit 4db55b4

Please sign in to comment.