Skip to content

Commit

Permalink
Add warehouse check in snowpark deploy (#1516)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek authored Sep 2, 2024
1 parent 41ebb8e commit cda9a6a
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/snowflake/cli/_plugins/snowpark/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
)


@app.command("deploy", requires_connection=True)
@app.command("deploy", requires_connection=True, require_warehouse=True)
@with_project_definition()
def deploy(
replace: bool = ReplaceOption(
Expand Down
11 changes: 9 additions & 2 deletions src/snowflake/cli/api/commands/snow_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import typer
from click import ClickException
from snowflake.cli.api.commands.decorators import (
global_options,
global_options_with_connection,
Expand All @@ -33,6 +34,7 @@
from snowflake.cli.api.exceptions import CommandReturnTypeError
from snowflake.cli.api.output.types import CommandResult
from snowflake.cli.api.sanitizers import sanitize_for_terminal
from snowflake.cli.api.sql_execution import SqlExecutionMixin

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,6 +73,7 @@ def command(
requires_global_options: bool = True,
requires_connection: bool = False,
is_enabled: Callable[[], bool] | None = None,
require_warehouse: bool = False,
**kwargs,
):
"""
Expand All @@ -97,7 +100,7 @@ def custom_command(command_callable):
def command_callable_decorator(*args, **kw):
"""Wrapper around command callable. This is what happens at "runtime"."""
execution = ExecutionMetadata()
self.pre_execute(execution)
self.pre_execute(execution, require_warehouse=require_warehouse)
try:
result = command_callable(*args, **kw)
self.process_result(result)
Expand All @@ -116,7 +119,7 @@ def command_callable_decorator(*args, **kw):
return custom_command

@staticmethod
def pre_execute(execution: ExecutionMetadata):
def pre_execute(execution: ExecutionMetadata, require_warehouse: bool = False):
"""
Callback executed before running any command callable (after context execution).
Pay attention to make this method safe to use if performed operations are not necessary
Expand All @@ -127,6 +130,10 @@ def pre_execute(execution: ExecutionMetadata):
log.debug("Executing command pre execution callback")
run_pre_execute_commands()
log_command_usage(execution)
if require_warehouse and not SqlExecutionMixin().session_has_warehouse():
raise ClickException(
"The command requires warehouse. No warehouse found in current connection."
)

@staticmethod
def process_result(result):
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/cli/api/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def use_role(self, new_role: str):
if is_different_role:
self._execute_query(f"use role {prev_role}")

def session_has_warehouse(self) -> bool:
result = self._execute_query(
"select current_warehouse() is not null as result", cursor_class=DictCursor
).fetchone()
return bool(result.get("RESULT"))

@contextmanager
def use_warehouse(self, new_wh: str):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/api/commands/test_snow_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def class_factory(
):
class _CustomTyper(SnowTyper):
@staticmethod
def pre_execute(execution):
def pre_execute(execution, require_warehouse):
if pre_execute:
pre_execute(execution)

Expand Down
15 changes: 15 additions & 0 deletions tests/snowpark/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@
pytest.skip("Requires further refactor to work on Windows", allow_module_level=True)


mock_session_has_warehouse = mock.patch(
"snowflake.cli.api.sql_execution.SqlExecutionMixin.session_has_warehouse",
lambda _: True,
)


@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager")
@mock_session_has_warehouse
def test_deploy_function(
mock_object_manager,
mock_connector,
Expand Down Expand Up @@ -72,6 +79,7 @@ def test_deploy_function(

@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager")
@mock_session_has_warehouse
def test_deploy_function_with_external_access(
mock_object_manager,
mock_connector,
Expand Down Expand Up @@ -122,6 +130,7 @@ def test_deploy_function_with_external_access(

@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager")
@mock_session_has_warehouse
def test_deploy_function_secrets_without_external_access(
mock_object_manager,
mock_conn,
Expand Down Expand Up @@ -150,6 +159,7 @@ def test_deploy_function_secrets_without_external_access(


@mock.patch("snowflake.connector.connect")
@mock_session_has_warehouse
def test_deploy_function_no_changes(
mock_connector,
runner,
Expand Down Expand Up @@ -190,6 +200,7 @@ def test_deploy_function_no_changes(


@mock.patch("snowflake.connector.connect")
@mock_session_has_warehouse
def test_deploy_function_needs_update_because_packages_changes(
mock_connector,
runner,
Expand Down Expand Up @@ -240,6 +251,7 @@ def test_deploy_function_needs_update_because_packages_changes(


@mock.patch("snowflake.connector.connect")
@mock_session_has_warehouse
def test_deploy_function_needs_update_because_handler_changes(
mock_connector,
runner,
Expand Down Expand Up @@ -293,6 +305,7 @@ def test_deploy_function_needs_update_because_handler_changes(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_function_fully_qualified_name_duplicated_database(
mock_om_show,
mock_om_describe,
Expand All @@ -318,6 +331,7 @@ def test_deploy_function_fully_qualified_name_duplicated_database(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_function_fully_qualified_name_duplicated_schema(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -348,6 +362,7 @@ def test_deploy_function_fully_qualified_name_duplicated_schema(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_function_fully_qualified_name(
mock_om_show,
mock_om_describe,
Expand Down
31 changes: 31 additions & 0 deletions tests/snowpark/test_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
pytest.skip("Requires further refactor to work on Windows", allow_module_level=True)


mock_session_has_warehouse = mock.patch(
"snowflake.cli.api.sql_execution.SqlExecutionMixin.session_has_warehouse",
lambda _: True,
)


@mock_session_has_warehouse
def test_deploy_function_no_procedure(runner, project_directory):
with project_directory("empty_project"):
result = runner.invoke(
Expand All @@ -48,6 +55,7 @@ def test_deploy_function_no_procedure(runner, project_directory):
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -111,6 +119,7 @@ def test_deploy_procedure(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_with_external_access(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -171,6 +180,7 @@ def test_deploy_procedure_with_external_access(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_secrets_without_external_access(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -204,6 +214,7 @@ def test_deploy_procedure_secrets_without_external_access(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_fails_if_integration_does_not_exists(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -239,6 +250,7 @@ def test_deploy_procedure_fails_if_integration_does_not_exists(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_fails_if_object_exists_and_no_replace(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -271,6 +283,7 @@ def test_deploy_procedure_fails_if_object_exists_and_no_replace(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_replace_nothing_to_update(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -326,6 +339,7 @@ def test_deploy_procedure_replace_nothing_to_update(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_replace_updates_single_object(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -379,6 +393,7 @@ def test_deploy_procedure_replace_updates_single_object(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_replace_creates_missing_object(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -424,6 +439,7 @@ def test_deploy_procedure_replace_creates_missing_object(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_fully_qualified_name(
mock_om_show,
mock_om_describe,
Expand All @@ -449,6 +465,7 @@ def test_deploy_procedure_fully_qualified_name(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.describe")
@mock.patch("snowflake.cli._plugins.snowpark.commands.ObjectManager.show")
@mock_session_has_warehouse
def test_deploy_procedure_fully_qualified_name_duplicated_schema(
mock_om_show,
mock_om_describe,
Expand Down Expand Up @@ -516,3 +533,17 @@ def test_command_aliases(mock_connector, runner, mock_ctx, command, parameters):

queries = ctx.get_queries()
assert queries[0] == queries[1]


@mock.patch(
"snowflake.cli.api.sql_execution.SqlExecutionMixin.session_has_warehouse",
lambda _: False,
)
def test_snowpark_fail_if_no_active_warehouse(runner, mock_ctx, project_directory):
with project_directory("snowpark_procedures"):
result = runner.invoke(["snowpark", "deploy"])
assert result.exit_code == 1, result.output
assert (
"The command requires warehouse. No warehouse found in current connection."
in result.output
)

0 comments on commit cda9a6a

Please sign in to comment.