diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index 1bcef29fd8..87c248f51d 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -415,7 +415,10 @@ def _password_callback(value: str): None, "--variable", "-D", - help="Variables for the template. For example: `-D \"=\"`, string values must be in `''`.", + help='Variables for the execution context. For example: `-D "="`. ' + "For SQL files variables are use to expand the template and any unknown variable will cause an error. " + "For Python files variables are used to update os.environ dictionary. " + "In case of SQL files string values must be quoted in `''` (consider embedding quoting in the file).", hidden=True, show_default=False, ) @@ -545,8 +548,10 @@ def __init__(self, key: str, value: str): self.value = value -def parse_key_value_variables(variables: List[str]) -> List[Variable]: +def parse_key_value_variables(variables: List[str] | None) -> List[Variable]: """Util for parsing key=value input. Useful for commands accepting multiple input options.""" + if not variables: + return [] result = [] for p in variables: if "=" not in p: diff --git a/src/snowflake/cli/plugins/stage/manager.py b/src/snowflake/cli/plugins/stage/manager.py index c44207622a..2788ea75a6 100644 --- a/src/snowflake/cli/plugins/stage/manager.py +++ b/src/snowflake/cli/plugins/stage/manager.py @@ -8,10 +8,15 @@ from dataclasses import dataclass from os import path from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from textwrap import dedent +from typing import Dict, List, Optional, Union from click import ClickException -from snowflake.cli.api.commands.flags import OnErrorType, parse_key_value_variables +from snowflake.cli.api.commands.flags import ( + OnErrorType, + Variable, + parse_key_value_variables, +) from snowflake.cli.api.console import cli_console from snowflake.cli.api.project.util import to_string_literal from snowflake.cli.api.secure_path import SecurePath @@ -48,6 +53,10 @@ def path(self) -> str: class StageManager(SqlExecutionMixin): + def __init__(self): + super().__init__() + self._python_exe_procedure = None + @staticmethod def get_standard_stage_prefix(name: str) -> str: # Handle embedded stages @@ -221,21 +230,15 @@ def execute( filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f)) ) - sql_variables = self._parse_execute_variables(variables) + parsed_variables = parse_key_value_variables(variables) + sql_variables = self._parse_execute_variables(parsed_variables) + python_variables = {str(v.key).upper(): v.value for v in parsed_variables} results = [] - if any(f.endswith(".py") for f in sorted_file_list): - # Bootstrap Snowpark session - python_exec_sproc = self._bootstrap_snowpark_execution_environment() - else: - python_exec_sproc = None - for file in sorted_file_list: if file.endswith(".py"): result = self._execute_python( - file=file, - on_error=on_error, - python_execution_procedure=python_exec_sproc, + file=file, on_error=on_error, variables=python_variables ) else: result = self._call_execute_immediate( @@ -304,12 +307,10 @@ def _filter_supported_files(files: List[str]) -> List[str]: return [f for f in files if Path(f).suffix in EXECUTE_SUPPORTED_FILES_FORMATS] @staticmethod - def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]: + def _parse_execute_variables(variables: List[Variable]) -> Optional[str]: if not variables: return None - - parsed_variables = parse_key_value_variables(variables) - query_parameters = [f"{v.key}=>{v.value}" for v in parsed_variables] + query_parameters = [f"{v.key}=>{v.value}" for v in variables] return f" using ({', '.join(query_parameters)})" @staticmethod @@ -358,29 +359,46 @@ def _bootstrap_snowpark_execution_environment(self): from snowflake.snowpark.functions import sproc self.snowpark_session.add_packages("snowflake-snowpark-python") + self.snowpark_session.add_packages("snowflake.core") @sproc(is_permanent=False) - def _python_execution_procedure(_: Session, file_path: str) -> None: - """Snowpark session-scoped stored procedure to execute content of provided pyrthon file.""" + def _python_execution_procedure( + _: Session, file_path: str, variables: Dict | None = None + ) -> None: + """Snowpark session-scoped stored procedure to execute content of provided python file.""" + import json + from snowflake.snowpark.files import SnowflakeFile with SnowflakeFile.open(file_path, require_scoped_url=False) as f: file_content: str = f.read() # type: ignore - exec(file_content) + + wrapper = dedent( + f"""\ + import os + os.environ.update({json.dumps(variables)}) + """ + ) + + exec(wrapper + file_content) return _python_execution_procedure - def _execute_python( - self, file: str, on_error: OnErrorType, python_execution_procedure: Callable - ): + def _execute_python(self, file: str, on_error: OnErrorType, variables: Dict): """ Executes Python file from stage using a Snowpark temporary procedure. Currently, there's no option to pass input to the execution. """ from snowflake.snowpark.exceptions import SnowparkSQLException + # Bootstrap Snowpark session + if self._python_exe_procedure is None: + self._python_exe_procedure = ( + self._bootstrap_snowpark_execution_environment() + ) + try: - python_execution_procedure(self.get_standard_stage_prefix(file)) + self._python_exe_procedure(self.get_standard_stage_prefix(file), variables) # type: ignore return StageManager._success_result(file=file) except SnowparkSQLException as e: StageManager._handle_execution_exception(on_error=on_error, exception=e) diff --git a/tests/stage/__snapshots__/test_stage.ambr b/tests/stage/__snapshots__/test_stage.ambr index e40a791d26..1c2ac7a803 100644 --- a/tests/stage/__snapshots__/test_stage.ambr +++ b/tests/stage/__snapshots__/test_stage.ambr @@ -35,7 +35,7 @@ | @db.schema.exe/a/s3.sql | SUCCESS | None | | @db.schema.exe/a/b/s4.sql | SUCCESS | None | +---------------------------------------------+ - + ''' # --- # name: test_execute[@db.schema.exe/s1.sql-@db.schema.exe-expected_files18] @@ -46,7 +46,7 @@ |-----------------------+---------+-------| | @db.schema.exe/s1.sql | SUCCESS | None | +-----------------------------------------+ - + ''' # --- # name: test_execute[@exe-@exe-expected_files0] @@ -61,7 +61,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[db.schema.exe-@db.schema.exe-expected_files17] @@ -76,7 +76,7 @@ | @db.schema.exe/a/s3.sql | SUCCESS | None | | @db.schema.exe/a/b/s4.sql | SUCCESS | None | +---------------------------------------------+ - + ''' # --- # name: test_execute[exe-@exe-expected_files2] @@ -91,7 +91,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/*-@exe-expected_files4] @@ -106,7 +106,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/*.sql-@exe-expected_files5] @@ -121,7 +121,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/-@exe-expected_files3] @@ -136,7 +136,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/a-@exe-expected_files6] @@ -149,7 +149,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/a/*-@exe-expected_files8] @@ -162,7 +162,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/a/*.sql-@exe-expected_files9] @@ -175,7 +175,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/a/-@exe-expected_files7] @@ -188,7 +188,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/a/b-@exe-expected_files10] @@ -199,7 +199,7 @@ |-----------------+---------+-------| | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/a/b/*-@exe-expected_files12] @@ -210,7 +210,7 @@ |-----------------+---------+-------| | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute[exe/a/b/*.sql-@exe-expected_files13] @@ -254,7 +254,7 @@ |-------------+---------+-------| | @exe/s1.sql | SUCCESS | None | +-------------------------------+ - + ''' # --- # name: test_execute[snow://exe-@exe-expected_files1] @@ -269,7 +269,7 @@ | @exe/a/s3.sql | SUCCESS | None | | @exe/a/b/s4.sql | SUCCESS | None | +-----------------------------------+ - + ''' # --- # name: test_execute_continue_on_error @@ -290,7 +290,7 @@ # name: test_execute_raise_invalid_file_extension_error ''' ╭─ Error ──────────────────────────────────────────────────────────────────────╮ - │ Invalid file extension, only `.sql` files are allowed. │ + │ Invalid file extension, only .sql,.py files are allowed. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ''' diff --git a/tests/stage/test_stage.py b/tests/stage/test_stage.py index ea96124ac0..311a2418b3 100644 --- a/tests/stage/test_stage.py +++ b/tests/stage/test_stage.py @@ -6,6 +6,7 @@ from snowflake.cli.plugins.stage.manager import StageManager from snowflake.connector import ProgrammingError from snowflake.connector.cursor import DictCursor +from snowflake.snowpark.exceptions import SnowparkSQLException STAGE_MANAGER = "snowflake.cli.plugins.stage.manager.StageManager" @@ -850,13 +851,16 @@ def test_execute_no_files_for_stage_path( @mock.patch(f"{STAGE_MANAGER}._execute_query") -def test_execute_stop_on_error(mock_execute, mock_cursor, runner): +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +def test_execute_stop_on_error(mock_bootstrap, mock_execute, mock_cursor, runner): error_message = "Error" mock_execute.side_effect = [ mock_cursor( [ {"name": "exe/s1.sql"}, + {"name": "exe/p1.py"}, {"name": "exe/s2.sql"}, + {"name": "exe/p2.py"}, {"name": "exe/s3.sql"}, ], [], @@ -873,16 +877,25 @@ def test_execute_stop_on_error(mock_execute, mock_cursor, runner): mock.call(f"execute immediate from @exe/s1.sql"), mock.call(f"execute immediate from @exe/s2.sql"), ] + assert mock_bootstrap.return_value.mock_calls == [ + mock.call("@exe/p1.py"), + mock.call("@exe/p2.py"), + ] assert e.value.msg == error_message @mock.patch(f"{STAGE_MANAGER}._execute_query") -def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot): +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +def test_execute_continue_on_error( + mock_bootstrap, mock_execute, mock_cursor, runner, snapshot +): mock_execute.side_effect = [ mock_cursor( [ {"name": "exe/s1.sql"}, + {"name": "exe/p1.py"}, {"name": "exe/s2.sql"}, + {"name": "exe/p2.py"}, {"name": "exe/s3.sql"}, ], [], @@ -892,6 +905,8 @@ def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot): mock_cursor([{"3": 3}], []), ] + mock_bootstrap.return_value.side_effect = ["ok", SnowparkSQLException("Test error")] + result = runner.invoke(["stage", "execute", "exe", "--on-error", "continue"]) assert result.exit_code == 0 @@ -902,3 +917,8 @@ def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot): mock.call(f"execute immediate from @exe/s2.sql"), mock.call(f"execute immediate from @exe/s3.sql"), ] + + assert mock_bootstrap.return_value.mock_calls == [ + mock.call("@exe/p1.py"), + mock.call("@exe/p2.py"), + ] diff --git a/tests_integration/__snapshots__/test_stage.ambr b/tests_integration/__snapshots__/test_stage.ambr index cd4d3fb8c5..6f48aa4944 100644 --- a/tests_integration/__snapshots__/test_stage.ambr +++ b/tests_integration/__snapshots__/test_stage.ambr @@ -18,11 +18,16 @@ }), ]) # --- -# name: test_stage_execute.1 +# name: test_stage_execute_python list([ dict({ 'Error': None, - 'File': '@test_stage_execute/script_template.sql', + 'File': 'test_stage_execute/script1.py', + 'Status': 'SUCCESS', + }), + dict({ + 'Error': None, + 'File': '@test_stage_execute/script_template.py', 'Status': 'SUCCESS', }), ]) diff --git a/tests_integration/test_data/projects/stage_execute/script1.py b/tests_integration/test_data/projects/stage_execute/script1.py index 1378a3919f..c9f9b6de4b 100644 --- a/tests_integration/test_data/projects/stage_execute/script1.py +++ b/tests_integration/test_data/projects/stage_execute/script1.py @@ -1,8 +1 @@ -import os -from snowflake.core import Root -from snowflake.snowpark.context import get_active_session - -session = get_active_session() -database = Root(session).databases[os.environ["TEST_DATABASE_NAME"]] - -assert database.name.upper() == os.environ["TEST_DATABASE_NAME"].upper() +print("ok") diff --git a/tests_integration/test_data/projects/stage_execute/script_template.py b/tests_integration/test_data/projects/stage_execute/script_template.py index 1378a3919f..ffe0b96bdf 100644 --- a/tests_integration/test_data/projects/stage_execute/script_template.py +++ b/tests_integration/test_data/projects/stage_execute/script_template.py @@ -1,8 +1,13 @@ import os from snowflake.core import Root -from snowflake.snowpark.context import get_active_session +from snowflake.core.database import DatabaseResource +from snowflake.core.schema import Schema +from snowflake.snowpark.session import Session -session = get_active_session() -database = Root(session).databases[os.environ["TEST_DATABASE_NAME"]] +session = Session.builder.getOrCreate() +database: DatabaseResource = Root(session).databases[os.environ["TEST_DATABASE_NAME"]] assert database.name.upper() == os.environ["TEST_DATABASE_NAME"].upper() + +# Make a side effect that we can check in tests +database.schemas.create(Schema(name=os.environ["TEST_ID"])) diff --git a/tests_integration/test_stage.py b/tests_integration/test_stage.py index a5c9c0b10f..5fa32c22e8 100644 --- a/tests_integration/test_stage.py +++ b/tests_integration/test_stage.py @@ -1,9 +1,11 @@ import glob import os import tempfile +import time from pathlib import Path import pytest +from snowflake.connector import DictCursor from tests_integration.test_utils import ( contains_row_with, @@ -209,3 +211,54 @@ def test_stage_execute(runner, test_database, test_root_path, snapshot): "Error": None, } ] + + +@pytest.mark.integration +def test_stage_execute_python( + snowflake_session, runner, test_database, test_root_path, snapshot +): + project_path = test_root_path / "test_data/projects/stage_execute" + stage_name = "test_stage_execute" + + result = runner.invoke_with_connection_json(["stage", "create", stage_name]) + assert contains_row_with( + result.json, + {"status": f"Stage area {stage_name.upper()} successfully created."}, + ) + + files = [ + "script1.py", + "script_template.py", + ] + for name in files: + result = runner.invoke_with_connection_json( + [ + "stage", + "copy", + f"{project_path}/{name}", + f"@{stage_name}", + ] + ) + assert result.exit_code == 0, result.output + assert contains_row_with(result.json, {"status": "UPLOADED"}) + + test_id = f"FOO{time.time_ns()}" + result = runner.invoke_with_connection_json( + [ + "stage", + "execute", + f"{stage_name}/", + "-D", + f"test_database_name={test_database}", + "-D", + f"TEST_ID={test_id}", + ] + ) + assert result.exit_code == 0 + assert result.json == snapshot + + # Assert side effect created by executed script + *_, schemas = snowflake_session.execute_string( + f"show schemas like '{test_id}' in database {test_database};" + ) + assert len(list(schemas)) == 1