Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute user script refactor #1695

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
DenyAlwaysPolicy,
PolicyBase,
)
from snowflake.cli._plugins.nativeapp.sf_facade import get_snowflake_facade
from snowflake.cli._plugins.nativeapp.utils import needs_confirmation
from snowflake.cli._plugins.stage.diff import DiffResult
from snowflake.cli._plugins.stage.manager import StageManager
Expand Down Expand Up @@ -429,16 +430,16 @@ def deploy(
console.warning(e.message)
if not policy.should_proceed("Proceed with using this package?"):
raise typer.Abort() from e
with get_sql_executor().use_role(package_role):
cls.apply_package_scripts(
console=console,
package_scripts=package_scripts,
package_warehouse=package_warehouse,
project_root=project_root,
package_role=package_role,
package_name=package_name,
)

cls.apply_package_scripts(
console=console,
package_scripts=package_scripts,
package_warehouse=package_warehouse,
project_root=project_root,
package_role=package_role,
package_name=package_name,
)
with get_sql_executor().use_role(package_role):
# 3. Upload files from deploy root local folder to the above stage
stage_schema = extract_schema(stage_fqn)
diff = sync_deploy_root_with_stage(
Expand Down Expand Up @@ -1093,15 +1094,12 @@ def apply_package_scripts(
)

# once we're sure all the templates expanded correctly, execute all of them
with cls.use_package_warehouse(
package_warehouse=package_warehouse,
):
try:
for i, queries in enumerate(queued_queries):
console.step(f"Applying package script: {package_scripts[i]}")
get_sql_executor().execute_queries(queries)
except ProgrammingError as err:
generic_sql_error_handler(err)
for i, queries in enumerate(queued_queries):
script_name = package_scripts[i]
console.step(f"Applying package script: {script_name}")
get_snowflake_facade().execute_user_script(
queries, script_name, package_role, package_warehouse
)

@classmethod
def create_app_package(
Expand All @@ -1115,7 +1113,7 @@ def create_app_package(
Creates the application package with our up-to-date stage if none exists.
"""

# 1. Check for existing existing application package
# 1. Check for existing application package
show_obj_row = cls.get_existing_app_pkg_info(
package_name=package_name,
package_role=package_role,
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/cli/_plugins/nativeapp/sf_facade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from snowflake.cli._plugins.nativeapp.sf_sql_facade import SnowflakeSQLFacade


def get_snowflake_facade() -> SnowflakeSQLFacade:
"""Returns a Snowflake Facade"""
return SnowflakeSQLFacade()
204 changes: 204 additions & 0 deletions src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from contextlib import contextmanager

from click import ClickException
from cryptography.utils import cached_property
from snowflake.cli.api.constants import ObjectType
from snowflake.cli.api.entities.common import get_sql_executor
from snowflake.cli.api.errno import (
DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED,
NO_WAREHOUSE_SELECTED_IN_SESSION,
)
from snowflake.cli.api.exceptions import CouldNotUseObjectError
from snowflake.cli.api.project.util import to_identifier
from snowflake.cli.api.sql_execution import SqlExecutor
from snowflake.connector import DatabaseError, ProgrammingError


class UnknownSQLError(DatabaseError):
"""Exception raised when the root of the SQL error is unidentified by us."""

exit_code = 3

def __init__(self, msg):
self.msg = f"Unknown SQL error occurred. {msg}"
super().__init__(self.msg)

def __str__(self):
return self.msg


class UserScriptError(ClickException):
"""Exception raised when user-provided scripts fail."""

def __init__(self, script_name, msg):
super().__init__(f"Failed to run script {script_name}. {msg}")


class SnowflakeSQLFacade:
def __init__(self, sql_executor: SqlExecutor | None = None):
self._sql_executor = (
sql_executor if sql_executor is not None else get_sql_executor()
)

@cached_property
def _log(self):
return logging.getLogger(__name__)

def _use_object(self, object_type: ObjectType, name: str):
"""
Call sql to use snowflake object with error handling
@param object_type: ObjectType, type of snowflake object to use
@param name: object name, has to be a valid snowflake identifier.
"""
try:
self._sql_executor.execute_query(f"use {object_type} {name}")
except ProgrammingError as err:
if err.errno == DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED:
raise CouldNotUseObjectError(object_type, name) from err

raise ProgrammingError(f"Failed to use {object_type} {name}") from err
except Exception as err:
raise UnknownSQLError(f"Failed to use {object_type} {name}") from err

@contextmanager
def _use_warehouse_optional(self, new_wh: str | None):
"""
Switches to a different warehouse for a while, then switches back.
This is a no-op if the requested warehouse is already active or if no warehouse is passed in.
@param new_wh: Name of the warehouse to use. If not a valid Snowflake identifier, will be converted before use.
"""
if new_wh is None:
yield
return

valid_wh_name = to_identifier(new_wh)

wh_result = self._sql_executor.execute_query(
f"select current_warehouse()"
).fetchone()
Comment on lines +91 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what should be the behaviour if this fails? It shouldn't fail, I know, but we should control the contract if it ever does. Right now we'd probably just let a ProgrammingError bubble up. Is that what we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with our error classifications now bubbbling it up as a ProgrammingError is fair (categorizes it as not 100% the user's fault). Would you say if we get a failure here it's fair to assume we need act on it and see what's wrong?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a weird state for sure if something fails, but I'm not convinced that ProgrammingError is the right contract. We know what failed here, and it's not wrong SQL.


# If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended.
try:
prev_wh = wh_result[0]
except:
prev_wh = None
# new_wh is not None, and should already be a valid identifier, no additional check is performed here.
is_different_wh = valid_wh_name != prev_wh
if is_different_wh:
self._log.debug(f"Using warehouse: {valid_wh_name}")
self._use_object(ObjectType.WAREHOUSE, valid_wh_name)
try:
yield
finally:
if is_different_wh and prev_wh is not None:
self._log.debug(f"Switching back to warehouse:{prev_wh}")
self._use_object(ObjectType.WAREHOUSE, prev_wh)

@contextmanager
def _use_role_optional(self, new_role: str | None):
"""
Switches to a different role for a while, then switches back.
This is a no-op if the requested role is already active or if no role is passed in.
@param new_role: Name of the role to use. If not a valid Snowflake identifier, will be converted before use.
"""
if new_role is None:
yield
return

valid_role_name = to_identifier(new_role)

prev_role = self._sql_executor.current_role()

is_different_role = valid_role_name.lower() != prev_role.lower()
if is_different_role:
self._log.debug(f"Assuming different role: {valid_role_name}")
self._use_object(ObjectType.ROLE, valid_role_name)
try:
yield
finally:
if is_different_role:
self._log.debug(f"Switching back to role:{prev_role}")
self._use_object(ObjectType.ROLE, prev_role)

@contextmanager
def _use_database_optional(self, database_name: str | None):
"""
Switch to database `database_name`, then switches back.
This is a no-op if the requested database is already selected or if no database_name is passed in.
@param database_name: Name of the database to use. If not a valid Snowflake identifier, will be converted before use.
"""

if database_name is None:
yield
return

valid_name = to_identifier(database_name)

db_result = self._sql_executor.execute_query(
f"select current_database()"
).fetchone()
try:
prev_db = db_result[0]
except:
prev_db = None

is_different_db = valid_name != prev_db
if is_different_db:
self._log.debug(f"Using database {valid_name}")
self._use_object(ObjectType.DATABASE, valid_name)

try:
yield
finally:
if is_different_db and prev_db is not None:
self._log.debug(f"Switching back to database:{prev_db}")
self._use_object(ObjectType.DATABASE, prev_db)

def execute_user_script(
self,
queries: str,
script_name: str,
role: str | None = None,
warehouse: str | None = None,
database: str | None = None,
):
"""
Runs the user-provided sql script.
@param queries: Queries to run in this script
@param script_name: Name of the file containing the script. Used to show logs to the user.
@param [Optional] role: Role to switch to while running this script. Current role will be used if no role is passed in.
@param [Optional] warehouse: Warehouse to use while running this script.
@param [Optional] database: Database to use while running this script.
"""
with self._use_role_optional(role):
with self._use_warehouse_optional(warehouse):
with self._use_database_optional(database):
try:
self._sql_executor.execute_queries(queries)
except ProgrammingError as err:
if err.errno == NO_WAREHOUSE_SELECTED_IN_SESSION:
raise UserScriptError(
script_name,
f"{err.msg}. Please provide a warehouse in your project definition file, config.toml file, or via command line",
) from err
else:
raise UserScriptError(script_name, err.msg) from err
except Exception as err:
raise UnknownSQLError(
f"Failed to run script {script_name}"
) from err
Loading
Loading