diff --git a/.github/workflows/test_integration.yaml b/.github/workflows/test_integration.yaml index c038051615..95952f6236 100644 --- a/.github/workflows/test_integration.yaml +++ b/.github/workflows/test_integration.yaml @@ -35,7 +35,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip hatch diff --git a/.gitignore b/.gitignore index 8157fa62fa..0e4e3f615f 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ gen_docs/ /venv/ .env .vscode +tmp/ ^app.zip ^snowflake.yml diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index d0927ad250..a8f03749ff 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -20,6 +20,7 @@ ## Deprecations ## New additions +* Support for Python remote execution via `snow stage execute` and `snow git execute` similar to existing EXECUTE IMMEDIATE support. ## Fixes and improvements * The `snow app run` command now allows upgrading to unversioned mode from a versioned or release mode application installation diff --git a/pyproject.toml b/pyproject.toml index 254b9a3836..e0a9ba654b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "setuptools==70.2.0", 'snowflake.core==0.8.0; python_version < "3.12"', "snowflake-connector-python[secure-local-storage]==3.11.0", + 'snowflake-snowpark-python>=1.15.0;python_version < "3.12"', "tomlkit==0.12.5", "typer==0.12.3", "urllib3>=1.24.3,<2.3", diff --git a/snyk/requirements.txt b/snyk/requirements.txt index 74f5981dd5..84a73b7e62 100644 --- a/snyk/requirements.txt +++ b/snyk/requirements.txt @@ -8,6 +8,7 @@ requirements-parser==0.9.0 setuptools==70.2.0 snowflake.core==0.8.0; python_version < "3.12" snowflake-connector-python[secure-local-storage]==3.11.0 +snowflake-snowpark-python>=1.15.0;python_version < "3.12" tomlkit==0.12.5 typer==0.12.3 urllib3>=1.24.3,<2.3 diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index ff839c22e8..9010cd30d2 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -433,7 +433,10 @@ def _password_callback(value: str): None, "--variable", "-D", - help="Variables for the template. For example: `-D \"=\"`, string values must be in `''`.", + help='Variables for the execution context. For example: `-D "="`. ' + "For SQL files variables are use to expand the template and any unknown variable will cause an error. " + "For Python files variables are used to update os.environ dictionary. Provided keys are capitalized to adhere to best practices." + "In case of SQL files string values must be quoted in `''` (consider embedding quoting in the file).", show_default=False, ) @@ -617,11 +620,12 @@ def __init__(self, key: str, value: str): def parse_key_value_variables(variables: Optional[List[str]]) -> List[Variable]: """Util for parsing key=value input. Useful for commands accepting multiple input options.""" + if not variables: + return [] result: List[Variable] = [] if variables is None: return result - for p in variables: if "=" not in p: raise ClickException(f"Invalid variable: '{p}'") diff --git a/src/snowflake/cli/api/identifiers.py b/src/snowflake/cli/api/identifiers.py index b9f5c37c95..886cbc72e4 100644 --- a/src/snowflake/cli/api/identifiers.py +++ b/src/snowflake/cli/api/identifiers.py @@ -53,11 +53,17 @@ def name(self) -> str: return self._name @property - def identifier(self) -> str: + def prefix(self) -> str: if self.database: - return f"{self.database}.{self.schema if self.schema else 'PUBLIC'}.{self.name}" + return f"{self.database}.{self.schema if self.schema else 'PUBLIC'}" if self.schema: - return f"{self.schema}.{self.name}" + return f"{self.schema}" + return "" + + @property + def identifier(self) -> str: + if self.prefix: + return f"{self.prefix}.{self.name}" return self.name @property @@ -96,6 +102,13 @@ def from_string(cls, identifier: str) -> "FQN": unqualified_name = unqualified_name + signature return cls(name=unqualified_name, schema=schema, database=database) + @classmethod + def from_stage(cls, stage: str) -> "FQN": + name = stage + if stage.startswith("@"): + name = stage[1:] + return cls.from_string(name) + @classmethod def from_identifier_model(cls, model: ObjectIdentifierBaseModel) -> "FQN": """Create an instance from object model.""" diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index cb7b89a3fb..cb81cc3865 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -41,12 +41,22 @@ class SqlExecutionMixin: def __init__(self): - pass + self._snowpark_session = None @property def _conn(self): return cli_context.connection + @property + def snowpark_session(self): + if not self._snowpark_session: + from snowflake.snowpark.session import Session + + self._snowpark_session = Session.builder.configs( + {"connection": self._conn} + ).create() + return self._snowpark_session + @cached_property def _log(self): return logging.getLogger(__name__) diff --git a/src/snowflake/cli/plugins/cortex/commands.py b/src/snowflake/cli/plugins/cortex/commands.py index 2b36ae0416..0f866786b4 100644 --- a/src/snowflake/cli/plugins/cortex/commands.py +++ b/src/snowflake/cli/plugins/cortex/commands.py @@ -24,6 +24,7 @@ from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.commands.flags import readable_file_option from snowflake.cli.api.commands.snow_typer import SnowTyperFactory +from snowflake.cli.api.constants import PYTHON_3_12 from snowflake.cli.api.output.types import ( CollectionResult, CommandResult, @@ -45,7 +46,7 @@ help="Provides access to Snowflake Cortex.", ) -SEARCH_COMMAND_ENABLED = sys.version_info < (3, 12) +SEARCH_COMMAND_ENABLED = sys.version_info < PYTHON_3_12 @app.command( diff --git a/src/snowflake/cli/plugins/stage/manager.py b/src/snowflake/cli/plugins/stage/manager.py index 25ffc33272..d324c8c73a 100644 --- a/src/snowflake/cli/plugins/stage/manager.py +++ b/src/snowflake/cli/plugins/stage/manager.py @@ -18,15 +18,24 @@ import glob import logging import re +import sys from contextlib import nullcontext from dataclasses import dataclass from os import path from pathlib import Path +from tempfile import TemporaryDirectory +from textwrap import dedent from typing import Dict, List, Optional, Union from click import ClickException -from snowflake.cli.api.commands.flags import OnErrorType, parse_key_value_variables +from snowflake.cli.api.commands.flags import ( + OnErrorType, + Variable, + parse_key_value_variables, +) from snowflake.cli.api.console import cli_console +from snowflake.cli.api.constants import PYTHON_3_12 +from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.util import to_string_literal from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin @@ -34,12 +43,19 @@ from snowflake.connector import DictCursor, ProgrammingError from snowflake.connector.cursor import SnowflakeCursor +if sys.version_info < PYTHON_3_12: + # Because Snowpark works only below 3.12 and to use @sproc Session must be imported here. + from snowflake.snowpark import Session + log = logging.getLogger(__name__) UNQUOTED_FILE_URI_REGEX = r"[\w/*?\-.=&{}$#[\]\"\\!@%^+:]+" -EXECUTE_SUPPORTED_FILES_FORMATS = {".sql"} USER_STAGE_PREFIX = "@~" +EXECUTE_SUPPORTED_FILES_FORMATS = ( + ".sql", + ".py", +) # tuple to preserve order but it's a set @dataclass @@ -59,6 +75,10 @@ def path(self) -> str: def add_stage_prefix(self, file_path: str) -> str: raise NotImplementedError + def get_full_stage_path(self, path: str): + if prefix := FQN.from_stage(self.stage).prefix: + return prefix + "." + path + return path @dataclass class DefaultStagePathParts(StagePathParts): @@ -118,6 +138,10 @@ def add_stage_prefix(self, file_path: str) -> str: class StageManager(SqlExecutionMixin): + def __init__(self): + super().__init__() + self._python_exe_procedure = None + @staticmethod def get_standard_stage_prefix(name: str) -> str: # Handle embedded stages @@ -282,22 +306,40 @@ def execute( filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f)) ) - sql_variables = self._parse_execute_variables(variables) + parsed_variables = parse_key_value_variables(variables) + sql_variables = self._parse_execute_variables(parsed_variables) + python_variables = {str(v.key): v.value for v in parsed_variables} results = [] + + if any(file.endswith(".py") for file in sorted_file_path_list): + self._python_exe_procedure = self._bootstrap_snowpark_execution_environment( + stage_path_parts + ) + for file_path in sorted_file_path_list: - results.append( - self._call_execute_immediate( - stage_path_parts=stage_path_parts, - file_path=file_path, + file_stage_path = stage_path_parts.add_stage_prefix(file_path) + if file_path.endswith(".py"): + result = self._execute_python( + file_stage_path=file_stage_path, + on_error=on_error, + variables=python_variables, + ) + else: + result = self._call_execute_immediate( + file_stage_path=file_stage_path, variables=sql_variables, on_error=on_error, ) - ) + results.append(result) return results - def _get_files_list_from_stage(self, stage_path_parts: StagePathParts) -> List[str]: - files_list_result = self.list_files(stage_path_parts.stage).fetchall() + def _get_files_list_from_stage( + self, stage_path_parts: StagePathParts, pattern: str | None = None + ) -> List[str]: + files_list_result = self.list_files( + stage_path_parts.stage, pattern=pattern + ).fetchall() if not files_list_result: raise ClickException(f"No files found on stage '{stage_path_parts.stage}'") @@ -319,9 +361,8 @@ def _filter_files_list( return filtered_files else: raise ClickException( - "Invalid file extension, only `.sql` files are allowed." + f"Invalid file extension, only {', '.join(EXECUTE_SUPPORTED_FILES_FORMATS)} files are allowed." ) - # Filter with fnmatch if contains `*` or `?` if glob.has_magic(stage_path): filtered_files = fnmatch.filter(files_on_stage, stage_path) @@ -335,34 +376,42 @@ def _filter_supported_files(files: List[str]) -> List[str]: return [f for f in files if Path(f).suffix in EXECUTE_SUPPORTED_FILES_FORMATS] @staticmethod - def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]: + def _parse_execute_variables(variables: List[Variable]) -> Optional[str]: if not variables: return None - - parsed_variables = parse_key_value_variables(variables) - query_parameters = [f"{v.key}=>{v.value}" for v in parsed_variables] + query_parameters = [f"{v.key}=>{v.value}" for v in variables] return f" using ({', '.join(query_parameters)})" + @staticmethod + def _success_result(file: str): + cli_console.warning(f"SUCCESS - {file}") + return {"File": file, "Status": "SUCCESS", "Error": None} + + @staticmethod + def _error_result(file: str, msg: str): + cli_console.warning(f"FAILURE - {file}") + return {"File": file, "Status": "FAILURE", "Error": msg} + + @staticmethod + def _handle_execution_exception(on_error: OnErrorType, exception: Exception): + if on_error == OnErrorType.BREAK: + raise exception + def _call_execute_immediate( self, - stage_path_parts: StagePathParts, - file_path: str, + file_stage_path: str, variables: Optional[str], on_error: OnErrorType, ) -> Dict: - file_stage_path = stage_path_parts.add_stage_prefix(file_path) try: query = f"execute immediate from {file_stage_path}" if variables: query += variables self._execute_query(query) - cli_console.step(f"SUCCESS - {file_stage_path}") - return {"File": file_stage_path, "Status": "SUCCESS", "Error": None} + return StageManager._success_result(file=file_stage_path) except ProgrammingError as e: - cli_console.warning(f"FAILURE - {file_stage_path}") - if on_error == OnErrorType.BREAK: - raise e - return {"File": file_stage_path, "Status": "FAILURE", "Error": e.msg} + StageManager._handle_execution_exception(on_error=on_error, exception=e) + return StageManager._error_result(file=file_stage_path, msg=e.msg) @staticmethod def _stage_path_part_factory(stage_path: str) -> StagePathParts: @@ -370,3 +419,101 @@ def _stage_path_part_factory(stage_path: str) -> StagePathParts: if stage_path.startswith(USER_STAGE_PREFIX): return UserStagePathParts(stage_path) return DefaultStagePathParts(stage_path) + + def _check_for_requirements_file( + self, stage_path_parts: StagePathParts + ) -> List[str]: + """Looks for requirements.txt file on stage.""" + req_files_on_stage = self._get_files_list_from_stage( + stage_path_parts, pattern=r".*requirements\.txt$" + ) + if not req_files_on_stage: + return [] + + # Construct all possible path for requirements file for this context + # We don't use os.path or pathlib to preserve compatibility on Windows + req_file_name = "requirements.txt" + path_parts = stage_path_parts.path.split("/") + possible_req_files = [] + + while path_parts: + current_file = "/".join([*path_parts, req_file_name]) + possible_req_files.append(str(current_file)) + path_parts = path_parts[:-1] + + # Now for every possible path check if the file exists on stage, + # if yes break, we use the first possible file + requirements_file = None + for req_file in possible_req_files: + if req_file in req_files_on_stage: + requirements_file = req_file + break + + # If we haven't found any matching requirements + if requirements_file is None: + return [] + + # req_file at this moment is the first found requirements file + with TemporaryDirectory() as tmp_dir: + self.get( + stage_path_parts.get_full_stage_path(requirements_file), Path(tmp_dir) + ) + requirements = (Path(tmp_dir) / "requirements.txt").read_text().splitlines() + + return requirements + + def _bootstrap_snowpark_execution_environment( + self, stage_path_parts: StagePathParts + ): + """Prepares Snowpark session for executing Python code remotely.""" + if sys.version_info >= PYTHON_3_12: + raise ClickException( + f"Executing python files is not supported in Python >= 3.12. Current version: {sys.version}" + ) + + from snowflake.snowpark.functions import sproc + + self.snowpark_session.add_packages("snowflake-snowpark-python") + self.snowpark_session.add_packages("snowflake.core") + requirements = self._check_for_requirements_file(stage_path_parts) + for req in requirements: + self.snowpark_session.add_packages(req) + + @sproc(is_permanent=False) + def _python_execution_procedure( + _: Session, file_path: str, variables: Dict | None = None + ) -> None: + """Snowpark session-scoped stored procedure to execute content of provided python file.""" + import json + + from snowflake.snowpark.files import SnowflakeFile + + with SnowflakeFile.open(file_path, require_scoped_url=False) as f: + file_content: str = f.read() # type: ignore + + wrapper = dedent( + f"""\ + import os + os.environ.update({json.dumps(variables)}) + """ + ) + + exec(wrapper + file_content) + + return _python_execution_procedure + + def _execute_python( + self, file_stage_path: str, on_error: OnErrorType, variables: Dict + ): + """ + Executes Python file from stage using a Snowpark temporary procedure. + Currently, there's no option to pass input to the execution. + """ + from snowflake.snowpark.exceptions import SnowparkSQLException + + try: + self._python_exe_procedure(self.get_standard_stage_prefix(file_stage_path), variables) # type: ignore + return StageManager._success_result(file=file_stage_path) + except SnowparkSQLException as e: + StageManager._handle_execution_exception(on_error=on_error, exception=e) + return StageManager._error_result(file=file_stage_path, msg=e.message) diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 0ec9037f6c..c517dd736d 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -1412,9 +1412,18 @@ | --on-error [break|continue] What to do when an error occurs. | | Defaults to break. | | [default: break] | - | --variable -D TEXT Variables for the template. For | - | example: `-D "="`, string | - | values must be in `''`. | + | --variable -D TEXT Variables for the execution context. | + | For example: `-D "="`. For | + | SQL files variables are use to expand | + | the template and any unknown variable | + | will cause an error. For Python files | + | variables are used to update | + | os.environ dictionary. Provided keys | + | are capitalized to adhere to best | + | practices.In case of SQL files string | + | values must be quoted in `''` | + | (consider embedding quoting in the | + | file). | | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ +- Connection configuration ---------------------------------------------------+ @@ -6926,9 +6935,18 @@ | --on-error [break|continue] What to do when an error occurs. | | Defaults to break. | | [default: break] | - | --variable -D TEXT Variables for the template. For | - | example: `-D "="`, string | - | values must be in `''`. | + | --variable -D TEXT Variables for the execution context. | + | For example: `-D "="`. For | + | SQL files variables are use to expand | + | the template and any unknown variable | + | will cause an error. For Python files | + | variables are used to update | + | os.environ dictionary. Provided keys | + | are capitalized to adhere to best | + | practices.In case of SQL files string | + | values must be quoted in `''` | + | (consider embedding quoting in the | + | file). | | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ +- Connection configuration ---------------------------------------------------+ diff --git a/tests/stage/__snapshots__/test_stage.ambr b/tests/stage/__snapshots__/test_stage.ambr index a4b189c607..823f08b8ca 100644 --- a/tests/stage/__snapshots__/test_stage.ambr +++ b/tests/stage/__snapshots__/test_stage.ambr @@ -259,16 +259,20 @@ # --- # name: test_execute_continue_on_error ''' + SUCCESS - @exe/p1.py + FAILURE - @exe/p2.py SUCCESS - @exe/s1.sql FAILURE - @exe/s2.sql SUCCESS - @exe/s3.sql - +-------------------------------+ - | File | Status | Error | - |-------------+---------+-------| - | @exe/s1.sql | SUCCESS | None | - | @exe/s2.sql | FAILURE | Error | - | @exe/s3.sql | SUCCESS | None | - +-------------------------------+ + +------------------------------------+ + | File | Status | Error | + |-------------+---------+------------| + | @exe/p1.py | SUCCESS | None | + | @exe/p2.py | FAILURE | Test error | + | @exe/s1.sql | SUCCESS | None | + | @exe/s2.sql | FAILURE | Error | + | @exe/s3.sql | SUCCESS | None | + +------------------------------------+ ''' # --- @@ -336,7 +340,7 @@ # name: test_execute_raise_invalid_file_extension_error ''' +- Error ----------------------------------------------------------------------+ - | Invalid file extension, only `.sql` files are allowed. | + | Invalid file extension, only .sql, .py files are allowed. | +------------------------------------------------------------------------------+ ''' diff --git a/tests/stage/test_stage.py b/tests/stage/test_stage.py index b9c1d97764..76cb16add9 100644 --- a/tests/stage/test_stage.py +++ b/tests/stage/test_stage.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import sys from pathlib import Path from tempfile import TemporaryDirectory from unittest import mock @@ -24,6 +24,10 @@ STAGE_MANAGER = "snowflake.cli.plugins.stage.manager.StageManager" +skip_python_3_12 = pytest.mark.skipif( + sys.version_info >= (3, 12), reason="Snowpark is not supported in Python >= 3.12" +) + @mock.patch(f"{STAGE_MANAGER}._execute_query") def test_stage_list(mock_execute, runner, mock_cursor): @@ -779,8 +783,12 @@ def test_execute_from_user_stage( @mock.patch(f"{STAGE_MANAGER}._execute_query") -def test_execute_with_variables(mock_execute, mock_cursor, runner): - mock_execute.return_value = mock_cursor([{"name": "exe/s1.sql"}], []) +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +@skip_python_3_12 +def test_execute_with_variables(mock_bootstrap, mock_execute, mock_cursor, runner): + mock_execute.return_value = mock_cursor( + [{"name": "exe/s1.sql"}, {"name": "exe/s2.py"}], [] + ) result = runner.invoke( [ @@ -792,7 +800,7 @@ def test_execute_with_variables(mock_execute, mock_cursor, runner): "-D", "key2=1", "-D", - "key3=TRUE", + "KEY3=TRUE", "-D", "key4=NULL", "-D", @@ -804,9 +812,19 @@ def test_execute_with_variables(mock_execute, mock_cursor, runner): assert mock_execute.mock_calls == [ mock.call("ls @exe", cursor_class=DictCursor), mock.call( - f"execute immediate from @exe/s1.sql using (key1=>'string value', key2=>1, key3=>TRUE, key4=>NULL, key5=>'var=value')" + f"execute immediate from @exe/s1.sql using (key1=>'string value', key2=>1, KEY3=>TRUE, key4=>NULL, key5=>'var=value')" ), ] + mock_bootstrap.return_value.assert_called_once_with( + "@exe/s2.py", + { + "key1": "'string value'", + "key2": "1", + "KEY3": "TRUE", + "key4": "NULL", + "key5": "'var=value'", + }, + ) @mock.patch(f"{STAGE_MANAGER}._execute_query") @@ -897,13 +915,17 @@ def test_execute_no_files_for_stage_path( @mock.patch(f"{STAGE_MANAGER}._execute_query") -def test_execute_stop_on_error(mock_execute, mock_cursor, runner): +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +@skip_python_3_12 +def test_execute_stop_on_error(mock_bootstrap, mock_execute, mock_cursor, runner): error_message = "Error" mock_execute.side_effect = [ mock_cursor( [ {"name": "exe/s1.sql"}, + {"name": "exe/p1.py"}, {"name": "exe/s2.sql"}, + {"name": "exe/p2.py"}, {"name": "exe/s3.sql"}, ], [], @@ -920,16 +942,28 @@ def test_execute_stop_on_error(mock_execute, mock_cursor, runner): mock.call(f"execute immediate from @exe/s1.sql"), mock.call(f"execute immediate from @exe/s2.sql"), ] + assert mock_bootstrap.return_value.mock_calls == [ + mock.call("@exe/p1.py", {}), + mock.call("@exe/p2.py", {}), + ] assert e.value.msg == error_message @mock.patch(f"{STAGE_MANAGER}._execute_query") -def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot): +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +@skip_python_3_12 +def test_execute_continue_on_error( + mock_bootstrap, mock_execute, mock_cursor, runner, snapshot +): + from snowflake.snowpark.exceptions import SnowparkSQLException + mock_execute.side_effect = [ mock_cursor( [ {"name": "exe/s1.sql"}, + {"name": "exe/p1.py"}, {"name": "exe/s2.sql"}, + {"name": "exe/p2.py"}, {"name": "exe/s3.sql"}, ], [], @@ -939,6 +973,8 @@ def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot): mock_cursor([{"3": 3}], []), ] + mock_bootstrap.return_value.side_effect = ["ok", SnowparkSQLException("Test error")] + result = runner.invoke(["stage", "execute", "exe", "--on-error", "continue"]) assert result.exit_code == 0 @@ -950,6 +986,11 @@ def test_execute_continue_on_error(mock_execute, mock_cursor, runner, snapshot): mock.call(f"execute immediate from @exe/s3.sql"), ] + assert mock_bootstrap.return_value.mock_calls == [ + mock.call("@exe/p1.py", {}), + mock.call("@exe/p2.py", {}), + ] + @mock.patch("snowflake.connector.connect") @pytest.mark.parametrize( @@ -972,3 +1013,58 @@ def test_command_aliases(mock_connector, runner, mock_ctx, command, parameters): queries = ctx.get_queries() assert queries[0] == queries[1] + + +@pytest.mark.parametrize( + "files, selected, packages", + [ + ([], None, []), + (["my_stage/dir/parallel/requirements.txt"], None, []), + ( + ["my_stage/dir/files/requirements.txt"], + "db.schema.my_stage/dir/files/requirements.txt", + ["aaa", "bbb"], + ), + ( + [ + "my_stage/requirements.txt", + "my_stage/dir/requirements.txt", + "my_stage/dir/files/requirements.txt", + ], + "db.schema.my_stage/dir/files/requirements.txt", + ["aaa", "bbb"], + ), + ( + ["my_stage/requirements.txt"], + "db.schema.my_stage/requirements.txt", + ["aaa", "bbb"], + ), + ], +) +@pytest.mark.parametrize( + "input_path", ["@db.schema.my_stage/dir/files", "@db.schema.my_stage/dir/files/"] +) +def test_stage_manager_check_for_requirements_file( + files, selected, packages, input_path +): + class _MockGetter: + def __init__(self): + self.download_file = None + + def __call__(self, file_on_stage, target_dir): + self.download_file = file_on_stage + (Path(target_dir) / "requirements.txt").write_text("\n".join(packages)) + + get_mock = _MockGetter() + sm = StageManager() + with mock.patch.object( + sm, "_get_files_list_from_stage", lambda parts, pattern: files + ): + with mock.patch.object(StageManager, "get", get_mock) as get_mock: + result = sm._check_for_requirements_file( # noqa: SLF001 + stage_path_parts=sm._stage_path_part_factory(input_path) # noqa: SLF001 + ) + + assert result == packages + + assert get_mock.download_file == selected diff --git a/tests_integration/__snapshots__/test_stage.ambr b/tests_integration/__snapshots__/test_stage.ambr index f9a443fddd..a3f98c619d 100644 --- a/tests_integration/__snapshots__/test_stage.ambr +++ b/tests_integration/__snapshots__/test_stage.ambr @@ -77,6 +77,20 @@ }), ]) # --- +# name: test_stage_execute_python + list([ + dict({ + 'Error': None, + 'File': '@test_stage_execute/script1.py', + 'Status': 'SUCCESS', + }), + dict({ + 'Error': None, + 'File': '@test_stage_execute/script_template.py', + 'Status': 'SUCCESS', + }), + ]) +# --- # name: test_user_stage_execute list([ dict({ diff --git a/tests_integration/test_data/projects/stage_execute/requirements.txt b/tests_integration/test_data/projects/stage_execute/requirements.txt new file mode 100644 index 0000000000..997c65761f --- /dev/null +++ b/tests_integration/test_data/projects/stage_execute/requirements.txt @@ -0,0 +1,2 @@ +scikit-learn +matplotlib diff --git a/tests_integration/test_data/projects/stage_execute/script1.py b/tests_integration/test_data/projects/stage_execute/script1.py new file mode 100644 index 0000000000..c9f9b6de4b --- /dev/null +++ b/tests_integration/test_data/projects/stage_execute/script1.py @@ -0,0 +1 @@ +print("ok") diff --git a/tests_integration/test_data/projects/stage_execute/script_template.py b/tests_integration/test_data/projects/stage_execute/script_template.py new file mode 100644 index 0000000000..d1f7c80499 --- /dev/null +++ b/tests_integration/test_data/projects/stage_execute/script_template.py @@ -0,0 +1,19 @@ +import os +from snowflake.core import Root +from snowflake.core.database import DatabaseResource +from snowflake.core.schema import Schema +from snowflake.snowpark.session import Session + +session = Session.builder.getOrCreate() +database: DatabaseResource = Root(session).databases[os.environ["test_database_name"]] + +assert database.name.upper() == os.environ["test_database_name"].upper() + +# Make a side effect that we can check in tests +database.schemas.create(Schema(name=os.environ["TEST_ID"])) + +# Check if an external dependency works +from sklearn import show_versions +import matplotlib + +show_versions() diff --git a/tests_integration/test_snowpark.py b/tests_integration/test_snowpark.py index e6cfe7e08a..31d5a62f4a 100644 --- a/tests_integration/test_snowpark.py +++ b/tests_integration/test_snowpark.py @@ -14,6 +14,7 @@ from __future__ import annotations +import sys from pathlib import Path from textwrap import dedent @@ -715,6 +716,7 @@ def test_build_skip_version_check( @pytest.mark.integration +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Unknown issues") @pytest.mark.parametrize( "flags", [ @@ -760,6 +762,7 @@ def test_build_with_non_anaconda_dependencies( @pytest.mark.integration +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Unknown issues") def test_build_shared_libraries_error( runner, project_directory, alter_requirements_txt, test_database ): diff --git a/tests_integration/test_stage.py b/tests_integration/test_stage.py index c0f0288c17..8825ae9fd8 100644 --- a/tests_integration/test_stage.py +++ b/tests_integration/test_stage.py @@ -14,10 +14,13 @@ import glob import os +import sys import tempfile +import time from pathlib import Path import pytest +from snowflake.connector import DictCursor from tests_integration.test_utils import ( contains_row_with, @@ -291,6 +294,57 @@ def test_user_stage_execute(runner, test_database, test_root_path, snapshot): assert result.json == snapshot +@pytest.mark.integration +@pytest.mark.skipif( + sys.version_info >= (3, 12), reason="Snowpark is not supported in Python >= 3.12" +) +def test_stage_execute_python( + snowflake_session, runner, test_database, test_root_path, snapshot +): + project_path = test_root_path / "test_data/projects/stage_execute" + stage_name = "test_stage_execute" + + result = runner.invoke_with_connection_json(["stage", "create", stage_name]) + assert contains_row_with( + result.json, + {"status": f"Stage area {stage_name.upper()} successfully created."}, + ) + + files = ["script1.py", "script_template.py", "requirements.txt"] + for name in files: + result = runner.invoke_with_connection_json( + [ + "stage", + "copy", + f"{project_path}/{name}", + f"@{stage_name}", + ] + ) + assert result.exit_code == 0, result.output + assert contains_row_with(result.json, {"status": "UPLOADED"}) + + test_id = f"FOO{time.time_ns()}" + result = runner.invoke_with_connection_json( + [ + "stage", + "execute", + f"{stage_name}/", + "-D", + f"test_database_name={test_database}", + "-D", + f"TEST_ID={test_id}", + ] + ) + assert result.exit_code == 0 + assert result.json == snapshot + + # Assert side effect created by executed script + *_, schemas = snowflake_session.execute_string( + f"show schemas like '{test_id}' in database {test_database};" + ) + assert len(list(schemas)) == 1 + + @pytest.mark.integration def test_stage_diff(runner, snowflake_session, test_database, tmp_path, snapshot): stage_name = "test_stage"