From 4b30958c23a2f58681a4ce7c5bdd8951ee46c920 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Mon, 26 Aug 2024 09:10:03 +0200 Subject: [PATCH 1/4] Refactor snowpark commands to be entity-centric (#1483) --- .../cli/_plugins/snowpark/commands.py | 124 ++-------- src/snowflake/cli/_plugins/snowpark/common.py | 217 ++++++++++-------- .../cli/_plugins/snowpark/manager.py | 111 --------- .../api/project/schemas/entities/common.py | 4 +- tests/snowpark/test_common.py | 78 +++++-- tests/snowpark/test_function.py | 4 +- tests/snowpark/test_procedure.py | 4 +- tests/streamlit/test_streamlit_manager.py | 2 +- 8 files changed, 203 insertions(+), 341 deletions(-) delete mode 100644 src/snowflake/cli/_plugins/snowpark/manager.py diff --git a/src/snowflake/cli/_plugins/snowpark/commands.py b/src/snowflake/cli/_plugins/snowpark/commands.py index 009bd7b0d4..8d7c3b5ca2 100644 --- a/src/snowflake/cli/_plugins/snowpark/commands.py +++ b/src/snowflake/cli/_plugins/snowpark/commands.py @@ -16,8 +16,7 @@ import logging from collections import defaultdict -from enum import Enum -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, Optional, Set, Tuple import typer from click import ClickException, UsageError @@ -36,16 +35,18 @@ from snowflake.cli._plugins.object.manager import ObjectManager from snowflake.cli._plugins.snowpark import package_utils from snowflake.cli._plugins.snowpark.common import ( - check_if_replace_is_required, + EntityToImportPathsMapping, + SnowparkEntities, + SnowparkObject, + SnowparkObjectManager, + StageToArtefactMapping, ) -from snowflake.cli._plugins.snowpark.manager import FunctionManager, ProcedureManager from snowflake.cli._plugins.snowpark.package.anaconda_packages import ( AnacondaPackages, AnacondaPackagesManager, ) from snowflake.cli._plugins.snowpark.package.commands import app as package_app from snowflake.cli._plugins.snowpark.snowpark_project_paths import ( - Artefact, SnowparkProjectPaths, ) from snowflake.cli._plugins.snowpark.snowpark_shared import ( @@ -72,7 +73,6 @@ from snowflake.cli.api.console import cli_console from snowflake.cli.api.constants import ( DEFAULT_SIZE_LIMIT_MB, - ObjectType, ) from snowflake.cli.api.exceptions import ( NoProjectDefinitionError, @@ -88,7 +88,6 @@ from snowflake.cli.api.project.schemas.entities.snowpark_entity import ( FunctionEntityModel, ProcedureEntityModel, - SnowparkEntityModel, ) from snowflake.cli.api.project.schemas.project_definition import ( ProjectDefinition, @@ -124,11 +123,6 @@ ) -SnowparkEntities = Dict[str, SnowparkEntityModel] -StageToArtefactMapping = Dict[str, set[Artefact]] -EntityToImportPathsMapping = Dict[str, set[str]] - - @app.command("deploy", requires_connection=True) @with_project_definition() def deploy( @@ -179,13 +173,13 @@ def deploy( # Create snowpark entities with cli_console.phase("Creating Snowpark entities"): + snowpark_manager = SnowparkObjectManager() snowflake_dependencies = _read_snowflake_requirements_file( project_paths.snowflake_requirements ) deploy_status = [] for entity in snowpark_entities.values(): - cli_console.step(f"Creating {entity.type} {entity.fqn}") - operation_result = _deploy_single_object( + operation_result = snowpark_manager.deploy_entity( entity=entity, existing_objects=existing_objects, snowflake_dependencies=snowflake_dependencies, @@ -280,7 +274,7 @@ def _find_existing_objects( def _check_if_all_defined_integrations_exists( om: ObjectManager, - snowpark_entities: Dict[str, FunctionEntityModel | ProcedureEntityModel], + snowpark_entities: SnowparkEntities, ): existing_integrations = { i["name"].lower() @@ -306,72 +300,6 @@ def _check_if_all_defined_integrations_exists( ) -def _deploy_single_object( - entity: SnowparkEntityModel, - existing_objects: Dict[str, SnowflakeCursor], - snowflake_dependencies: List[str], - entities_to_artifact_map: EntityToImportPathsMapping, -): - object_type = entity.get_type() - is_procedure = isinstance(entity, ProcedureEntityModel) - - handler = entity.handler - returns = entity.returns - imports = entity.imports - external_access_integrations = entity.external_access_integrations - runtime_ver = entity.runtime - execute_as_caller = None - if is_procedure: - execute_as_caller = entity.execute_as_caller - replace_object = False - - object_exists = entity.entity_id in existing_objects - if object_exists: - replace_object = check_if_replace_is_required( - object_type=object_type, - current_state=existing_objects[entity.entity_id], - handler=handler, - return_type=returns, - snowflake_dependencies=snowflake_dependencies, - external_access_integrations=external_access_integrations, - imports=imports, - stage_artifact_files=entities_to_artifact_map[entity.entity_id], - runtime_ver=runtime_ver, - execute_as_caller=execute_as_caller, - ) - - if object_exists and not replace_object: - return { - "object": entity.udf_sproc_identifier.identifier_with_arg_names_types_defaults, - "type": str(object_type), - "status": "packages updated", - } - - create_or_replace_kwargs = { - "identifier": entity.udf_sproc_identifier, - "handler": handler, - "return_type": returns, - "artifact_files": entities_to_artifact_map[entity.entity_id], - "packages": snowflake_dependencies, - "runtime": entity.runtime, - "external_access_integrations": entity.external_access_integrations, - "secrets": entity.secrets, - "imports": imports, - } - if is_procedure: - create_or_replace_kwargs["execute_as_caller"] = entity.execute_as_caller - - manager = ProcedureManager() if is_procedure else FunctionManager() - manager.create_or_replace(**create_or_replace_kwargs) - - status = "created" if not object_exists else "definition updated" - return { - "object": entity.udf_sproc_identifier.identifier_with_arg_names_types_defaults, - "type": str(object_type), - "status": status, - } - - def _read_snowflake_requirements_file(file_path: SecurePath): if not file_path.exists(): return [] @@ -470,46 +398,24 @@ def get_snowpark_entities( return snowpark_entities -class _SnowparkObject(Enum): - """This clas is used only for Snowpark execute where choice is limited.""" - - PROCEDURE = str(ObjectType.PROCEDURE) - FUNCTION = str(ObjectType.FUNCTION) - - -def _execute_object_method( - method_name: str, - object_type: _SnowparkObject, - **kwargs, -): - if object_type == _SnowparkObject.PROCEDURE: - manager = ProcedureManager() - elif object_type == _SnowparkObject.FUNCTION: - manager = FunctionManager() - else: - raise ClickException(f"Unknown object type: {object_type}") - - return getattr(manager, method_name)(**kwargs) - - @app.command("execute", requires_connection=True) def execute( - object_type: _SnowparkObject = ObjectTypeArgument, + object_type: SnowparkObject = ObjectTypeArgument, execution_identifier: str = execution_identifier_argument( "procedure/function", "hello(1, 'world')" ), **options, ) -> CommandResult: """Executes a procedure or function in a specified environment.""" - cursor = _execute_object_method( - "execute", object_type=object_type, execution_identifier=execution_identifier + cursor = SnowparkObjectManager().execute( + execution_identifier=execution_identifier, object_type=object_type ) return SingleQueryResult(cursor) @app.command("list", requires_connection=True) def list_( - object_type: _SnowparkObject = ObjectTypeArgument, + object_type: SnowparkObject = ObjectTypeArgument, like: str = LikeOption, scope: Tuple[str, str] = scope_option( help_example="`list function --in database my_db`" @@ -522,7 +428,7 @@ def list_( @app.command("drop", requires_connection=True) def drop( - object_type: _SnowparkObject = ObjectTypeArgument, + object_type: SnowparkObject = ObjectTypeArgument, identifier: FQN = IdentifierArgument, **options, ): @@ -532,7 +438,7 @@ def drop( @app.command("describe", requires_connection=True) def describe( - object_type: _SnowparkObject = ObjectTypeArgument, + object_type: SnowparkObject = ObjectTypeArgument, identifier: FQN = IdentifierArgument, **options, ): diff --git a/src/snowflake/cli/_plugins/snowpark/common.py b/src/snowflake/cli/_plugins/snowpark/common.py index 256631502c..48d258c6ac 100644 --- a/src/snowflake/cli/_plugins/snowpark/common.py +++ b/src/snowflake/cli/_plugins/snowpark/common.py @@ -14,46 +14,128 @@ from __future__ import annotations +import logging import re -from typing import Dict, List, Optional, Set +from enum import Enum +from typing import Dict, List, Set +from click import UsageError from snowflake.cli._plugins.snowpark.models import Requirement -from snowflake.cli._plugins.snowpark.package_utils import ( - generate_deploy_stage_name, -) +from snowflake.cli._plugins.snowpark.snowpark_project_paths import Artefact +from snowflake.cli.api.console import cli_console from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.project.schemas.entities.snowpark_entity import ( - UdfSprocIdentifier, + ProcedureEntityModel, + SnowparkEntityModel, ) from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.connector.cursor import SnowflakeCursor +log = logging.getLogger(__name__) + +SnowparkEntities = Dict[str, SnowparkEntityModel] +StageToArtefactMapping = Dict[str, set[Artefact]] +EntityToImportPathsMapping = Dict[str, set[str]] + DEFAULT_RUNTIME = "3.10" -def check_if_replace_is_required( - object_type: str, +class SnowparkObject(Enum): + """This clas is used only for Snowpark execute where choice is limited.""" + + PROCEDURE = str(ObjectType.PROCEDURE) + FUNCTION = str(ObjectType.FUNCTION) + + +class SnowparkObjectManager(SqlExecutionMixin): + def execute( + self, execution_identifier: str, object_type: SnowparkObject + ) -> SnowflakeCursor: + if object_type == SnowparkObject.FUNCTION: + return self._execute_query(f"select {execution_identifier}") + if object_type == SnowparkObject.PROCEDURE: + return self._execute_query(f"call {execution_identifier}") + raise UsageError(f"Unknown object type: {object_type}.") + + def create_or_replace( + self, + entity: SnowparkEntityModel, + artifact_files: set[str], + snowflake_dependencies: list[str], + ) -> str: + entity.imports.extend(artifact_files) + imports = [f"'{x}'" for x in entity.imports] + packages_list = ",".join(f"'{p}'" for p in snowflake_dependencies) + + object_type = entity.get_type() + + query = [ + f"create or replace {object_type} {entity.udf_sproc_identifier.identifier_for_sql}", + f"copy grants", + f"returns {entity.returns}", + "language python", + f"runtime_version={entity.runtime or DEFAULT_RUNTIME}", + f"imports=({', '.join(imports)})", + f"handler='{entity.handler}'", + f"packages=({packages_list})", + ] + + if entity.external_access_integrations: + query.append(entity.get_external_access_integrations_sql()) + + if entity.secrets: + query.append(entity.get_secrets_sql()) + + if isinstance(entity, ProcedureEntityModel) and entity.execute_as_caller: + query.append("execute as caller") + + return self._execute_query("\n".join(query)) + + def deploy_entity( + self, + entity: SnowparkEntityModel, + existing_objects: Dict[str, SnowflakeCursor], + snowflake_dependencies: List[str], + entities_to_artifact_map: EntityToImportPathsMapping, + ): + cli_console.step(f"Creating {entity.type} {entity.fqn}") + object_exists = entity.entity_id in existing_objects + replace_object = False + if object_exists: + replace_object = _check_if_replace_is_required( + entity=entity, + current_state=existing_objects[entity.entity_id], + snowflake_dependencies=snowflake_dependencies, + stage_artifact_files=entities_to_artifact_map[entity.entity_id], + ) + + state = { + "object": entity.udf_sproc_identifier.identifier_with_arg_names_types_defaults, + "type": entity.get_type(), + } + if object_exists and not replace_object: + return {**state, "status": "packages updated"} + + self.create_or_replace( + entity=entity, + artifact_files=entities_to_artifact_map[entity.entity_id], + snowflake_dependencies=snowflake_dependencies, + ) + return { + **state, + "status": "created" if not object_exists else "definition updated", + } + + +def _check_if_replace_is_required( + entity: SnowparkEntityModel, current_state, - handler: str, - return_type: str, snowflake_dependencies: List[str], - external_access_integrations: List[str], - imports: List[str], stage_artifact_files: set[str], - runtime_ver: Optional[str] = None, - execute_as_caller: Optional[bool] = None, ) -> bool: - import logging - - log = logging.getLogger(__name__) + object_type = entity.get_type() resource_json = _convert_resource_details_to_dict(current_state) old_dependencies = resource_json["packages"] - log.info( - "Found %d defined Anaconda packages in deployed %s...", - len(old_dependencies), - object_type, - ) - log.info("Checking if app configuration has changed...") if _snowflake_dependencies_differ(old_dependencies, snowflake_dependencies): log.info( @@ -61,7 +143,7 @@ def check_if_replace_is_required( ) return True - if set(external_access_integrations) != set( + if set(entity.external_access_integrations) != set( resource_json.get("external_access_integrations", []) ): log.info( @@ -71,33 +153,33 @@ def check_if_replace_is_required( return True if ( - resource_json["handler"].lower() != handler.lower() + resource_json["handler"].lower() != entity.handler.lower() or _sql_to_python_return_type_mapper(resource_json["returns"]).lower() - != return_type.lower() + != entity.returns.lower() ): log.info( "Return type or handler types do not match. Replacing the %s.", object_type ) return True - if _compare_imports(resource_json, imports, stage_artifact_files): + if _compare_imports(resource_json, entity.imports, stage_artifact_files): log.info("Imports do not match. Replacing the %s", object_type) return True - if runtime_ver is not None and runtime_ver != resource_json.get( + if entity.runtime is not None and entity.runtime != resource_json.get( "runtime_version", "RUNTIME_NOT_SET" ): log.info("Runtime versions do not match. Replacing the %s", object_type) return True - if execute_as_caller is not None and ( - resource_json.get("execute as", "OWNER") - != ("CALLER" if execute_as_caller else "OWNER") - ): - log.info( - "Execute as caller settings do not match. Replacing the %s", object_type - ) - return True + if isinstance(entity, ProcedureEntityModel): + if resource_json.get("execute as", "OWNER") != ( + "CALLER" if entity.execute_as_caller else "OWNER" + ): + log.info( + "Execute as caller settings do not match. Replacing the %s", object_type + ) + return True return False @@ -148,71 +230,6 @@ def _sql_to_python_return_type_mapper(resource_return_type: str) -> str: return mapping.get(resource_return_type.lower(), resource_return_type.lower()) -class SnowparkObjectManager(SqlExecutionMixin): - @property - def _object_type(self) -> ObjectType: - raise NotImplementedError() - - @property - def _object_execute(self): - raise NotImplementedError() - - def create(self, *args, **kwargs) -> SnowflakeCursor: - raise NotImplementedError() - - def execute(self, execution_identifier: str) -> SnowflakeCursor: - return self._execute_query(f"{self._object_execute} {execution_identifier}") - - @staticmethod - def artifact_stage_path(identifier: str): - return generate_deploy_stage_name(identifier).lower() - - def create_query( - self, - identifier: UdfSprocIdentifier, - return_type: str, - handler: str, - artifact_files: set[str], - packages: List[str], - imports: List[str], - external_access_integrations: Optional[List[str]] = None, - secrets: Optional[Dict[str, str]] = None, - runtime: Optional[str] = None, - execute_as_caller: bool = False, - ) -> str: - imports.extend(artifact_files) - imports = [f"'{x}'" for x in imports] - packages_list = ",".join(f"'{p}'" for p in packages) - - query = [ - f"create or replace {self._object_type.value.sf_name} {identifier.identifier_for_sql}", - f"copy grants", - f"returns {return_type}", - "language python", - f"runtime_version={runtime or DEFAULT_RUNTIME}", - f"imports=({', '.join(imports)})", - f"handler='{handler}'", - f"packages=({packages_list})", - ] - - if external_access_integrations: - external_access_integration_name = ",".join( - f"{e}" for e in external_access_integrations - ) - query.append( - f"external_access_integrations=({external_access_integration_name})" - ) - - if secrets: - secret_name = ",".join(f"'{k}'={v}" for k, v in secrets.items()) - query.append(f"secrets=({secret_name})") - - if execute_as_caller: - query.append("execute as caller") - - return "\n".join(query) - - def _compare_imports( resource_json: dict, imports: List[str], artifact_files: set[str] ) -> bool: diff --git a/src/snowflake/cli/_plugins/snowpark/manager.py b/src/snowflake/cli/_plugins/snowpark/manager.py deleted file mode 100644 index 1bbbf16106..0000000000 --- a/src/snowflake/cli/_plugins/snowpark/manager.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) 2024 Snowflake Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging -from typing import Dict, List, Optional - -from snowflake.cli._plugins.snowpark.common import ( - SnowparkObjectManager, -) -from snowflake.cli.api.constants import ObjectType -from snowflake.cli.api.project.schemas.entities.snowpark_entity import ( - UdfSprocIdentifier, -) -from snowflake.connector.cursor import SnowflakeCursor - -log = logging.getLogger(__name__) - - -class FunctionManager(SnowparkObjectManager): - @property - def _object_type(self): - return ObjectType.FUNCTION - - @property - def _object_execute(self): - return "select" - - def create_or_replace( - self, - identifier: UdfSprocIdentifier, - return_type: str, - handler: str, - artifact_files: set[str], - packages: List[str], - imports: List[str], - external_access_integrations: Optional[List[str]] = None, - secrets: Optional[Dict[str, str]] = None, - runtime: Optional[str] = None, - ) -> SnowflakeCursor: - log.debug( - "Creating function %s using @%s", - identifier.identifier_with_arg_names_types_defaults, - artifact_files, - ) - query = self.create_query( - identifier, - return_type, - handler, - artifact_files, - packages, - imports, - external_access_integrations, - secrets, - runtime, - ) - return self._execute_query(query) - - -class ProcedureManager(SnowparkObjectManager): - @property - def _object_type(self): - return ObjectType.PROCEDURE - - @property - def _object_execute(self): - return "call" - - def create_or_replace( - self, - identifier: UdfSprocIdentifier, - return_type: str, - handler: str, - artifact_files: set[str], - packages: List[str], - imports: List[str], - external_access_integrations: Optional[List[str]] = None, - secrets: Optional[Dict[str, str]] = None, - runtime: Optional[str] = None, - execute_as_caller: bool = False, - ) -> SnowflakeCursor: - log.debug( - "Creating procedure %s using @%s", - identifier.identifier_with_arg_names_types_defaults, - artifact_files, - ) - query = self.create_query( - identifier, - return_type, - handler, - artifact_files, - packages, - imports, - external_access_integrations, - secrets, - runtime, - execute_as_caller, - ) - return self._execute_query(query) diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index 46922588c1..4868f1f396 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -135,5 +135,5 @@ def get_external_access_integrations_sql(self) -> str | None: def get_secrets_sql(self) -> str | None: if not self.secrets: return None - secrets = ", ".join(f"'{key}' = {value}" for key, value in self.secrets.items()) - return f"secrets = ({secrets})" + secrets = ", ".join(f"'{key}'={value}" for key, value in self.secrets.items()) + return f"secrets=({secrets})" diff --git a/tests/snowpark/test_common.py b/tests/snowpark/test_common.py index f381ee7fe7..d9ecf65cb3 100644 --- a/tests/snowpark/test_common.py +++ b/tests/snowpark/test_common.py @@ -18,10 +18,13 @@ import pytest from snowflake.cli._plugins.snowpark.common import ( + _check_if_replace_is_required, _convert_resource_details_to_dict, _snowflake_dependencies_differ, _sql_to_python_return_type_mapper, - check_if_replace_is_required, +) +from snowflake.cli.api.project.schemas.entities.snowpark_entity import ( + ProcedureEntityModel, ) @@ -79,11 +82,51 @@ def test_sql_to_python_return_type_mapper(argument: Tuple[str, str]): [ ({}, False), ({"handler": "app.another_procedure"}, True), - ({"return_type": "variant"}, True), - ({"snowflake_dependencies": ["snowflake-snowpark-python", "pandas"]}, True), + ({"returns": "variant"}, True), ({"external_access_integrations": ["My_Integration"]}, True), ({"imports": ["@FOO.BAR.BAZ/some_project/some_package.zip"]}, True), ({"imports": ["@FOO.BAR.BAZ/my_snowpark_project/app.zip"]}, False), + ({"runtime": "3.9"}, True), + ({"execute_as_caller": False}, True), + ], +) +def test_check_if_replace_is_required_entity_changes( + mock_procedure_description, arguments, expected +): + entity_spec = { + "type": "procedure", + "handler": "app.hello_procedure", + "signature": "(NAME VARCHAR)", + "artifacts": [], + "stage": "foo", + "returns": "string", + "external_access_integrations": [], + "imports": [], + "runtime": "3.10", + "execute_as_caller": True, + } + entity_spec.update(arguments) + + entity = ProcedureEntityModel(**entity_spec) + + assert ( + _check_if_replace_is_required( + entity=entity, + current_state=mock_procedure_description, + snowflake_dependencies=[ + "snowflake-snowpark-python", + "pytest<9.0.0,>=7.0.0", + ], + stage_artifact_files={"@FOO.BAR.BAZ/my_snowpark_project/app.zip"}, + ) + == expected + ) + + +@pytest.mark.parametrize( + "arguments, expected", + [ + ({"snowflake_dependencies": ["snowflake-snowpark-python", "pandas"]}, True), ( { "stage_artifact_files": [ @@ -92,26 +135,33 @@ def test_sql_to_python_return_type_mapper(argument: Tuple[str, str]): }, True, ), - ({"runtime_ver": "3.9"}, True), - ({"execute_as_caller": False}, True), ], ) -def test_check_if_replace_is_required(mock_procedure_description, arguments, expected): - replace_arguments = { +def test_check_if_replace_is_required_file_changes( + mock_procedure_description, arguments, expected +): + entity_spec = { + "type": "procedure", "handler": "app.hello_procedure", - "return_type": "string", - "snowflake_dependencies": ["snowflake-snowpark-python", "pytest<9.0.0,>=7.0.0"], + "signature": "(NAME VARCHAR)", + "artifacts": [], + "stage": "foo", + "returns": "string", "external_access_integrations": [], "imports": [], - "stage_artifact_files": ["@FOO.BAR.BAZ/my_snowpark_project/app.zip"], - "runtime_ver": "3.10", + "runtime": "3.10", "execute_as_caller": True, } - replace_arguments.update(arguments) + entity = ProcedureEntityModel(**entity_spec) + kwargs = { + "snowflake_dependencies": ["snowflake-snowpark-python", "pytest<9.0.0,>=7.0.0"], + "stage_artifact_files": {"@FOO.BAR.BAZ/my_snowpark_project/app.zip"}, + } + kwargs.update(arguments) assert ( - check_if_replace_is_required( - "procedure", mock_procedure_description, **replace_arguments + _check_if_replace_is_required( + entity=entity, current_state=mock_procedure_description, **kwargs ) == expected ) diff --git a/tests/snowpark/test_function.py b/tests/snowpark/test_function.py index 1f3ce17de3..dab8b3483d 100644 --- a/tests/snowpark/test_function.py +++ b/tests/snowpark/test_function.py @@ -113,8 +113,8 @@ def test_deploy_function_with_external_access( imports=('@MockDatabase.MockSchema.dev_deployment/my_snowpark_project/app.py') handler='app.func1_handler' packages=() - external_access_integrations=(external_1,external_2) - secrets=('cred'=cred_name,'other'=other_name) + external_access_integrations=(external_1, external_2) + secrets=('cred'=cred_name, 'other'=other_name) """ ).strip(), ] diff --git a/tests/snowpark/test_procedure.py b/tests/snowpark/test_procedure.py index 97f7a960d2..275b464f23 100644 --- a/tests/snowpark/test_procedure.py +++ b/tests/snowpark/test_procedure.py @@ -158,8 +158,8 @@ def test_deploy_procedure_with_external_access( imports=('@MockDatabase.MockSchema.dev_deployment/my_snowpark_project/app.py') handler='app.hello' packages=() - external_access_integrations=(external_1,external_2) - secrets=('cred'=cred_name,'other'=other_name) + external_access_integrations=(external_1, external_2) + secrets=('cred'=cred_name, 'other'=other_name) """ ).strip(), ] diff --git a/tests/streamlit/test_streamlit_manager.py b/tests/streamlit/test_streamlit_manager.py index dd17f3d426..9cf4e01509 100644 --- a/tests/streamlit/test_streamlit_manager.py +++ b/tests/streamlit/test_streamlit_manager.py @@ -80,6 +80,6 @@ def test_deploy_streamlit_with_api_integrations( QUERY_WAREHOUSE = My_WH TITLE = 'MyStreamlit' external_access_integrations=(MY_INTERGATION, OTHER) - secrets = ('my_secret' = SecretOfTheSecrets, 'other' = other_secret)""" + secrets=('my_secret'=SecretOfTheSecrets, 'other'=other_secret)""" ) ) From 9ac000e313fae0034b458216eedb913a71cd9342 Mon Sep 17 00:00:00 2001 From: Jan Sikorski <132985823+sfc-gh-jsikorski@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:01:50 +0200 Subject: [PATCH 2/4] Add mixins (#1484) * Basic solution SNOW-1636849 Auto-teardown Native App in integration tests (#1478) Changes `with project_directory()` to `with nativeapp_project_directory()`, which automatically runs `snow app teardown` before exiting the project. This allows us to remove the `try`/`finally` in most tests. For tests that were using `with pushd(test_project)`, this has been changed to `with nativeapp_teardown()`, which is what `with nativeapp_project_directory()` uses under the hood. SNOW-1621834 Cast version to identifier when creating/dropping app versions (#1475) When running `snow app version create` and `snow app version drop`, wrap the version in `to_identifier()` so users don't have to specify the quotes around version names that aren't valid identifiers. If the name is already quoted, `to_identifier()` doesn't do anything. Added tests Added tests * Added tests Added tests * Post-review-fixes --- .../api/project/schemas/entities/common.py | 15 ++++- .../api/project/schemas/project_definition.py | 65 ++++++++++++++++--- tests/project/test_project_definition_v2.py | 34 ++++++++++ .../projects/mixins_basic/snowflake.yml | 28 ++++++++ .../mixins_defaults_hierarchy/snowflake.yml | 29 +++++++++ .../mixins_different_entities/environment.yml | 5 ++ .../pages/my_page.py | 3 + .../mixins_different_entities/snowflake.yml | 45 +++++++++++++ .../streamlit_app.py | 0 .../environment.yml | 5 ++ .../pages/my_page.py | 3 + .../snowflake.yml | 58 +++++++++++++++++ .../streamlit_app.py | 0 13 files changed, 281 insertions(+), 9 deletions(-) create mode 100644 tests/test_data/projects/mixins_basic/snowflake.yml create mode 100644 tests/test_data/projects/mixins_defaults_hierarchy/snowflake.yml create mode 100644 tests/test_data/projects/mixins_different_entities/environment.yml create mode 100644 tests/test_data/projects/mixins_different_entities/pages/my_page.py create mode 100644 tests/test_data/projects/mixins_different_entities/snowflake.yml create mode 100644 tests/test_data/projects/mixins_different_entities/streamlit_app.py create mode 100644 tests/test_data/projects/mixins_list_applied_in_order/environment.yml create mode 100644 tests/test_data/projects/mixins_list_applied_in_order/pages/my_page.py create mode 100644 tests/test_data/projects/mixins_list_applied_in_order/snowflake.yml create mode 100644 tests/test_data/projects/mixins_list_applied_in_order/streamlit_app.py diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index 4868f1f396..1ee7b9564a 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -17,7 +17,7 @@ from abc import ABC from typing import Generic, List, Optional, TypeVar, Union -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, field_validator from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.schemas.identifier_model import Identifier from snowflake.cli.api.project.schemas.native_app.application import ( @@ -41,6 +41,19 @@ class MetaField(UpdatableModel): title="Actions that will be executed after the application object is created/upgraded", default=None, ) + use_mixins: Optional[List[str]] = Field( + title="Name of the mixin used to fill the entity fields", + default=None, + ) + + @field_validator("use_mixins", mode="before") + @classmethod + def ensure_use_mixins_is_a_list( + cls, mixins: Optional[str | List[str]] + ) -> Optional[List[str]]: + if isinstance(mixins, str): + return [mixins] + return mixins class DefaultsField(UpdatableModel): diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index ce74e38739..6545dc5aa2 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from packaging.version import Version from pydantic import Field, ValidationError, field_validator, model_validator @@ -123,15 +123,11 @@ def apply_defaults(cls, data: Dict) -> Dict: """ if "defaults" in data and "entities" in data: for key, entity in data["entities"].items(): - entity_type = entity["type"] - if entity_type not in v2_entity_model_types_map: + entity_fields = get_allowed_fields_for_entity(entity) + if not entity_fields: continue - entity_model = v2_entity_model_types_map[entity_type] for default_key, default_value in data["defaults"].items(): - if ( - default_key in entity_model.model_fields - and default_key not in entity - ): + if default_key in entity_fields and default_key not in entity: entity[default_key] = default_value return data @@ -194,6 +190,36 @@ def _validate_target_field( default=None, ) + mixins: Optional[Dict[str, Dict]] = Field( + title="Mixins to apply to entities", + default=None, + ) + + @model_validator(mode="before") + @classmethod + def apply_mixins(cls, data: Dict) -> Dict: + """ + Applies mixins to those entities, whose meta field contains the mixin name. + """ + if "mixins" not in data or "entities" not in data: + return data + + for entity in data["entities"].values(): + entity_mixins = entity_mixins_to_list( + entity.get("meta", {}).get("use_mixins") + ) + + entity_fields = get_allowed_fields_for_entity(entity) + if entity_fields and entity_mixins: + for mixin_name in entity_mixins: + if mixin_name in data["mixins"]: + for key, value in data["mixins"][mixin_name].items(): + if key in entity_fields: + entity[key] = value + else: + raise ValueError(f"Mixin {mixin_name} not found in mixins") + return data + def get_entities_by_type(self, entity_type: str): return {i: e for i, e in self.entities.items() if e.get_type() == entity_type} @@ -222,3 +248,26 @@ def get_version_map(): if FeatureFlag.ENABLE_PROJECT_DEFINITION_V2.is_enabled(): version_map["2"] = DefinitionV20 return version_map + + +def entity_mixins_to_list(entity_mixins: Optional[str | List[str]]) -> List[str]: + """ + Convert an optional string or a list of strings to a list of strings. + """ + if entity_mixins is None: + return [] + if isinstance(entity_mixins, str): + return [entity_mixins] + return entity_mixins + + +def get_allowed_fields_for_entity(entity: Dict[str, Any]) -> List[str]: + """ + Get the allowed fields for the given entity. + """ + entity_type = entity.get("type") + if entity_type not in v2_entity_model_types_map: + return [] + + entity_model = v2_entity_model_types_map[entity_type] + return entity_model.model_fields diff --git a/tests/project/test_project_definition_v2.py b/tests/project/test_project_definition_v2.py index 16ded60fa4..d6678b334e 100644 --- a/tests/project/test_project_definition_v2.py +++ b/tests/project/test_project_definition_v2.py @@ -385,6 +385,40 @@ def test_v1_to_v2_conversion( _assert_entities_are_equal(v1_function, v2_function) +# TODO: +# 1. rewrite projects to have one big definition covering all complex positive cases +# 2. Add negative case - entity uses non-existent mixin +@pytest.mark.parametrize( + "project_name,stage1,stage2", + [("mixins_basic", "foo", "bar"), ("mixins_defaults_hierarchy", "foo", "baz")], +) +def test_mixins(project_directory, project_name, stage1, stage2): + with project_directory(project_name) as project_dir: + definition = DefinitionManager(project_dir).project_definition + + assert definition.entities["function1"].stage == stage1 + assert definition.entities["function1"].handler == "app.hello" + assert definition.entities["function2"].stage == stage2 + assert definition.entities["function1"].handler == "app.hello" + + +def test_mixins_for_different_entities(project_directory): + with project_directory("mixins_different_entities") as project_dir: + definition = DefinitionManager(project_dir).project_definition + + assert definition.entities["function1"].stage == "foo" + assert definition.entities["streamlit1"].main_file == "streamlit_app.py" + + +def test_list_of_mixins_in_correct_order(project_directory): + with project_directory("mixins_list_applied_in_order") as project_dir: + definition = DefinitionManager(project_dir).project_definition + + assert definition.entities["function1"].stage == "foo" + assert definition.entities["function2"].stage == "baz" + assert definition.entities["streamlit1"].stage == "bar" + + def _assert_entities_are_equal( v1_entity: _CallableBase, v2_entity: SnowparkEntityModel ): diff --git a/tests/test_data/projects/mixins_basic/snowflake.yml b/tests/test_data/projects/mixins_basic/snowflake.yml new file mode 100644 index 0000000000..921415c943 --- /dev/null +++ b/tests/test_data/projects/mixins_basic/snowflake.yml @@ -0,0 +1,28 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: my_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello + identifier: name + returns: string + signature: + - name: name + type: string + stage: bar + type: function +mixins: + my_mixin: + stage: foo diff --git a/tests/test_data/projects/mixins_defaults_hierarchy/snowflake.yml b/tests/test_data/projects/mixins_defaults_hierarchy/snowflake.yml new file mode 100644 index 0000000000..9a5d517dad --- /dev/null +++ b/tests/test_data/projects/mixins_defaults_hierarchy/snowflake.yml @@ -0,0 +1,29 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: my_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello2 + identifier: name + returns: string + signature: + - name: name + type: string + type: function +defaults: + stage: baz +mixins: + my_mixin: + stage: foo diff --git a/tests/test_data/projects/mixins_different_entities/environment.yml b/tests/test_data/projects/mixins_different_entities/environment.yml new file mode 100644 index 0000000000..ac8feac3e8 --- /dev/null +++ b/tests/test_data/projects/mixins_different_entities/environment.yml @@ -0,0 +1,5 @@ +name: sf_env +channels: + - snowflake +dependencies: + - pandas diff --git a/tests/test_data/projects/mixins_different_entities/pages/my_page.py b/tests/test_data/projects/mixins_different_entities/pages/my_page.py new file mode 100644 index 0000000000..bc3ecbccba --- /dev/null +++ b/tests/test_data/projects/mixins_different_entities/pages/my_page.py @@ -0,0 +1,3 @@ +import streamlit as st + +st.title("Example page") diff --git a/tests/test_data/projects/mixins_different_entities/snowflake.yml b/tests/test_data/projects/mixins_different_entities/snowflake.yml new file mode 100644 index 0000000000..ed99c5fdf3 --- /dev/null +++ b/tests/test_data/projects/mixins_different_entities/snowflake.yml @@ -0,0 +1,45 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: my_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello + identifier: name + returns: string + signature: + - name: name + type: string + type: function + streamlit1: + artifacts: + - streamlit_app.py + - environment.yml + - pages + identifier: + name: test_streamlit + pages_dir: non_existent_dir + query_warehouse: test_warehouse + stage: streamlit + title: My Fancy Streamlit + type: streamlit + meta: + use_mixins: my_mixin +defaults: + stage: baz +mixins: + my_mixin: + stage: foo + main_file: streamlit_app.py + pages_dir: pages diff --git a/tests/test_data/projects/mixins_different_entities/streamlit_app.py b/tests/test_data/projects/mixins_different_entities/streamlit_app.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_data/projects/mixins_list_applied_in_order/environment.yml b/tests/test_data/projects/mixins_list_applied_in_order/environment.yml new file mode 100644 index 0000000000..ac8feac3e8 --- /dev/null +++ b/tests/test_data/projects/mixins_list_applied_in_order/environment.yml @@ -0,0 +1,5 @@ +name: sf_env +channels: + - snowflake +dependencies: + - pandas diff --git a/tests/test_data/projects/mixins_list_applied_in_order/pages/my_page.py b/tests/test_data/projects/mixins_list_applied_in_order/pages/my_page.py new file mode 100644 index 0000000000..bc3ecbccba --- /dev/null +++ b/tests/test_data/projects/mixins_list_applied_in_order/pages/my_page.py @@ -0,0 +1,3 @@ +import streamlit as st + +st.title("Example page") diff --git a/tests/test_data/projects/mixins_list_applied_in_order/snowflake.yml b/tests/test_data/projects/mixins_list_applied_in_order/snowflake.yml new file mode 100644 index 0000000000..9a536fa768 --- /dev/null +++ b/tests/test_data/projects/mixins_list_applied_in_order/snowflake.yml @@ -0,0 +1,58 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: + - second_mixin + - first_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello + identifier: name + returns: string + meta: + use_mixins: + - third_mixin + signature: + - name: name + type: string + type: function + streamlit1: + artifacts: + - streamlit_app.py + - environment.yml + - pages + identifier: + name: test_streamlit + pages_dir: non_existent_dir + query_warehouse: test_warehouse + stage: streamlit + title: My Fancy Streamlit + type: streamlit + meta: + use_mixins: + - first_mixin + - second_mixin +mixins: + first_mixin: + stage: foo + main_file: streamlit_app.py + pages_dir: non_existent_pages_dir + second_mixin: + stage: bar + main_file: streamlit_app.py + pages_dir: pages + third_mixin: + stage: baz + main_file: streamlit_app.py + pages_dir: pages diff --git a/tests/test_data/projects/mixins_list_applied_in_order/streamlit_app.py b/tests/test_data/projects/mixins_list_applied_in_order/streamlit_app.py new file mode 100644 index 0000000000..e69de29bb2 From 43ef0584dec5e81fa075f889e2a763196495a262 Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Mon, 26 Aug 2024 13:54:43 +0200 Subject: [PATCH 3/4] Snow 1625040 unify sql template syntax (#1458) * Add <% ... %> syntax to SQL rendering * add unit tests * fix nativeapp rendering * update release notes * Add integration tests * update nativeapp unit tests * Fix Windows paths * self-review * Change SQL rendering to choose syntax depending on template * revert nativeapp changes * refactor nativeapp usage * use SecurePath to open files --- RELEASE-NOTES.md | 1 + .../cli/_plugins/nativeapp/manager.py | 56 ++++++++++--------- src/snowflake/cli/api/rendering/jinja.py | 36 +++++++++--- .../cli/api/rendering/sql_templates.py | 49 ++++++++++++---- tests/test_sql.py | 49 +++++++++++++--- tests_integration/test_sql_templating.py | 19 ++++--- 6 files changed, 150 insertions(+), 60 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index bf1b83c2c8..d1cba58a63 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -34,6 +34,7 @@ * Added `snow spcs service execute-job` command, which supports creating and executing a job service in the current schema. * Added `snow app events` command to fetch logs and traces from local and customer app installations. * Added support for external access (api integrations and secrets) in Streamlit. +* Added support for `<% ... %>` syntax in SQL templating. * Support multiple Streamlit application in single snowflake.yml project definition file. ## Fixes and improvements diff --git a/src/snowflake/cli/_plugins/nativeapp/manager.py b/src/snowflake/cli/_plugins/nativeapp/manager.py index 46575240a5..0bc7ccf604 100644 --- a/src/snowflake/cli/_plugins/nativeapp/manager.py +++ b/src/snowflake/cli/_plugins/nativeapp/manager.py @@ -23,7 +23,7 @@ from functools import cached_property from pathlib import Path from textwrap import dedent -from typing import Any, Generator, List, NoReturn, Optional, TypedDict +from typing import Any, Callable, Dict, Generator, List, NoReturn, Optional, TypedDict import jinja2 from click import ClickException @@ -67,7 +67,6 @@ ) from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli._plugins.stage.utils import print_diff_to_console -from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.errno import ( DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, @@ -84,9 +83,13 @@ identifier_for_url, unquote_identifier, ) +from snowflake.cli.api.rendering.jinja import ( + jinja_render_from_str, +) from snowflake.cli.api.rendering.sql_templates import ( - get_sql_cli_jinja_env, + snowflake_sql_jinja_render, ) +from snowflake.cli.api.secure_path import UNLIMITED, SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.connector import DictCursor, ProgrammingError @@ -576,30 +579,36 @@ def create_app_package(self) -> None: ) ) - def _expand_script_templates( - self, env: jinja2.Environment, jinja_context: dict[str, Any], scripts: List[str] + def _render_script_templates( + self, + render_from_str: Callable[[str, Dict[str, Any]], str], + jinja_context: dict[str, Any], + scripts: List[str], ) -> List[str]: """ Input: - - env: Jinja2 environment + - render_from_str: function which renders a jinja template from a string and jinja context - jinja_context: a dictionary with the jinja context - - scripts: list of scripts that need to be expanded with Jinja + - scripts: list of script paths relative to the project root Returns: - - List of expanded scripts content. + - List of rendered scripts content Size of the return list is the same as the size of the input scripts list. """ scripts_contents = [] for relpath in scripts: + script_full_path = SecurePath(self.project_root) / relpath try: - template = env.get_template(relpath) - result = template.render(**jinja_context) + template_content = script_full_path.read_text( + file_size_limit_mb=UNLIMITED + ) + result = render_from_str(template_content, jinja_context) scripts_contents.append(result) - except jinja2.TemplateNotFound as e: - raise MissingScriptError(e.name) from e + except FileNotFoundError as e: + raise MissingScriptError(relpath) from e except jinja2.TemplateSyntaxError as e: - raise InvalidScriptError(e.name, e, e.lineno) from e + raise InvalidScriptError(relpath, e, e.lineno) from e except jinja2.UndefinedError as e: raise InvalidScriptError(relpath, e) from e @@ -617,14 +626,10 @@ def _apply_package_scripts(self) -> None: "WARNING: native_app.package.scripts is deprecated. Please migrate to using native_app.package.post_deploy." ) - env = jinja2.Environment( - loader=jinja2.loaders.FileSystemLoader(self.project_root), - keep_trailing_newline=True, - undefined=jinja2.StrictUndefined, - ) - - queued_queries = self._expand_script_templates( - env, dict(package_name=self.package_name), self.package_scripts + queued_queries = self._render_script_templates( + jinja_render_from_str, + dict(package_name=self.package_name), + self.package_scripts, ) # once we're sure all the templates expanded correctly, execute all of them @@ -678,11 +683,10 @@ def _execute_post_deploy_hooks( f"Unsupported {deployed_object_type} post-deploy hook type: {hook}" ) - env = get_sql_cli_jinja_env( - loader=jinja2.loaders.FileSystemLoader(self.project_root) - ) - scripts_content_list = self._expand_script_templates( - env, get_cli_context().template_context, sql_scripts_paths + scripts_content_list = self._render_script_templates( + snowflake_sql_jinja_render, + {}, + sql_scripts_paths, ) for index, sql_script_path in enumerate(sql_scripts_paths): diff --git a/src/snowflake/cli/api/rendering/jinja.py b/src/snowflake/cli/api/rendering/jinja.py index 299cb8ac8c..e65bf6ceac 100644 --- a/src/snowflake/cli/api/rendering/jinja.py +++ b/src/snowflake/cli/api/rendering/jinja.py @@ -17,7 +17,7 @@ from pathlib import Path from textwrap import dedent -from typing import Dict, Optional +from typing import Any, Dict, Optional import jinja2 from jinja2 import Environment, StrictUndefined, loaders @@ -82,8 +82,32 @@ def getitem(self, obj, argument): return self.undefined(obj=obj, name=argument) +def _get_jinja_env(loader: Optional[loaders.BaseLoader] = None) -> Environment: + return env_bootstrap( + IgnoreAttrEnvironment( + loader=loader or loaders.BaseLoader(), + keep_trailing_newline=True, + undefined=StrictUndefined, + ) + ) + + +def jinja_render_from_str(template_content: str, data: Dict[str, Any]) -> str: + """ + Renders a jinja template and outputs either the rendered contents as string or writes to a file. + + Args: + template_content (str): template contents + data (dict): A dictionary of jinja variables and their actual values + + Returns: + None if file path is provided, else returns the rendered string. + """ + return _get_jinja_env().from_string(template_content).render(data) + + def jinja_render_from_file( - template_path: Path, data: Dict, output_file_path: Optional[Path] = None + template_path: Path, data: Dict[str, Any], output_file_path: Optional[Path] = None ) -> Optional[str]: """ Renders a jinja template and outputs either the rendered contents as string or writes to a file. @@ -96,12 +120,8 @@ def jinja_render_from_file( Returns: None if file path is provided, else returns the rendered string. """ - env = env_bootstrap( - IgnoreAttrEnvironment( - loader=loaders.FileSystemLoader(template_path.parent), - keep_trailing_newline=True, - undefined=StrictUndefined, - ) + env = _get_jinja_env( + loader=loaders.FileSystemLoader(template_path.parent.as_posix()) ) loaded_template = env.get_template(template_path.name) rendered_result = loaded_template.render(**data) diff --git a/src/snowflake/cli/api/rendering/sql_templates.py b/src/snowflake/cli/api/rendering/sql_templates.py index b2eea68e75..f832417670 100644 --- a/src/snowflake/cli/api/rendering/sql_templates.py +++ b/src/snowflake/cli/api/rendering/sql_templates.py @@ -14,11 +14,13 @@ from __future__ import annotations -from typing import Dict, Optional +from typing import Dict from click import ClickException -from jinja2 import StrictUndefined, loaders +from jinja2 import Environment, StrictUndefined, loaders, meta from snowflake.cli.api.cli_global_context import get_cli_context +from snowflake.cli.api.console.console import cli_console +from snowflake.cli.api.exceptions import InvalidTemplate from snowflake.cli.api.rendering.jinja import ( CONTEXT_KEY, FUNCTION_KEY, @@ -26,26 +28,52 @@ env_bootstrap, ) -_SQL_TEMPLATE_START = "&{" -_SQL_TEMPLATE_END = "}" +_SQL_TEMPLATE_START = "<%" +_SQL_TEMPLATE_END = "%>" +_OLD_SQL_TEMPLATE_START = "&{" +_OLD_SQL_TEMPLATE_END = "}" RESERVED_KEYS = [CONTEXT_KEY, FUNCTION_KEY] -def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None): +def _get_sql_jinja_env(template_start: str, template_end: str) -> Environment: _random_block = "___very___unique___block___to___disable___logic___blocks___" return env_bootstrap( IgnoreAttrEnvironment( - loader=loader or loaders.BaseLoader(), - keep_trailing_newline=True, - variable_start_string=_SQL_TEMPLATE_START, - variable_end_string=_SQL_TEMPLATE_END, + variable_start_string=template_start, + variable_end_string=template_end, + loader=loaders.BaseLoader(), block_start_string=_random_block, block_end_string=_random_block, + keep_trailing_newline=True, undefined=StrictUndefined, ) ) +def _does_template_have_env_syntax(env: Environment, template_content: str) -> bool: + template = env.parse(template_content) + return bool(meta.find_undeclared_variables(template)) + + +def choose_sql_jinja_env_based_on_template_syntax(template_content: str) -> Environment: + old_syntax_env = _get_sql_jinja_env(_OLD_SQL_TEMPLATE_START, _OLD_SQL_TEMPLATE_END) + new_syntax_env = _get_sql_jinja_env(_SQL_TEMPLATE_START, _SQL_TEMPLATE_END) + has_old_syntax = _does_template_have_env_syntax(old_syntax_env, template_content) + has_new_syntax = _does_template_have_env_syntax(new_syntax_env, template_content) + if has_old_syntax and has_new_syntax: + raise InvalidTemplate( + f"The SQL query mixes {_OLD_SQL_TEMPLATE_START} ... {_OLD_SQL_TEMPLATE_END} syntax" + f" and {_SQL_TEMPLATE_START} ... {_SQL_TEMPLATE_END} syntax." + ) + if has_old_syntax: + cli_console.warning( + f"Warning: {_OLD_SQL_TEMPLATE_START} ... {_OLD_SQL_TEMPLATE_END} syntax is deprecated." + f" Use {_SQL_TEMPLATE_START} ... {_SQL_TEMPLATE_END} syntax instead." + ) + return old_syntax_env + return new_syntax_env + + def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str: data = data or {} @@ -57,4 +85,5 @@ def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str: context_data = get_cli_context().template_context context_data.update(data) - return get_sql_cli_jinja_env().from_string(content).render(**context_data) + env = choose_sql_jinja_env_based_on_template_syntax(content) + return env.from_string(content).render(context_data) diff --git a/tests/test_sql.py b/tests/test_sql.py index a1ba6cd26b..12169b6115 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory from unittest import mock @@ -321,6 +320,7 @@ def test_use_command(mock_execute_query, _object): "select &{ aaa }.&{ bbb }", "select &aaa.&bbb", "select &aaa.&{ bbb }", + "select <% aaa %>.<% bbb %>", ], ) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") @@ -332,7 +332,29 @@ def test_rendering_of_sql(mock_execute_query, query, runner): ) -@pytest.mark.parametrize("query", ["select &{ foo }", "select &foo"]) +@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") +def test_old_template_syntax_causes_warning(mock_execute_query, runner): + result = runner.invoke(["sql", "-q", "select &{ aaa }", "-D", "aaa=foo"]) + assert result.exit_code == 0 + assert ( + "Warning: &{ ... } syntax is deprecated. Use <% ... %> syntax instead." + in result.output + ) + mock_execute_query.assert_called_once_with("select foo", cursor_class=VerboseCursor) + + +@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") +def test_mixed_template_syntax_error(mock_execute_query, runner): + result = runner.invoke( + ["sql", "-q", "select <% aaa %>.&{ bbb }", "-D", "aaa=foo", "-D", "bbb=bar"] + ) + assert result.exit_code == 1 + assert "The SQL query mixes &{ ... } syntax and <% ... %> syntax." in result.output + + +@pytest.mark.parametrize( + "query", ["select &{ foo }", "select &foo", "select <% foo %>"] +) def test_execution_fails_if_unknown_variable(runner, query): result = runner.invoke(["sql", "-q", query, "-D", "bbb=1"]) assert "SQL template rendering error: 'foo' is undefined" in result.output @@ -356,12 +378,15 @@ def test_snowsql_compatibility(text, expected): assert transpile_snowsql_templates(text) == expected +@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")]) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") def test_uses_variables_from_snowflake_yml( - mock_execute_query, project_directory, runner + mock_execute_query, project_directory, runner, template_start, template_end ): with project_directory("sql_templating"): - result = runner.invoke(["sql", "-q", "select &{ ctx.env.sf_var }"]) + result = runner.invoke( + ["sql", "-q", f"select {template_start} ctx.env.sf_var {template_end}"] + ) assert result.exit_code == 0 mock_execute_query.assert_called_once_with( @@ -369,12 +394,19 @@ def test_uses_variables_from_snowflake_yml( ) +@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")]) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") def test_uses_variables_from_snowflake_local_yml( - mock_execute_query, project_directory, runner + mock_execute_query, project_directory, runner, template_start, template_end ): with project_directory("sql_templating"): - result = runner.invoke(["sql", "-q", "select &{ ctx.env.sf_var_override }"]) + result = runner.invoke( + [ + "sql", + "-q", + f"select {template_start} ctx.env.sf_var_override {template_end}", + ] + ) assert result.exit_code == 0 mock_execute_query.assert_called_once_with( @@ -382,16 +414,17 @@ def test_uses_variables_from_snowflake_local_yml( ) +@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")]) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") def test_uses_variables_from_cli_are_added_outside_context( - mock_execute_query, project_directory, runner + mock_execute_query, project_directory, runner, template_start, template_end ): with project_directory("sql_templating"): result = runner.invoke( [ "sql", "-q", - "select &{ ctx.env.sf_var } &{ other }", + f"select {template_start} ctx.env.sf_var {template_end} {template_start} other {template_end}", "-D", "other=other_value", ] diff --git a/tests_integration/test_sql_templating.py b/tests_integration/test_sql_templating.py index 7d69acdf1f..2d0de60a5d 100644 --- a/tests_integration/test_sql_templating.py +++ b/tests_integration/test_sql_templating.py @@ -18,7 +18,7 @@ @pytest.mark.integration def test_sql_env_value_from_cli_param(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'", "--env", "test=value_from_cli"] + ["sql", "-q", "select '<% ctx.env.test %>'", "--env", "test=value_from_cli"] ) assert result.exit_code == 0 @@ -28,7 +28,7 @@ def test_sql_env_value_from_cli_param(runner, snowflake_session): @pytest.mark.integration def test_sql_env_value_from_cli_param_that_is_blank(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'", "--env", "test="] + ["sql", "-q", "select '<% ctx.env.test %>'", "--env", "test="] ) assert result.exit_code == 0 @@ -38,7 +38,7 @@ def test_sql_env_value_from_cli_param_that_is_blank(runner, snowflake_session): @pytest.mark.integration def test_sql_undefined_env_causing_error(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'"] + ["sql", "-q", "select '<% ctx.env.test %>'"] ) assert result.exit_code == 1 @@ -48,7 +48,7 @@ def test_sql_undefined_env_causing_error(runner, snowflake_session): @pytest.mark.integration def test_sql_env_value_from_os_env(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'"], env={"test": "value_from_os_env"} + ["sql", "-q", "select '<% ctx.env.test %>'"], env={"test": "value_from_os_env"} ) assert result.exit_code == 0 @@ -58,7 +58,7 @@ def test_sql_env_value_from_os_env(runner, snowflake_session): @pytest.mark.integration def test_sql_env_value_from_cli_param_overriding_os_env(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'", "--env", "test=value_from_cli"], + ["sql", "-q", "select '<% ctx.env.test %>'", "--env", "test=value_from_cli"], env={"test": "value_from_os_env"}, ) @@ -72,7 +72,7 @@ def test_sql_env_value_from_cli_duplicate_arg(runner, snowflake_session): [ "sql", "-q", - "select '&{ctx.env.Test}'", + "select '<% ctx.env.Test %>'", "--env", "Test=firstArg", "--env", @@ -84,13 +84,16 @@ def test_sql_env_value_from_cli_duplicate_arg(runner, snowflake_session): assert result.json == [{"'SECONDARG'": "secondArg"}] +@pytest.mark.parametrize("t_start,t_end", [("&{", "}"), ("<%", "%>")]) @pytest.mark.integration -def test_sql_env_value_from_cli_multiple_args(runner, snowflake_session): +def test_sql_env_value_from_cli_multiple_args( + runner, snowflake_session, t_start, t_end +): result = runner.invoke_with_connection_json( [ "sql", "-q", - "select '&{ctx.env.Test1}-&{ctx.env.Test2}'", + f"select '{t_start}ctx.env.Test1{t_end}-{t_start}ctx.env.Test2{t_end}'", "--env", "Test1=test1", "--env", From 54e8d46ec496ce9f5406ed1efea92aed328b13c7 Mon Sep 17 00:00:00 2001 From: Michel El Nacouzi Date: Mon, 26 Aug 2024 08:26:02 -0400 Subject: [PATCH 4/4] Fix integer indices in error reporting (#1487) --- src/snowflake/cli/api/project/errors.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/snowflake/cli/api/project/errors.py b/src/snowflake/cli/api/project/errors.py index e11a283825..3c272821b6 100644 --- a/src/snowflake/cli/api/project/errors.py +++ b/src/snowflake/cli/api/project/errors.py @@ -29,10 +29,25 @@ class SchemaValidationError(ClickException): def __init__(self, error: ValidationError): errors = error.errors() message = f"During evaluation of {error.title} in project definition following errors were encountered:\n" + + def calculate_location(e): + if e["loc"] is None: + return None + + # show numbers as list indexes and strings as dictionary keys. Example: key1[0].key2 + result = "".join( + f"[{item}]" if isinstance(item, int) else f".{item}" + for item in e["loc"] + ) + + # remove leading dot from the string if any: + return result[1:] if result.startswith(".") else result + message += "\n".join( [ self.message_templates.get(e["type"], self.generic_message).format( - **e, location=".".join(e["loc"]) if e["loc"] is not None else None + **e, + location=calculate_location(e), ) for e in errors ]