From a44f54a8332b1ac92c3ed0b1542f369a33488b08 Mon Sep 17 00:00:00 2001 From: Michel El Nacouzi Date: Wed, 18 Dec 2024 17:00:41 -0500 Subject: [PATCH 1/2] Add release channels add-accounts remove-accounts commands (#1955) --- RELEASE-NOTES.md | 1 + .../nativeapp/entities/application.py | 32 +-- .../nativeapp/entities/application_package.py | 179 ++++++++---- .../nativeapp/release_channel/commands.py | 69 +++++ .../cli/_plugins/nativeapp/sf_sql_facade.py | 96 ++++++- src/snowflake/cli/api/entities/common.py | 2 + src/snowflake/cli/api/errno.py | 1 + tests/__snapshots__/test_help_messages.ambr | 218 +++++++++++++- .../test_application_package_entity.py | 267 ++++++++++++++++-- tests/nativeapp/test_run_processor.py | 2 +- tests/nativeapp/test_sf_sql_facade.py | 183 ++++++++++++ tests/nativeapp/utils.py | 6 + 12 files changed, 943 insertions(+), 113 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index bdf2072b11..03b8dd9d7e 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -28,6 +28,7 @@ * `snow app version create` now returns version, patch, and label in JSON format. * Add ability to specify release channel when creating application instance from release directive: `snow app run --from-release-directive --channel=` * Add ability to list release channels through `snow app release-channel list` command +* Add ability to add and remove accounts from release channels through `snow app release-channel add-accounts` and snow app release-channel remove-accounts` commands. ## Fixes and improvements * Fixed crashes with older x86_64 Intel CPUs. diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application.py b/src/snowflake/cli/_plugins/nativeapp/entities/application.py index bd2096ba1f..b43d4ee8ea 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application.py @@ -26,7 +26,6 @@ from snowflake.cli._plugins.nativeapp.constants import ( ALLOWED_SPECIAL_COMMENTS, COMMENT_COL, - DEFAULT_CHANNEL, OWNER_COL, ) from snowflake.cli._plugins.nativeapp.entities.application_package import ( @@ -86,8 +85,6 @@ append_test_resource_suffix, extract_schema, identifier_for_url, - identifier_in_list, - same_identifiers, to_identifier, unquote_identifier, ) @@ -360,8 +357,8 @@ def action_deploy( # same-account release directive if from_release_directive: - release_channel = _get_verified_release_channel( - package_entity, release_channel + release_channel = package_entity.get_sanitized_release_channel( + release_channel ) self.create_or_upgrade_app( @@ -1025,28 +1022,3 @@ def _application_objects_to_str( def _application_object_to_str(obj: ApplicationOwnedObject) -> str: return f"({obj['type']}) {obj['name']}" - - -def _get_verified_release_channel( - package_entity: ApplicationPackageEntity, - release_channel: Optional[str], -) -> Optional[str]: - release_channel = release_channel or DEFAULT_CHANNEL - available_release_channels = get_snowflake_facade().show_release_channels( - package_entity.name, role=package_entity.role - ) - if available_release_channels: - release_channel_names = [c["name"] for c in available_release_channels] - if not identifier_in_list(release_channel, release_channel_names): - raise UsageError( - f"Release channel '{release_channel}' is not available for application package {package_entity.name}. Available release channels: ({', '.join(release_channel_names)})." - ) - else: - if same_identifiers(release_channel, DEFAULT_CHANNEL): - return None - else: - raise UsageError( - f"Release channels are not enabled for application package {package_entity.name}." - ) - - return release_channel diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index af95ce2e10..389a5a8c82 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -8,7 +8,7 @@ from typing import Any, List, Literal, Optional, Set, Union import typer -from click import BadOptionUsage, ClickException +from click import BadOptionUsage, ClickException, UsageError from pydantic import Field, field_validator from snowflake.cli._plugins.connection.util import UIParameter from snowflake.cli._plugins.nativeapp.artifacts import ( @@ -97,7 +97,6 @@ VALID_IDENTIFIER_REGEX, append_test_resource_suffix, extract_schema, - identifier_in_list, identifier_to_show_like_pattern, same_identifiers, sql_match, @@ -622,6 +621,69 @@ def action_version_drop( f"Version {version} in application package {self.name} dropped successfully." ) + def _validate_target_accounts(self, accounts: list[str]) -> None: + """ + Validates the target accounts provided by the user. + """ + for account in accounts: + if not re.fullmatch( + f"{VALID_IDENTIFIER_REGEX}\\.{VALID_IDENTIFIER_REGEX}", account + ): + raise ClickException( + f"Target account {account} is not in a valid format. Make sure you provide the target account in the format 'org.account'." + ) + + def get_sanitized_release_channel( + self, release_channel: Optional[str] + ) -> Optional[str]: + """ + Sanitize the release channel name provided by the user and validate it against the available release channels. + + A return value of None indicates that release channels should not be used. Returns None if: + - Release channel is not provided + - Release channels are not enabled in the application package and the user provided the default release channel + """ + if not release_channel: + return None + + available_release_channels = get_snowflake_facade().show_release_channels( + self.name, self.role + ) + + if not available_release_channels and same_identifiers( + release_channel, DEFAULT_CHANNEL + ): + return None + + self.validate_release_channel(release_channel, available_release_channels) + return release_channel + + def validate_release_channel( + self, + release_channel: str, + available_release_channels: Optional[list[ReleaseChannel]] = None, + ) -> None: + """ + Validates the release channel provided by the user and make sure it is a valid release channel for the application package. + """ + + if available_release_channels is None: + available_release_channels = get_snowflake_facade().show_release_channels( + self.name, self.role + ) + if not available_release_channels: + raise UsageError( + f"Release channels are not enabled for application package {self.name}." + ) + for channel in available_release_channels: + if same_identifiers(release_channel, channel["name"]): + return + + raise UsageError( + f"Release channel {release_channel} is not available in application package {self.name}. " + f"Available release channels are: ({', '.join(channel['name'] for channel in available_release_channels)})." + ) + def action_release_directive_list( self, action_ctx: ActionContext, @@ -636,25 +698,7 @@ def action_release_directive_list( If `like` is provided, only release directives matching the SQL LIKE pattern are listed. """ - available_release_channels = get_snowflake_facade().show_release_channels( - self.name, self.role - ) - - # assume no release channel used if user selects default channel and release channels are not enabled - if ( - release_channel - and same_identifiers(release_channel, DEFAULT_CHANNEL) - and not available_release_channels - ): - release_channel = None - - release_channel_names = [c.get("name") for c in available_release_channels] - if release_channel and not identifier_in_list( - release_channel, release_channel_names - ): - raise ClickException( - f"Release channel {release_channel} does not exist in application package {self.name}." - ) + release_channel = self.get_sanitized_release_channel(release_channel) release_directives = get_snowflake_facade().show_release_directives( package_name=self.name, @@ -686,13 +730,7 @@ def action_release_directive_set( For non-default release directives, update the existing release directive if target accounts are not provided. """ if target_accounts: - for account in target_accounts: - if not re.fullmatch( - f"{VALID_IDENTIFIER_REGEX}\\.{VALID_IDENTIFIER_REGEX}", account - ): - raise ClickException( - f"Target account {account} is not in a valid format. Make sure you provide the target account in the format 'org.account'." - ) + self._validate_target_accounts(target_accounts) if target_accounts and same_identifiers(release_directive, DEFAULT_DIRECTIVE): raise BadOptionUsage( @@ -700,18 +738,7 @@ def action_release_directive_set( "Target accounts can only be specified for non-default named release directives.", ) - available_release_channels = get_snowflake_facade().show_release_channels( - self.name, self.role - ) - - release_channel_names = [c.get("name") for c in available_release_channels] - - if not same_identifiers( - release_channel, DEFAULT_CHANNEL - ) and not identifier_in_list(release_channel, release_channel_names): - raise ClickException( - f"Release channel {release_channel} does not exist in application package {self.name}." - ) + sanitized_release_channel = self.get_sanitized_release_channel(release_channel) if ( not same_identifiers(release_directive, DEFAULT_DIRECTIVE) @@ -722,7 +749,7 @@ def action_release_directive_set( get_snowflake_facade().modify_release_directive( package_name=self.name, release_directive=release_directive, - release_channel=release_channel, + release_channel=sanitized_release_channel, version=version, patch=patch, role=self.role, @@ -731,7 +758,7 @@ def action_release_directive_set( get_snowflake_facade().set_release_directive( package_name=self.name, release_directive=release_directive, - release_channel=release_channel if available_release_channels else None, + release_channel=sanitized_release_channel, target_accounts=target_accounts, version=version, patch=patch, @@ -739,7 +766,10 @@ def action_release_directive_set( ) def action_release_directive_unset( - self, action_ctx: ActionContext, release_directive: str, release_channel: str + self, + action_ctx: ActionContext, + release_directive: str, + release_channel: str, ): """ Unsets a release directive from the specified release channel. @@ -749,21 +779,10 @@ def action_release_directive_unset( "Cannot unset default release directive. Please specify a non-default release directive." ) - available_release_channels = get_snowflake_facade().show_release_channels( - self.name, self.role - ) - release_channel_names = [c.get("name") for c in available_release_channels] - if not same_identifiers( - release_channel, DEFAULT_CHANNEL - ) and not identifier_in_list(release_channel, release_channel_names): - raise ClickException( - f"Release channel {release_channel} does not exist in application package {self.name}." - ) - get_snowflake_facade().unset_release_directive( package_name=self.name, release_directive=release_directive, - release_channel=release_channel if available_release_channels else None, + release_channel=self.get_sanitized_release_channel(release_channel), role=self.role, ) @@ -861,6 +880,56 @@ def _bundle(self, action_ctx: ActionContext = None): return bundle_map + def action_release_channel_add_accounts( + self, + action_ctx: ActionContext, + release_channel: str, + target_accounts: list[str], + *args, + **kwargs, + ): + """ + Adds target accounts to a release channel. + """ + + if not target_accounts: + raise ClickException("No target accounts provided.") + + self.validate_release_channel(release_channel) + self._validate_target_accounts(target_accounts) + + get_snowflake_facade().add_accounts_to_release_channel( + package_name=self.name, + release_channel=release_channel, + target_accounts=target_accounts, + role=self.role, + ) + + def action_release_channel_remove_accounts( + self, + action_ctx: ActionContext, + release_channel: str, + target_accounts: list[str], + *args, + **kwargs, + ): + """ + Removes target accounts from a release channel. + """ + + if not target_accounts: + raise ClickException("No target accounts provided.") + + self.validate_release_channel(release_channel) + self._validate_target_accounts(target_accounts) + + get_snowflake_facade().remove_accounts_from_release_channel( + package_name=self.name, + release_channel=release_channel, + target_accounts=target_accounts, + role=self.role, + ) + def _bundle_children(self, action_ctx: ActionContext) -> List[str]: # Create _children directory children_artifacts_dir = self.children_artifacts_deploy_root diff --git a/src/snowflake/cli/_plugins/nativeapp/release_channel/commands.py b/src/snowflake/cli/_plugins/nativeapp/release_channel/commands.py index 4614f6e06b..54c484c768 100644 --- a/src/snowflake/cli/_plugins/nativeapp/release_channel/commands.py +++ b/src/snowflake/cli/_plugins/nativeapp/release_channel/commands.py @@ -30,6 +30,7 @@ from snowflake.cli.api.output.types import ( CollectionResult, CommandResult, + MessageResult, ) app = SnowTyperFactory( @@ -69,3 +70,71 @@ def release_channel_list( if cli_context.output_format == OutputFormat.JSON: return CollectionResult(channels) + + +@app.command("add-accounts", requires_connection=True) +@with_project_definition() +@force_project_definition_v2() +def release_channel_add_accounts( + channel: str = typer.Argument( + show_default=False, + help="The release channel to add accounts to.", + ), + target_accounts: list[str] = typer.Option( + show_default=False, + help="The accounts to add to the release channel. Format has to be `org1.account1,org2.account2`.", + ), + **options, +) -> CommandResult: + """ + Adds accounts to a release channel. + """ + + cli_context = get_cli_context() + ws = WorkspaceManager( + project_definition=cli_context.project_definition, + project_root=cli_context.project_root, + ) + package_id = options["package_entity_id"] + ws.perform_action( + package_id, + EntityActions.RELEASE_CHANNEL_ADD_ACCOUNTS, + release_channel=channel, + target_accounts=target_accounts, + ) + + return MessageResult("Successfully added accounts to the release channel.") + + +@app.command("remove-accounts", requires_connection=True) +@with_project_definition() +@force_project_definition_v2() +def release_channel_remove_accounts( + channel: str = typer.Argument( + show_default=False, + help="The release channel to remove accounts from.", + ), + target_accounts: list[str] = typer.Option( + show_default=False, + help="The accounts to remove from the release channel. Format has to be `org1.account1,org2.account2`.", + ), + **options, +) -> CommandResult: + """ + Removes accounts from a release channel. + """ + + cli_context = get_cli_context() + ws = WorkspaceManager( + project_definition=cli_context.project_definition, + project_root=cli_context.project_root, + ) + package_id = options["package_entity_id"] + ws.perform_action( + package_id, + EntityActions.RELEASE_CHANNEL_REMOVE_ACCOUNTS, + release_channel=channel, + target_accounts=target_accounts, + ) + + return MessageResult("Successfully removed accounts from the release channel.") diff --git a/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py b/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py index 4c30ead793..914c7b1605 100644 --- a/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py +++ b/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py @@ -53,6 +53,7 @@ APPLICATION_REQUIRES_TELEMETRY_SHARING, CANNOT_DISABLE_MANDATORY_TELEMETRY, CANNOT_DISABLE_RELEASE_CHANNELS, + CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS, DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, DOES_NOT_EXIST_OR_NOT_AUTHORIZED, INSUFFICIENT_PRIVILEGES, @@ -1159,7 +1160,8 @@ def show_release_channels( ) -> list[ReleaseChannel]: """ Show release channels in a package. - @param package_name: Name of the package + + @param package_name: Name of the application package @param [Optional] role: Role to switch to while running this script. Current role will be used if no role is passed in. """ @@ -1208,6 +1210,98 @@ def show_release_channels( return results + def add_accounts_to_release_channel( + self, + package_name: str, + release_channel: str, + target_accounts: List[str], + role: str | None = None, + ): + """ + Adds accounts to a release channel. + + @param package_name: Name of the application package + @param release_channel: Name of the release channel + @param target_accounts: List of target accounts to add to the release channel + @param [Optional] role: Role to switch to while running this script. Current role will be used if no role is passed in. + """ + + package_name = to_identifier(package_name) + release_channel = to_identifier(release_channel) + + with self._use_role_optional(role): + try: + self._sql_executor.execute_query( + f"alter application package {package_name} modify release channel {release_channel} add accounts = ({','.join(target_accounts)})" + ) + except ProgrammingError as err: + if ( + err.errno == ACCOUNT_DOES_NOT_EXIST + or err.errno == ACCOUNT_HAS_TOO_MANY_QUALIFIERS + ): + raise UserInputError( + f"Invalid account passed in.\n{str(err.msg)}" + ) from err + if err.errno == CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS: + raise UserInputError( + f"Cannot modify accounts for release channel {release_channel} in application package {package_name}." + ) from err + handle_unclassified_error( + err, + f"Failed to add accounts to release channel {release_channel} in application package {package_name}.", + ) + except Exception as err: + handle_unclassified_error( + err, + f"Failed to add accounts to release channel {release_channel} in application package {package_name}.", + ) + + def remove_accounts_from_release_channel( + self, + package_name: str, + release_channel: str, + target_accounts: List[str], + role: str | None = None, + ): + """ + Removes accounts from a release channel. + + @param package_name: Name of the application package + @param release_channel: Name of the release channel + @param target_accounts: List of target accounts to remove from the release channel + @param [Optional] role: Role to switch to while running this script. Current role will be used if no role is passed in. + """ + + package_name = to_identifier(package_name) + release_channel = to_identifier(release_channel) + + with self._use_role_optional(role): + try: + self._sql_executor.execute_query( + f"alter application package {package_name} modify release channel {release_channel} remove accounts = ({','.join(target_accounts)})" + ) + except ProgrammingError as err: + if ( + err.errno == ACCOUNT_DOES_NOT_EXIST + or err.errno == ACCOUNT_HAS_TOO_MANY_QUALIFIERS + ): + raise UserInputError( + f"Invalid account passed in.\n{str(err.msg)}" + ) from err + if err.errno == CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS: + raise UserInputError( + f"Cannot modify accounts for release channel {release_channel} in application package {package_name}." + ) from err + handle_unclassified_error( + err, + f"Failed to remove accounts from release channel {release_channel} in application package {package_name}.", + ) + except Exception as err: + handle_unclassified_error( + err, + f"Failed to remove accounts from release channel {release_channel} in application package {package_name}.", + ) + def _strip_empty_lines(text: str) -> str: """ diff --git a/src/snowflake/cli/api/entities/common.py b/src/snowflake/cli/api/entities/common.py index c444dc0897..cbedb87825 100644 --- a/src/snowflake/cli/api/entities/common.py +++ b/src/snowflake/cli/api/entities/common.py @@ -22,6 +22,8 @@ class EntityActions(str, Enum): RELEASE_DIRECTIVE_LIST = "action_release_directive_list" RELEASE_CHANNEL_LIST = "action_release_channel_list" + RELEASE_CHANNEL_ADD_ACCOUNTS = "action_release_channel_add_accounts" + RELEASE_CHANNEL_REMOVE_ACCOUNTS = "action_release_channel_remove_accounts" RELEASE_CHANNEL_ADD_VERSION = "action_release_channel_add_version" RELEASE_CHANNEL_REMOVE_VERSION = "action_release_channel_remove_version" diff --git a/src/snowflake/cli/api/errno.py b/src/snowflake/cli/api/errno.py index 88796f2590..c13fdef719 100644 --- a/src/snowflake/cli/api/errno.py +++ b/src/snowflake/cli/api/errno.py @@ -68,6 +68,7 @@ VERSION_DOES_NOT_EXIST = 93031 ACCOUNT_DOES_NOT_EXIST = 1999 ACCOUNT_HAS_TOO_MANY_QUALIFIERS = 906 +CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS = 512017 ERR_JAVASCRIPT_EXECUTION = 100132 diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 91f09a4acf..8583b2599a 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -578,6 +578,110 @@ +------------------------------------------------------------------------------+ + ''' +# --- +# name: test_help_messages[app.release-channel.add-accounts] + ''' + + Usage: default app release-channel add-accounts [OPTIONS] CHANNEL + + Adds accounts to a release channel. + + +- Arguments ------------------------------------------------------------------+ + | * channel TEXT The release channel to add accounts to. | + | [required] | + +------------------------------------------------------------------------------+ + +- Options --------------------------------------------------------------------+ + | * --target-accounts TEXT The accounts to add to the release | + | channel. Format has to be | + | org1.account1,org2.account2. | + | [required] | + | --package-entity-id TEXT The ID of the package entity on which | + | to operate when definition_version is | + | 2 or higher. | + | --app-entity-id TEXT The ID of the application entity on | + | which to operate when | + | definition_version is 2 or higher. | + | --project -p TEXT Path where Snowflake project resides. | + | Defaults to current working directory. | + | --env TEXT String in format of key=value. | + | Overrides variables from env section | + | used for templates. | + | --help -h Show this message and exit. | + +------------------------------------------------------------------------------+ + +- Connection configuration ---------------------------------------------------+ + | --connection,--environment -c TEXT Name of the connection, as | + | defined in your config.toml | + | file. Default: default. | + | --host TEXT Host address for the | + | connection. Overrides the | + | value specified for the | + | connection. | + | --port INTEGER Port for the connection. | + | Overrides the value | + | specified for the | + | connection. | + | --account,--accountname TEXT Name assigned to your | + | Snowflake account. Overrides | + | the value specified for the | + | connection. | + | --user,--username TEXT Username to connect to | + | Snowflake. Overrides the | + | value specified for the | + | connection. | + | --password TEXT Snowflake password. | + | Overrides the value | + | specified for the | + | connection. | + | --authenticator TEXT Snowflake authenticator. | + | Overrides the value | + | specified for the | + | connection. | + | --private-key-file,--privateā€¦ TEXT Snowflake private key file | + | path. Overrides the value | + | specified for the | + | connection. | + | --token-file-path TEXT Path to file with an OAuth | + | token that should be used | + | when connecting to Snowflake | + | --database,--dbname TEXT Database to use. Overrides | + | the value specified for the | + | connection. | + | --schema,--schemaname TEXT Database schema to use. | + | Overrides the value | + | specified for the | + | connection. | + | --role,--rolename TEXT Role to use. Overrides the | + | value specified for the | + | connection. | + | --warehouse TEXT Warehouse to use. Overrides | + | the value specified for the | + | connection. | + | --temporary-connection -x Uses connection defined with | + | command line parameters, | + | instead of one defined in | + | config | + | --mfa-passcode TEXT Token to use for | + | multi-factor authentication | + | (MFA) | + | --enable-diag Run Python connector | + | diagnostic test | + | --diag-log-path TEXT Diagnostic report path | + | --diag-allowlist-path TEXT Diagnostic report path to | + | optional allowlist | + +------------------------------------------------------------------------------+ + +- Global configuration -------------------------------------------------------+ + | --format [TABLE|JSON] Specifies the output format. | + | [default: TABLE] | + | --verbose -v Displays log entries for log levels info | + | and higher. | + | --debug Displays log entries for log levels debug | + | and higher; debug logs contain additional | + | information. | + | --silent Turns off intermediate output to console. | + +------------------------------------------------------------------------------+ + + ''' # --- # name: test_help_messages[app.release-channel.list] @@ -678,6 +782,110 @@ +------------------------------------------------------------------------------+ + ''' +# --- +# name: test_help_messages[app.release-channel.remove-accounts] + ''' + + Usage: default app release-channel remove-accounts [OPTIONS] CHANNEL + + Removes accounts from a release channel. + + +- Arguments ------------------------------------------------------------------+ + | * channel TEXT The release channel to remove accounts from. | + | [required] | + +------------------------------------------------------------------------------+ + +- Options --------------------------------------------------------------------+ + | * --target-accounts TEXT The accounts to remove from the | + | release channel. Format has to be | + | org1.account1,org2.account2. | + | [required] | + | --package-entity-id TEXT The ID of the package entity on which | + | to operate when definition_version is | + | 2 or higher. | + | --app-entity-id TEXT The ID of the application entity on | + | which to operate when | + | definition_version is 2 or higher. | + | --project -p TEXT Path where Snowflake project resides. | + | Defaults to current working directory. | + | --env TEXT String in format of key=value. | + | Overrides variables from env section | + | used for templates. | + | --help -h Show this message and exit. | + +------------------------------------------------------------------------------+ + +- Connection configuration ---------------------------------------------------+ + | --connection,--environment -c TEXT Name of the connection, as | + | defined in your config.toml | + | file. Default: default. | + | --host TEXT Host address for the | + | connection. Overrides the | + | value specified for the | + | connection. | + | --port INTEGER Port for the connection. | + | Overrides the value | + | specified for the | + | connection. | + | --account,--accountname TEXT Name assigned to your | + | Snowflake account. Overrides | + | the value specified for the | + | connection. | + | --user,--username TEXT Username to connect to | + | Snowflake. Overrides the | + | value specified for the | + | connection. | + | --password TEXT Snowflake password. | + | Overrides the value | + | specified for the | + | connection. | + | --authenticator TEXT Snowflake authenticator. | + | Overrides the value | + | specified for the | + | connection. | + | --private-key-file,--privateā€¦ TEXT Snowflake private key file | + | path. Overrides the value | + | specified for the | + | connection. | + | --token-file-path TEXT Path to file with an OAuth | + | token that should be used | + | when connecting to Snowflake | + | --database,--dbname TEXT Database to use. Overrides | + | the value specified for the | + | connection. | + | --schema,--schemaname TEXT Database schema to use. | + | Overrides the value | + | specified for the | + | connection. | + | --role,--rolename TEXT Role to use. Overrides the | + | value specified for the | + | connection. | + | --warehouse TEXT Warehouse to use. Overrides | + | the value specified for the | + | connection. | + | --temporary-connection -x Uses connection defined with | + | command line parameters, | + | instead of one defined in | + | config | + | --mfa-passcode TEXT Token to use for | + | multi-factor authentication | + | (MFA) | + | --enable-diag Run Python connector | + | diagnostic test | + | --diag-log-path TEXT Diagnostic report path | + | --diag-allowlist-path TEXT Diagnostic report path to | + | optional allowlist | + +------------------------------------------------------------------------------+ + +- Global configuration -------------------------------------------------------+ + | --format [TABLE|JSON] Specifies the output format. | + | [default: TABLE] | + | --verbose -v Displays log entries for log levels info | + | and higher. | + | --debug Displays log entries for log levels debug | + | and higher; debug logs contain additional | + | information. | + | --silent Turns off intermediate output to console. | + +------------------------------------------------------------------------------+ + + ''' # --- # name: test_help_messages[app.release-channel] @@ -691,7 +899,10 @@ | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ +- Commands -------------------------------------------------------------------+ - | list Lists the release channels available for an application package. | + | add-accounts Adds accounts to a release channel. | + | list Lists the release channels available for an application | + | package. | + | remove-accounts Removes accounts from a release channel. | +------------------------------------------------------------------------------+ @@ -10253,7 +10464,10 @@ | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ +- Commands -------------------------------------------------------------------+ - | list Lists the release channels available for an application package. | + | add-accounts Adds accounts to a release channel. | + | list Lists the release channels available for an application | + | package. | + | remove-accounts Removes accounts from a release channel. | +------------------------------------------------------------------------------+ diff --git a/tests/nativeapp/test_application_package_entity.py b/tests/nativeapp/test_application_package_entity.py index bbab822962..1185e7d700 100644 --- a/tests/nativeapp/test_application_package_entity.py +++ b/tests/nativeapp/test_application_package_entity.py @@ -20,7 +20,7 @@ import pytest import pytz import yaml -from click import ClickException +from click import ClickException, UsageError from snowflake.cli._plugins.connection.util import UIParameter from snowflake.cli._plugins.nativeapp.constants import ( LOOSE_FILES_MAGIC_VERSION, @@ -38,8 +38,10 @@ APP_PACKAGE_ENTITY, APPLICATION_PACKAGE_ENTITY_MODULE, SQL_EXECUTOR_EXECUTE, + SQL_FACADE_ADD_ACCOUNTS_TO_RELEASE_CHANNEL, SQL_FACADE_GET_UI_PARAMETER, SQL_FACADE_MODIFY_RELEASE_DIRECTIVE, + SQL_FACADE_REMOVE_ACCOUNTS_FROM_RELEASE_CHANNEL, SQL_FACADE_SET_RELEASE_DIRECTIVE, SQL_FACADE_SHOW_RELEASE_CHANNELS, SQL_FACADE_SHOW_RELEASE_DIRECTIVES, @@ -226,9 +228,6 @@ def test_given_channels_disabled_and_no_directives_when_release_directive_list_t ) assert result == [] - show_release_channels.assert_called_once_with( - pkg_model.fqn.name, pkg_model.meta.role - ) show_release_directives.assert_called_once_with( package_name=pkg_model.fqn.name, @@ -254,9 +253,6 @@ def test_given_channels_disabled_and_directives_present_when_release_directive_l ) assert result == [{"name": "my_directive"}] - show_release_channels.assert_called_once_with( - pkg_model.fqn.name, pkg_model.meta.role - ) show_release_directives.assert_called_once_with( package_name=pkg_model.fqn.name, @@ -285,9 +281,6 @@ def test_given_multiple_directives_and_like_pattern_when_release_directive_list_ ) assert result == [{"name": "abcdef"}] - show_release_channels.assert_called_once_with( - pkg_model.fqn.name, pkg_model.meta.role - ) show_release_directives.assert_called_once_with( package_name=pkg_model.fqn.name, @@ -314,10 +307,6 @@ def test_given_channels_enabled_and_no_channel_specified_when_release_directive_ assert result == [{"name": "my_directive"}] - show_release_channels.assert_called_once_with( - pkg_model.fqn.name, pkg_model.meta.role - ) - show_release_directives.assert_called_once_with( package_name=pkg_model.fqn.name, role=pkg_model.meta.role, @@ -366,14 +355,14 @@ def test_given_channels_disabled_and_non_default_channel_selected_when_release_d pkg_model = application_package_entity._entity_model # noqa SLF001 pkg_model.meta.role = "package_role" - with pytest.raises(ClickException) as e: + with pytest.raises(UsageError) as e: application_package_entity.action_release_directive_list( action_ctx=action_context, release_channel="non_default", like="%%" ) assert ( str(e.value) - == f"Release channel non_default does not exist in application package {pkg_model.fqn.name}." + == f"Release channels are not enabled for application package {pkg_model.fqn.name}." ) show_release_channels.assert_called_once_with( pkg_model.fqn.name, pkg_model.meta.role @@ -394,14 +383,14 @@ def test_given_channels_enabled_and_invalid_channel_selected_when_release_direct pkg_model = application_package_entity._entity_model # noqa SLF001 pkg_model.meta.role = "package_role" - with pytest.raises(ClickException) as e: + with pytest.raises(UsageError) as e: application_package_entity.action_release_directive_list( action_ctx=action_context, release_channel="invalid_channel", like="%%" ) assert ( str(e.value) - == f"Release channel invalid_channel does not exist in application package {pkg_model.fqn.name}." + == f"Release channel invalid_channel is not available in application package {pkg_model.fqn.name}. Available release channels are: (my_channel)." ) show_release_channels.assert_called_once_with( pkg_model.fqn.name, pkg_model.meta.role @@ -555,7 +544,7 @@ def test_given_no_channels_with_non_default_channel_used_when_release_directive_ pkg_model = application_package_entity._entity_model # noqa SLF001 pkg_model.meta.role = "package_role" - with pytest.raises(ClickException) as e: + with pytest.raises(UsageError) as e: application_package_entity.action_release_directive_set( action_ctx=action_context, version="1.0", @@ -567,7 +556,7 @@ def test_given_no_channels_with_non_default_channel_used_when_release_directive_ assert ( str(e.value) - == f"Release channel non_default does not exist in application package {pkg_model.fqn.name}." + == f"Release channels are not enabled for application package {pkg_model.fqn.name}." ) show_release_channels.assert_called_once_with( @@ -776,7 +765,7 @@ def test_given_channels_disabled_and_non_default_channel_selected_when_release_d pkg_model = application_package_entity._entity_model # noqa SLF001 pkg_model.meta.role = "package_role" - with pytest.raises(ClickException) as e: + with pytest.raises(UsageError) as e: application_package_entity.action_release_directive_unset( action_ctx=action_context, release_channel="non_default", @@ -785,7 +774,7 @@ def test_given_channels_disabled_and_non_default_channel_selected_when_release_d assert ( str(e.value) - == f"Release channel non_default does not exist in application package {pkg_model.fqn.name}." + == f"Release channels are not enabled for application package {pkg_model.fqn.name}." ) show_release_channels.assert_called_once_with( @@ -806,7 +795,7 @@ def test_given_channels_enabled_and_non_existing_channel_selected_when_release_d pkg_model = application_package_entity._entity_model # noqa SLF001 pkg_model.meta.role = "package_role" - with pytest.raises(ClickException) as e: + with pytest.raises(UsageError) as e: application_package_entity.action_release_directive_unset( action_ctx=action_context, release_channel="non_existing", @@ -815,7 +804,7 @@ def test_given_channels_enabled_and_non_existing_channel_selected_when_release_d assert ( str(e.value) - == f"Release channel non_existing does not exist in application package {pkg_model.fqn.name}." + == f"Release channel non_existing is not available in application package {pkg_model.fqn.name}. Available release channels are: (my_channel)." ) show_release_channels.assert_called_once_with( @@ -1022,3 +1011,233 @@ def test_given_release_channels_with_a_selected_channel_to_filter_when_list_rele assert result == [test_channel_1] assert capsys.readouterr().out == os_agnostic_snapshot + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_ADD_ACCOUNTS_TO_RELEASE_CHANNEL) +def test_given_release_channel_and_accounts_when_add_accounts_to_release_channel_then_success( + add_accounts_to_release_channel, + show_release_channels, + application_package_entity, + action_context, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [{"name": "test_channel"}] + + application_package_entity.action_release_channel_add_accounts( + action_ctx=action_context, + release_channel="test_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + add_accounts_to_release_channel.assert_called_once_with( + package_name=pkg_model.fqn.name, + role=pkg_model.meta.role, + release_channel="test_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_ADD_ACCOUNTS_TO_RELEASE_CHANNEL) +def test_given_release_channels_disabled_when_add_accounts_to_release_channel_then_error( + add_accounts_to_release_channel, + show_release_channels, + application_package_entity, + action_context, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [] + + with pytest.raises(UsageError) as e: + application_package_entity.action_release_channel_add_accounts( + action_ctx=action_context, + release_channel="invalid_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + assert ( + str(e.value) + == f"Release channels are not enabled for application package {pkg_model.fqn.name}." + ) + + add_accounts_to_release_channel.assert_not_called() + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_ADD_ACCOUNTS_TO_RELEASE_CHANNEL) +def test_given_invalid_release_channel_when_add_accounts_to_release_channel_then_error( + add_accounts_to_release_channel, + show_release_channels, + application_package_entity, + action_context, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [{"name": "test_channel"}] + + with pytest.raises(UsageError) as e: + application_package_entity.action_release_channel_add_accounts( + action_ctx=action_context, + release_channel="invalid_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + assert ( + str(e.value) + == f"Release channel invalid_channel is not available in application package {pkg_model.fqn.name}. Available release channels are: (test_channel)." + ) + + add_accounts_to_release_channel.assert_not_called() + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_ADD_ACCOUNTS_TO_RELEASE_CHANNEL) +@pytest.mark.parametrize( + "account_name", ["org1", "org1.", ".account1", "org1.acc.ount1"] +) +def test_given_invalid_account_names_when_add_accounts_to_release_channel_then_error( + add_accounts_to_release_channel, + show_release_channels, + application_package_entity, + action_context, + account_name, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [{"name": "test_channel"}] + + with pytest.raises(ClickException) as e: + application_package_entity.action_release_channel_add_accounts( + action_ctx=action_context, + release_channel="test_channel", + target_accounts=[account_name], + ) + + assert ( + str(e.value) + == f"Target account {account_name} is not in a valid format. Make sure you provide the target account in the format 'org.account'." + ) + + add_accounts_to_release_channel.assert_not_called() + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_REMOVE_ACCOUNTS_FROM_RELEASE_CHANNEL) +def test_given_release_channel_and_accounts_when_remove_accounts_from_release_channel_then_success( + remove_accounts_from_release_channel, + show_release_channels, + application_package_entity, + action_context, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [{"name": "test_channel"}] + + application_package_entity.action_release_channel_remove_accounts( + action_ctx=action_context, + release_channel="test_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + remove_accounts_from_release_channel.assert_called_once_with( + package_name=pkg_model.fqn.name, + role=pkg_model.meta.role, + release_channel="test_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_REMOVE_ACCOUNTS_FROM_RELEASE_CHANNEL) +def test_given_release_channel_disabled_when_remove_accounts_from_release_channel_then_error( + remove_accounts_from_release_channel, + show_release_channels, + application_package_entity, + action_context, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [] + + with pytest.raises(UsageError) as e: + application_package_entity.action_release_channel_remove_accounts( + action_ctx=action_context, + release_channel="invalid_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + assert ( + str(e.value) + == f"Release channels are not enabled for application package {pkg_model.fqn.name}." + ) + + remove_accounts_from_release_channel.assert_not_called() + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_REMOVE_ACCOUNTS_FROM_RELEASE_CHANNEL) +def test_given_invalid_release_channel_when_remove_accounts_from_release_channel_then_error( + remove_accounts_from_release_channel, + show_release_channels, + application_package_entity, + action_context, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [{"name": "test_channel"}] + + with pytest.raises(UsageError) as e: + application_package_entity.action_release_channel_remove_accounts( + action_ctx=action_context, + release_channel="invalid_channel", + target_accounts=["org1.acc1", "org2.acc2"], + ) + + assert ( + str(e.value) + == f"Release channel invalid_channel is not available in application package {pkg_model.fqn.name}. Available release channels are: (test_channel)." + ) + + remove_accounts_from_release_channel.assert_not_called() + + +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch(SQL_FACADE_REMOVE_ACCOUNTS_FROM_RELEASE_CHANNEL) +@pytest.mark.parametrize( + "account_name", ["org1", "org1.", ".account1", "org1.acc.ount1"] +) +def test_given_invalid_account_names_when_remove_accounts_from_release_channel_then_error( + remove_accounts_from_release_channel, + show_release_channels, + application_package_entity, + action_context, + account_name, +): + pkg_model = application_package_entity._entity_model # noqa SLF001 + pkg_model.meta.role = "package_role" + + show_release_channels.return_value = [{"name": "test_channel"}] + + with pytest.raises(ClickException) as e: + application_package_entity.action_release_channel_remove_accounts( + action_ctx=action_context, + release_channel="test_channel", + target_accounts=[account_name], + ) + + assert ( + str(e.value) + == f"Target account {account_name} is not in a valid format. Make sure you provide the target account in the format 'org.account'." + ) + + remove_accounts_from_release_channel.assert_not_called() diff --git a/tests/nativeapp/test_run_processor.py b/tests/nativeapp/test_run_processor.py index 3b478bad2e..90aa4b7cf8 100644 --- a/tests/nativeapp/test_run_processor.py +++ b/tests/nativeapp/test_run_processor.py @@ -2506,7 +2506,7 @@ def test_run_app_from_release_directive_with_channel_not_in_list( assert ( str(err.value) - == "Release channel 'unknown_channel' is not available for application package app_pkg. Available release channels: (channel1, channel2)." + == "Release channel unknown_channel is not available in application package app_pkg. Available release channels are: (channel1, channel2)." ) mock_sql_facade_upgrade_application.assert_not_called() mock_sql_facade_create_application.assert_not_called() diff --git a/tests/nativeapp/test_sf_sql_facade.py b/tests/nativeapp/test_sf_sql_facade.py index 7d678c669d..a5493c8192 100644 --- a/tests/nativeapp/test_sf_sql_facade.py +++ b/tests/nativeapp/test_sf_sql_facade.py @@ -51,6 +51,7 @@ APPLICATION_REQUIRES_TELEMETRY_SHARING, CANNOT_DISABLE_MANDATORY_TELEMETRY, CANNOT_DISABLE_RELEASE_CHANNELS, + CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS, DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, DOES_NOT_EXIST_OR_NOT_AUTHORIZED, INSUFFICIENT_PRIVILEGES, @@ -3713,3 +3714,185 @@ def test_drop_version_from_package_with_error( sql_facade.drop_version_from_package( package_name=package_name, version=version, role=role ) + + +def test_add_accounts_to_release_channel_valid_input_then_success( + mock_use_role, mock_execute_query +): + package_name = "test_package" + release_channel = "test_channel" + accounts = ["org1.acc1", "org2.acc2"] + role = "test_role" + + expected_use_objects = [ + (mock_use_role, mock.call(role)), + ] + expected_execute_query = [ + ( + mock_execute_query, + mock.call( + "alter application package test_package modify release channel test_channel add accounts = (org1.acc1,org2.acc2)" + ), + ), + ] + + with assert_in_context(expected_use_objects, expected_execute_query): + sql_facade.add_accounts_to_release_channel( + package_name, release_channel, accounts, role + ) + + +def test_add_accounts_to_release_channel_with_special_chars_in_names( + mock_use_role, mock_execute_query +): + package_name = "test.package" + release_channel = "test.channel" + accounts = ["org1.acc1", "org2.acc2"] + role = "test_role" + + expected_use_objects = [ + (mock_use_role, mock.call(role)), + ] + expected_execute_query = [ + ( + mock_execute_query, + mock.call( + 'alter application package "test.package" modify release channel "test.channel" add accounts = (org1.acc1,org2.acc2)' + ), + ), + ] + + with assert_in_context(expected_use_objects, expected_execute_query): + sql_facade.add_accounts_to_release_channel( + package_name, release_channel, accounts, role + ) + + +@pytest.mark.parametrize( + "error_raised, error_caught, error_message", + [ + ( + ProgrammingError(errno=ACCOUNT_DOES_NOT_EXIST), + UserInputError, + "Invalid account passed in.", + ), + ( + ProgrammingError(errno=ACCOUNT_HAS_TOO_MANY_QUALIFIERS), + UserInputError, + "Invalid account passed in.", + ), + ( + ProgrammingError(errno=CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS), + UserInputError, + "Cannot modify accounts for release channel test_channel in application package test_package.", + ), + ( + ProgrammingError(), + InvalidSQLError, + "Failed to add accounts to release channel test_channel in application package test_package.", + ), + ], +) +@mock.patch(SQL_EXECUTOR_EXECUTE) +def test_add_accounts_to_release_channel_error( + mock_execute_query, error_raised, error_caught, error_message, mock_use_role +): + mock_execute_query.side_effect = error_raised + + with pytest.raises(error_caught) as err: + sql_facade.add_accounts_to_release_channel( + "test_package", "test_channel", ["org1.acc1"], "test_role" + ) + + assert error_message in str(err) + + +def test_remove_accounts_from_release_channel_valid_input_then_success( + mock_use_role, mock_execute_query +): + package_name = "test_package" + release_channel = "test_channel" + accounts = ["org1.acc1", "org2.acc2"] + role = "test_role" + + expected_use_objects = [ + (mock_use_role, mock.call(role)), + ] + expected_execute_query = [ + ( + mock_execute_query, + mock.call( + "alter application package test_package modify release channel test_channel remove accounts = (org1.acc1,org2.acc2)" + ), + ), + ] + + with assert_in_context(expected_use_objects, expected_execute_query): + sql_facade.remove_accounts_from_release_channel( + package_name, release_channel, accounts, role + ) + + +def test_remove_accounts_from_release_channel_with_special_chars_in_names( + mock_use_role, mock_execute_query +): + package_name = "test.package" + release_channel = "test.channel" + accounts = ["org1.acc1", "org2.acc2"] + role = "test_role" + + expected_use_objects = [ + (mock_use_role, mock.call(role)), + ] + expected_execute_query = [ + ( + mock_execute_query, + mock.call( + 'alter application package "test.package" modify release channel "test.channel" remove accounts = (org1.acc1,org2.acc2)' + ), + ), + ] + + with assert_in_context(expected_use_objects, expected_execute_query): + sql_facade.remove_accounts_from_release_channel( + package_name, release_channel, accounts, role + ) + + +@pytest.mark.parametrize( + "error_raised, error_caught, error_message", + [ + ( + ProgrammingError(errno=ACCOUNT_DOES_NOT_EXIST), + UserInputError, + "Invalid account passed in.", + ), + ( + ProgrammingError(errno=ACCOUNT_HAS_TOO_MANY_QUALIFIERS), + UserInputError, + "Invalid account passed in.", + ), + ( + ProgrammingError(errno=CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS), + UserInputError, + "Cannot modify accounts for release channel test_channel in application package test_package.", + ), + ( + ProgrammingError(), + InvalidSQLError, + "Failed to remove accounts from release channel test_channel in application package test_package.", + ), + ], +) +@mock.patch(SQL_EXECUTOR_EXECUTE) +def test_remove_accounts_from_release_channel_error( + mock_execute_query, error_raised, error_caught, error_message, mock_use_role +): + mock_execute_query.side_effect = error_raised + + with pytest.raises(error_caught) as err: + sql_facade.remove_accounts_from_release_channel( + "test_package", "test_channel", ["org1.acc1"], "test_role" + ) + + assert error_message in str(err) diff --git a/tests/nativeapp/utils.py b/tests/nativeapp/utils.py index dcaae91080..d3f7ee4d32 100644 --- a/tests/nativeapp/utils.py +++ b/tests/nativeapp/utils.py @@ -96,6 +96,12 @@ SQL_FACADE_SHOW_RELEASE_CHANNELS = f"{SQL_FACADE}.show_release_channels" SQL_FACADE_DROP_VERSION = f"{SQL_FACADE}.drop_version_from_package" SQL_FACADE_CREATE_VERSION = f"{SQL_FACADE}.create_version_in_package" +SQL_FACADE_ADD_ACCOUNTS_TO_RELEASE_CHANNEL = ( + f"{SQL_FACADE}.add_accounts_to_release_channel" +) +SQL_FACADE_REMOVE_ACCOUNTS_FROM_RELEASE_CHANNEL = ( + f"{SQL_FACADE}.remove_accounts_from_release_channel" +) mock_snowflake_yml_file = dedent( """\ From 3009e8bcaf3de529046b054c4dd611deeee43e09 Mon Sep 17 00:00:00 2001 From: Jan Sikorski <132985823+sfc-gh-jsikorski@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:54:59 +0100 Subject: [PATCH 2/2] Add streamlit entities (#1934) * Actions * fix * text fix * Fixes * Fix Co-authored-by: Patryk Czajka --- .../nativeapp/entities/application_package.py | 2 + .../cli/_plugins/nativeapp/feature_flags.py | 4 +- .../_plugins/streamlit/streamlit_entity.py | 160 ++++++++++++--- tests/nativeapp/test_children.py | 6 +- .../streamlit/__snapshots__/test_actions.ambr | 49 +++++ .../__snapshots__/test_commands.ambr | 36 +--- .../__snapshots__/test_streamlit_entity.ambr | 60 ++++++ tests/streamlit/test_commands.py | 2 +- tests/streamlit/test_streamlit_entity.py | 192 +++++++++++++++--- .../example_streamlit_v2/snowflake.yml | 12 +- .../example_streamlit_v2/utils/utils.py | 0 11 files changed, 416 insertions(+), 107 deletions(-) create mode 100644 tests/streamlit/__snapshots__/test_actions.ambr create mode 100644 tests/streamlit/__snapshots__/test_streamlit_entity.ambr create mode 100644 tests/test_data/projects/example_streamlit_v2/utils/utils.py diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index 389a5a8c82..453d9fe103 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -966,6 +966,8 @@ def _bundle_children(self, action_ctx: ActionContext) -> List[str]: child_entity.get_deploy_sql( artifacts_dir=child_artifacts_dir.relative_to(self.deploy_root), schema=child_schema, + # TODO Allow users to override the hard-coded value for specific children + replace=True, ) ) if app_role: diff --git a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py index dc7e93bf51..498c430c2c 100644 --- a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py +++ b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py @@ -18,7 +18,9 @@ @unique -class FeatureFlag(FeatureFlagMixin): +class FeatureFlag( + FeatureFlagMixin +): # TODO move this to snowflake.cli.api.feature_flags ENABLE_NATIVE_APP_PYTHON_SETUP = BooleanFlag( "ENABLE_NATIVE_APP_PYTHON_SETUP", False ) diff --git a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py index 6b187ba54b..c0c6786822 100644 --- a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py +++ b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py @@ -1,23 +1,20 @@ +import functools from pathlib import Path from typing import Optional -from snowflake.cli._plugins.nativeapp.artifacts import build_bundle -from snowflake.cli._plugins.nativeapp.entities.application_package_child_interface import ( - ApplicationPackageChildInterface, -) +from click import ClickException +from snowflake.cli._plugins.connection.util import make_snowsight_url from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, ) -from snowflake.cli.api.entities.common import EntityBase -from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import PathMapping +from snowflake.cli._plugins.workspace.context import ActionContext +from snowflake.cli.api.entities.common import EntityBase, get_sql_executor +from snowflake.cli.api.secure_path import SecurePath +from snowflake.connector.cursor import SnowflakeCursor -# WARNING: This entity is not implemented yet. The logic below is only for demonstrating the -# required interfaces for composability (used by ApplicationPackageEntity behind a feature flag). -class StreamlitEntity( - EntityBase[StreamlitEntityModel], ApplicationPackageChildInterface -): +class StreamlitEntity(EntityBase[StreamlitEntityModel]): """ A Streamlit app. """ @@ -28,43 +25,140 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @property - def project_root(self) -> Path: + def root(self): return self._workspace_ctx.project_root @property - def deploy_root(self) -> Path: - return self.project_root / "output" / "deploy" + def artifacts(self): + return self._entity_model.artifacts - def action_bundle( - self, - *args, - **kwargs, - ): + @functools.cached_property + def _sql_executor(self): + return get_sql_executor() + + @functools.cached_property + def _conn(self): + return self._sql_executor._conn # noqa + + @property + def model(self): + return self._entity_model # noqa + + def action_bundle(self, action_ctx: ActionContext, *args, **kwargs): return self.bundle() - def bundle(self, bundle_root=None): - return build_bundle( - self.project_root, - bundle_root or self.deploy_root, - [ - PathMapping(src=str(artifact)) - for artifact in self._entity_model.artifacts - ], + def action_deploy(self, action_ctx: ActionContext, *args, **kwargs): + # After adding bundle map- we should use it's mapping here + # To copy artifacts to destination on stage. + + return self._sql_executor.execute_query(self.get_deploy_sql()) + + def action_drop(self, action_ctx: ActionContext, *args, **kwargs): + return self._sql_executor.execute_query(self.get_drop_sql()) + + def action_execute( + self, action_ctx: ActionContext, *args, **kwargs + ) -> SnowflakeCursor: + return self._sql_executor.execute_query(self.get_execute_sql()) + + def action_get_url( + self, action_ctx: ActionContext, *args, **kwargs + ): # maybe this should be a property + name = self._entity_model.fqn.using_connection(self._conn) + return make_snowsight_url( + self._conn, f"/#/streamlit-apps/{name.url_identifier}" ) + def bundle(self, output_dir: Optional[Path] = None): + + if not output_dir: + output_dir = self.root / "output" / self._entity_model.stage + + artifacts = self._entity_model.artifacts + + output_dir.mkdir(parents=True, exist_ok=True) # type: ignore + + output_files = [] + + # This is far from , but will be replaced by bundlemap mappings. + for file in artifacts: + output_file = output_dir / file.name + + if file.is_file(): + SecurePath(file).copy(output_file) + elif file.is_dir(): + output_file.mkdir(parents=True, exist_ok=True) + SecurePath(file).copy(output_file, dirs_exist_ok=True) + + output_files.append(output_file) + + return output_files + + def action_share( + self, action_ctx: ActionContext, to_role: str, *args, **kwargs + ) -> SnowflakeCursor: + return self._sql_executor.execute_query(self.get_share_sql(to_role)) + def get_deploy_sql( self, + if_not_exists: bool = False, + replace: bool = False, + from_stage_name: Optional[str] = None, artifacts_dir: Optional[Path] = None, schema: Optional[str] = None, + *args, + **kwargs, ): - entity_id = self.entity_id - if artifacts_dir: - streamlit_name = f"{schema}.{entity_id}" if schema else entity_id - return f"CREATE OR REPLACE STREAMLIT {streamlit_name} FROM '{artifacts_dir}' MAIN_FILE='{self._entity_model.main_file}';" + if replace and if_not_exists: + raise ClickException("Cannot specify both replace and if_not_exists") + + if replace: + query = "CREATE OR REPLACE " + elif if_not_exists: + query = "CREATE IF NOT EXISTS " else: - return f"CREATE OR REPLACE STREAMLIT {entity_id} MAIN_FILE='{self._entity_model.main_file}';" + query = "CREATE " + + schema_to_use = schema or self._entity_model.fqn.schema + query += f"STREAMLIT {self._entity_model.fqn.set_schema(schema_to_use).sql_identifier}" + + if from_stage_name: + query += f"\nROOT_LOCATION = '{from_stage_name}'" + elif artifacts_dir: + query += f"\nFROM '{artifacts_dir}'" + + query += f"\nMAIN_FILE = '{self._entity_model.main_file}'" + + if self.model.imports: + query += "\n" + self.model.get_imports_sql() + + if self.model.query_warehouse: + query += f"\nQUERY_WAREHOUSE = '{self.model.query_warehouse}'" + + if self.model.title: + query += f"\nTITLE = '{self.model.title}'" + + if self.model.comment: + query += f"\nCOMMENT = '{self.model.comment}'" + + if self.model.external_access_integrations: + query += "\n" + self.model.get_external_access_integrations_sql() + + if self.model.secrets: + query += "\n" + self.model.get_secrets_sql() + + return query + ";" + + def get_drop_sql(self): + return f"DROP STREAMLIT {self._entity_model.fqn};" + + def get_execute_sql(self): + return f"EXECUTE STREAMLIT {self._entity_model.fqn}();" + + def get_share_sql(self, to_role: str) -> str: + return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};" - def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None): + def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None) -> str: entity_id = self.entity_id streamlit_name = f"{schema}.{entity_id}" if schema else entity_id return ( diff --git a/tests/nativeapp/test_children.py b/tests/nativeapp/test_children.py index fca85666e3..77e44eaee8 100644 --- a/tests/nativeapp/test_children.py +++ b/tests/nativeapp/test_children.py @@ -143,10 +143,12 @@ def test_children_bundle_with_custom_dir(project_directory): dedent( f""" -- AUTO GENERATED CHILDREN SECTION - CREATE OR REPLACE STREAMLIT v_schema.my_streamlit FROM '{custom_dir_path}' MAIN_FILE='streamlit_app.py'; + CREATE OR REPLACE STREAMLIT IDENTIFIER('v_schema.my_streamlit') + FROM '{custom_dir_path}' + MAIN_FILE = 'streamlit_app.py'; CREATE APPLICATION ROLE IF NOT EXISTS my_app_role; GRANT USAGE ON SCHEMA v_schema TO APPLICATION ROLE my_app_role; GRANT USAGE ON STREAMLIT v_schema.my_streamlit TO APPLICATION ROLE my_app_role; - """ +""" ) ) diff --git a/tests/streamlit/__snapshots__/test_actions.ambr b/tests/streamlit/__snapshots__/test_actions.ambr new file mode 100644 index 0000000000..96db5e2f79 --- /dev/null +++ b/tests/streamlit/__snapshots__/test_actions.ambr @@ -0,0 +1,49 @@ +# serializer version: 1 +# name: test_get_deploy_sql[kwargs0] + ''' + CREATE OR REPLACE STREAMLIT IDENTIFIER('test_streamlit') + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit' + + ''' +# --- +# name: test_get_deploy_sql[kwargs1] + ''' + CREATE IF NOT EXISTS STREAMLIT IDENTIFIER('test_streamlit') + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit' + + ''' +# --- +# name: test_get_deploy_sql[kwargs2] + ''' + CREATE STREAMLIT IDENTIFIER('test_streamlit') + ROOT_LOCATION = 'test_stage' + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit' + + ''' +# --- +# name: test_get_deploy_sql[kwargs3] + ''' + CREATE OR REPLACE STREAMLIT IDENTIFIER('test_streamlit') + ROOT_LOCATION = 'test_stage' + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit' + + ''' +# --- +# name: test_get_deploy_sql[kwargs4] + ''' + CREATE IF NOT EXISTS STREAMLIT IDENTIFIER('test_streamlit') + ROOT_LOCATION = 'test_stage' + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit' + + ''' +# --- diff --git a/tests/streamlit/__snapshots__/test_commands.ambr b/tests/streamlit/__snapshots__/test_commands.ambr index d48cbe2286..d86c8a9528 100644 --- a/tests/streamlit/__snapshots__/test_commands.ambr +++ b/tests/streamlit/__snapshots__/test_commands.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_artifacts_must_exists +# name: test_artifacts_must_exist ''' +- Error ----------------------------------------------------------------------+ | During evaluation of DefinitionV20 in project definition following errors | @@ -11,38 +11,6 @@ ''' # --- -# name: test_deploy_put_files_on_stage[example_streamlit-merge_definition1] - list([ - "create stage if not exists IDENTIFIER('MockDatabase.MockSchema.streamlit_stage')", - 'put file://streamlit_app.py @MockDatabase.MockSchema.streamlit_stage/test_streamlit auto_compress=false parallel=4 overwrite=True', - 'put file://environment.yml @MockDatabase.MockSchema.streamlit_stage/test_streamlit auto_compress=false parallel=4 overwrite=True', - 'put file://pages/* @MockDatabase.MockSchema.streamlit_stage/test_streamlit/pages auto_compress=false parallel=4 overwrite=True', - ''' - CREATE STREAMLIT IDENTIFIER('MockDatabase.MockSchema.test_streamlit') - ROOT_LOCATION = '@MockDatabase.MockSchema.streamlit_stage/test_streamlit' - MAIN_FILE = 'streamlit_app.py' - QUERY_WAREHOUSE = test_warehouse - TITLE = 'My Fancy Streamlit' - ''', - 'select system$get_snowsight_host()', - 'select current_account_name()', - ]) -# --- -# name: test_deploy_put_files_on_stage[example_streamlit_v2-merge_definition0] - list([ - "create stage if not exists IDENTIFIER('MockDatabase.MockSchema.streamlit_stage')", - 'put file://streamlit_app.py @MockDatabase.MockSchema.streamlit_stage/test_streamlit auto_compress=false parallel=4 overwrite=True', - ''' - CREATE STREAMLIT IDENTIFIER('MockDatabase.MockSchema.test_streamlit') - ROOT_LOCATION = '@MockDatabase.MockSchema.streamlit_stage/test_streamlit' - MAIN_FILE = 'streamlit_app.py' - QUERY_WAREHOUSE = test_warehouse - TITLE = 'My Fancy Streamlit' - ''', - 'select system$get_snowsight_host()', - 'select current_account_name()', - ]) -# --- # name: test_deploy_streamlit_nonexisting_file[example_streamlit-opts0] ''' +- Error ----------------------------------------------------------------------+ @@ -74,7 +42,7 @@ | During evaluation of DefinitionV20 in project definition following errors | | were encountered: | | For field entities.test_streamlit.streamlit you provided '{'artifacts': | - | ['foo.bar'], 'identifier': {'name': 'test_streamlit'}, 'main_file': | + | ['foo.bar'], 'identifier': 'test_streamlit', 'main_file': | | 'streamlit_app.py', 'query_warehouse': 'test_warehouse', 'stage': | | 'streamlit', 'title': 'My Fancy Streamlit', 'type': 'streamlit'}'. This | | caused: Value error, Specified artifact foo.bar does not exist locally. | diff --git a/tests/streamlit/__snapshots__/test_streamlit_entity.ambr b/tests/streamlit/__snapshots__/test_streamlit_entity.ambr new file mode 100644 index 0000000000..a1b817713e --- /dev/null +++ b/tests/streamlit/__snapshots__/test_streamlit_entity.ambr @@ -0,0 +1,60 @@ +# serializer version: 1 +# name: test_get_deploy_sql[kwargs0] + ''' + CREATE OR REPLACE STREAMLIT IDENTIFIER('test_streamlit') + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit'; + ''' +# --- +# name: test_get_deploy_sql[kwargs1] + ''' + CREATE IF NOT EXISTS STREAMLIT IDENTIFIER('test_streamlit') + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit'; + ''' +# --- +# name: test_get_deploy_sql[kwargs2] + ''' + CREATE STREAMLIT IDENTIFIER('test_streamlit') + ROOT_LOCATION = 'test_stage' + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit'; + ''' +# --- +# name: test_get_deploy_sql[kwargs3] + ''' + CREATE OR REPLACE STREAMLIT IDENTIFIER('test_streamlit') + ROOT_LOCATION = 'test_stage' + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit'; + ''' +# --- +# name: test_get_deploy_sql[kwargs4] + ''' + CREATE IF NOT EXISTS STREAMLIT IDENTIFIER('test_streamlit') + ROOT_LOCATION = 'test_stage' + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit'; + ''' +# --- +# name: test_nativeapp_children_interface + ''' + CREATE STREAMLIT IDENTIFIER('test_streamlit') + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit'; + ''' +# --- +# name: test_nativeapp_children_interface.1 + ''' + CREATE STREAMLIT IDENTIFIER('test_streamlit') + MAIN_FILE = 'streamlit_app.py' + QUERY_WAREHOUSE = 'test_warehouse' + TITLE = 'My Fancy Streamlit'; + ''' +# --- diff --git a/tests/streamlit/test_commands.py b/tests/streamlit/test_commands.py index e533a66d39..95d8878fd1 100644 --- a/tests/streamlit/test_commands.py +++ b/tests/streamlit/test_commands.py @@ -281,7 +281,7 @@ def test_deploy_only_streamlit_file_replace( mock_typer.launch.assert_not_called() -def test_artifacts_must_exists( +def test_artifacts_must_exist( runner, mock_ctx, project_directory, alter_snowflake_yml, snapshot ): with project_directory("example_streamlit_v2") as pdir: diff --git a/tests/streamlit/test_streamlit_entity.py b/tests/streamlit/test_streamlit_entity.py index 315e34b8e5..4574939f35 100644 --- a/tests/streamlit/test_streamlit_entity.py +++ b/tests/streamlit/test_streamlit_entity.py @@ -1,20 +1,54 @@ from __future__ import annotations from pathlib import Path +from unittest import mock import pytest +import yaml from snowflake.cli._plugins.streamlit.streamlit_entity import ( StreamlitEntity, ) from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, ) -from snowflake.cli._plugins.workspace.context import WorkspaceContext -from snowflake.cli.api.console import cli_console as cc -from snowflake.cli.api.project.definition_manager import DefinitionManager +from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext from tests.testing_utils.mock_config import mock_config_key +STREAMLIT_NAME = "test_streamlit" +CONNECTOR = "snowflake.connector.connect" +CONTEXT = "" +EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query" + +GET_UI_PARAMETERS = "snowflake.cli._plugins.connection.util.get_ui_parameters" + + +@pytest.fixture +def example_streamlit_workspace(project_directory): + with mock_config_key("enable_native_app_children", True): + with project_directory("example_streamlit_v2") as pdir: + with Path(pdir / "snowflake.yml").open() as definition_file: + definition = yaml.safe_load(definition_file) + model = StreamlitEntityModel( + **definition.get("entities", {}).get("test_streamlit") + ) + + workspace_context = WorkspaceContext( + console=mock.MagicMock(), + project_root=pdir, + get_default_role=lambda: "test_role", + get_default_warehouse=lambda: "test_warehouse", + ) + + yield ( + StreamlitEntity( + workspace_ctx=workspace_context, entity_model=model + ), + ActionContext( + get_entity=lambda *args: None, + ), + ) + def test_cannot_instantiate_without_feature_flag(): with pytest.raises(NotImplementedError) as err: @@ -22,32 +56,128 @@ def test_cannot_instantiate_without_feature_flag(): assert str(err.value) == "Streamlit entity is not implemented yet" -def test_nativeapp_children_interface(temp_dir): - with mock_config_key("enable_native_app_children", True): - dm = DefinitionManager() - ctx = WorkspaceContext( - console=cc, - project_root=dm.project_root, - get_default_role=lambda: "mock_role", - get_default_warehouse=lambda: "mock_warehouse", - ) - main_file = "main.py" - (Path(temp_dir) / main_file).touch() - model = StreamlitEntityModel( - type="streamlit", - main_file=main_file, - artifacts=[main_file], - ) - sl = StreamlitEntity(model, ctx) - - sl.bundle() - bundle_artifact = Path(temp_dir) / "output" / "deploy" / main_file - deploy_sql_str = sl.get_deploy_sql() - grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") - - assert bundle_artifact.exists() - assert deploy_sql_str == "CREATE OR REPLACE STREAMLIT None MAIN_FILE='main.py';" - assert ( - grant_sql_str - == "GRANT USAGE ON STREAMLIT None TO APPLICATION ROLE app_role;" +def test_nativeapp_children_interface(example_streamlit_workspace, snapshot): + sl, action_context = example_streamlit_workspace + + sl.bundle() + bundle_artifact = sl.root / "output" / sl.model.stage / "streamlit_app.py" + deploy_sql_str = sl.get_deploy_sql() + grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") + + assert bundle_artifact.exists() + assert deploy_sql_str == snapshot + assert ( + grant_sql_str == f"GRANT USAGE ON STREAMLIT None TO APPLICATION ROLE app_role;" + ) + + +def test_bundle(example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_bundle(action_ctx) + + output = entity.root / "output" / entity._entity_model.stage # noqa + + assert output.exists() + assert (output / "streamlit_app.py").exists() + assert (output / "environment.yml").exists() + assert (output / "pages" / "my_page.py").exists() + + +@mock.patch(EXECUTE_QUERY) +def test_deploy(mock_execute, example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_deploy(action_ctx) + + mock_execute.assert_called_with( + f"CREATE STREAMLIT IDENTIFIER('{STREAMLIT_NAME}')\nMAIN_FILE = 'streamlit_app.py'\nQUERY_WAREHOUSE = 'test_warehouse'\nTITLE = 'My Fancy Streamlit';" + ) + + +@mock.patch(EXECUTE_QUERY) +def test_drop(mock_execute, example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_drop(action_ctx) + + mock_execute.assert_called_with(f"DROP STREAMLIT {STREAMLIT_NAME};") + + +@mock.patch(CONNECTOR) +@mock.patch( + GET_UI_PARAMETERS, + return_value={"UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT": "false"}, +) +@mock.patch("click.get_current_context") +def test_get_url( + mock_get_ctx, + mock_param, + mock_connect, + mock_cursor, + example_streamlit_workspace, + mock_ctx, +): + ctx = mock_ctx( + mock_cursor( + rows=[ + {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, + {"SYSTEM$RETURN_CURRENT_ORG_NAME()": "FOOBARBAZ"}, + {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, + ], + columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], ) + ) + mock_connect.return_value = ctx + mock_get_ctx.return_value = ctx + + entity, action_ctx = example_streamlit_workspace + result = entity.action_get_url(action_ctx) + + mock_connect.assert_called() + + +@mock.patch(EXECUTE_QUERY) +def test_share(mock_connect, example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_share(action_ctx, to_role="test_role") + + mock_connect.assert_called_with( + "GRANT USAGE ON STREAMLIT IDENTIFIER('test_streamlit') TO ROLE test_role;" + ) + + +@mock.patch(EXECUTE_QUERY) +def test_execute(mock_execute, example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_execute(action_ctx) + + mock_execute.assert_called_with(f"EXECUTE STREAMLIT {STREAMLIT_NAME}();") + + +def test_get_execute_sql(example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + execute_sql = entity.get_execute_sql() + + assert execute_sql == f"EXECUTE STREAMLIT {STREAMLIT_NAME}();" + + +def test_get_drop_sql(example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + drop_sql = entity.get_drop_sql() + + assert drop_sql == f"DROP STREAMLIT {STREAMLIT_NAME};" + + +@pytest.mark.parametrize( + "kwargs", + [ + {"replace": True}, + {"if_not_exists": True}, + {"from_stage_name": "test_stage"}, + {"from_stage_name": "test_stage", "replace": True}, + {"from_stage_name": "test_stage", "if_not_exists": True}, + ], +) +def test_get_deploy_sql(kwargs, example_streamlit_workspace, snapshot): + entity, action_ctx = example_streamlit_workspace + deploy_sql = entity.get_deploy_sql(**kwargs) + + assert deploy_sql == snapshot diff --git a/tests/test_data/projects/example_streamlit_v2/snowflake.yml b/tests/test_data/projects/example_streamlit_v2/snowflake.yml index 362d963f4b..f14d3e9d1b 100644 --- a/tests/test_data/projects/example_streamlit_v2/snowflake.yml +++ b/tests/test_data/projects/example_streamlit_v2/snowflake.yml @@ -1,12 +1,14 @@ definition_version: '2' entities: test_streamlit: - identifier: - name: test_streamlit - type: streamlit - title: My Fancy Streamlit + type: "streamlit" + identifier: test_streamlit + title: "My Fancy Streamlit" + stage: streamlit query_warehouse: test_warehouse main_file: streamlit_app.py - stage: streamlit artifacts: - streamlit_app.py + - utils/utils.py + - pages/ + - environment.yml diff --git a/tests/test_data/projects/example_streamlit_v2/utils/utils.py b/tests/test_data/projects/example_streamlit_v2/utils/utils.py new file mode 100644 index 0000000000..e69de29bb2