From 426173ad8759a86cd300f071de0f0c9fb27cc11a Mon Sep 17 00:00:00 2001 From: Guy Bloom Date: Fri, 12 Jul 2024 09:00:12 -0400 Subject: [PATCH] Project definition v2 entity schemas: application and application package (#1280) * application and application package entity schemas * add integration test * simplify entity type, use discriminator * create types map dynamically * get_type method * read entity list from union --- .../schemas/entities/application_entity.py | 50 ++++ .../entities/application_package_entity.py | 63 +++++ .../api/project/schemas/entities/common.py | 85 +++++++ .../api/project/schemas/entities/entities.py | 30 +++ .../api/project/schemas/project_definition.py | 76 +++++- tests/project/test_project_definition_v2.py | 224 ++++++++++++++++++ .../project_definition_v2/snowflake.yml | 38 +++ .../workspaces/test_validate_schema.py | 32 +++ 8 files changed, 589 insertions(+), 9 deletions(-) create mode 100644 src/snowflake/cli/api/project/schemas/entities/application_entity.py create mode 100644 src/snowflake/cli/api/project/schemas/entities/application_package_entity.py create mode 100644 src/snowflake/cli/api/project/schemas/entities/common.py create mode 100644 src/snowflake/cli/api/project/schemas/entities/entities.py create mode 100644 tests/project/test_project_definition_v2.py create mode 100644 tests_integration/test_data/projects/project_definition_v2/snowflake.yml create mode 100644 tests_integration/workspaces/test_validate_schema.py diff --git a/src/snowflake/cli/api/project/schemas/entities/application_entity.py b/src/snowflake/cli/api/project/schemas/entities/application_entity.py new file mode 100644 index 0000000000..983e9b15d0 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/entities/application_entity.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import AliasChoices, Field +from snowflake.cli.api.project.schemas.entities.application_package_entity import ( + ApplicationPackageEntity, +) +from snowflake.cli.api.project.schemas.entities.common import ( + EntityBase, + TargetField, +) +from snowflake.cli.api.project.schemas.updatable_model import ( + UpdatableModel, +) + + +class ApplicationEntity(EntityBase): + type: Literal["application"] # noqa: A003 + name: str = Field( + title="Name of the application created when this entity is deployed" + ) + from_: ApplicationFromField = Field( + validation_alias=AliasChoices("from"), + title="An application package this entity should be created from", + ) + debug: Optional[bool] = Field( + title="Whether to enable debug mode when using a named stage to create an application object", + default=None, + ) + + +class ApplicationFromField(UpdatableModel): + target: TargetField[ApplicationPackageEntity] = Field( + title="Reference to an application package entity", + ) diff --git a/src/snowflake/cli/api/project/schemas/entities/application_package_entity.py b/src/snowflake/cli/api/project/schemas/entities/application_package_entity.py new file mode 100644 index 0000000000..d9684e69e3 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/entities/application_package_entity.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import List, Literal, Optional, Union + +from pydantic import Field +from snowflake.cli.api.project.schemas.entities.common import ( + EntityBase, +) +from snowflake.cli.api.project.schemas.native_app.package import DistributionOptions +from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping +from snowflake.cli.api.project.schemas.updatable_model import IdentifierField + + +class ApplicationPackageEntity(EntityBase): + type: Literal["application package"] # noqa: A003 + name: str = Field( + title="Name of the application package created when this entity is deployed" + ) + artifacts: List[Union[PathMapping, Path]] = Field( + title="List of paths or file source/destination pairs to add to the deploy root", + ) + bundle_root: Optional[Path] = Field( + title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored.", + default=Path("output/bundle/"), + ) + deploy_root: Optional[Path] = Field( + title="Folder at the root of your project where the build step copies the artifacts", + default=Path("output/deploy/"), + ) + generated_root: Optional[Path] = Field( + title="Subdirectory of the deploy root where files generated by the Snowflake CLI will be written.", + default=Path("__generated/"), + ) + stage: Optional[str] = IdentifierField( + title="Identifier of the stage that stores the application artifacts.", + default="app_src.stage", + ) + scratch_stage: Optional[str] = IdentifierField( + title="Identifier of the stage that stores temporary scratch data used by the Snowflake CLI.", + default="app_src.stage_snowflake_cli_scratch", + ) + distribution: Optional[DistributionOptions] = Field( + title="Distribution of the application package created by the Snowflake CLI", + default="internal", + ) + manifest: Path = Field( + title="Path to manifest.yml", + ) diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py new file mode 100644 index 0000000000..5b50ac1f3a --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC +from typing import Generic, List, Optional, TypeVar + +from pydantic import AliasChoices, Field, GetCoreSchemaHandler, ValidationInfo +from pydantic_core import core_schema +from snowflake.cli.api.project.schemas.native_app.application import ( + ApplicationPostDeployHook, +) +from snowflake.cli.api.project.schemas.updatable_model import ( + IdentifierField, + UpdatableModel, +) + + +class MetaField(UpdatableModel): + warehouse: Optional[str] = IdentifierField( + title="Warehouse used to run the scripts", default=None + ) + role: Optional[str] = IdentifierField( + title="Role to use when creating the entity object", + default=None, + ) + post_deploy: Optional[List[ApplicationPostDeployHook]] = Field( + title="Actions that will be executed after the application object is created/upgraded", + default=None, + ) + + +class DefaultsField(UpdatableModel): + schema_: Optional[str] = Field( + title="Schema.", + validation_alias=AliasChoices("schema"), + default=None, + ) + stage: Optional[str] = Field( + title="Stage.", + default=None, + ) + + +class EntityBase(ABC, UpdatableModel): + @classmethod + def get_type(cls) -> str: + return cls.model_fields["type"].annotation.__args__[0] + + meta: Optional[MetaField] = Field(title="Meta fields", default=None) + + +TargetType = TypeVar("TargetType") + + +class TargetField(Generic[TargetType]): + def __init__(self, entity_target_key: str): + self.value = entity_target_key + + def __repr__(self): + return self.value + + @classmethod + def validate(cls, value: str, info: ValidationInfo) -> TargetField: + return cls(value) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + cls.validate, handler(str), field_name=handler.field_name + ) diff --git a/src/snowflake/cli/api/project/schemas/entities/entities.py b/src/snowflake/cli/api/project/schemas/entities/entities.py new file mode 100644 index 0000000000..d6cfb85057 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/entities/entities.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Union, get_args + +from snowflake.cli.api.project.schemas.entities.application_entity import ( + ApplicationEntity, +) +from snowflake.cli.api.project.schemas.entities.application_package_entity import ( + ApplicationPackageEntity, +) + +Entity = Union[ApplicationEntity, ApplicationPackageEntity] + +ALL_ENTITIES = [*get_args(Entity)] + +v2_entity_types_map = {e.get_type(): e for e in ALL_ENTITIES} diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 420decc1c2..03385b79c6 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -18,15 +18,27 @@ from typing import Dict, Optional, Union from packaging.version import Version -from pydantic import Field, ValidationError, field_validator +from pydantic import Field, ValidationError, field_validator, model_validator from snowflake.cli.api.feature_flags import FeatureFlag from snowflake.cli.api.project.errors import SchemaValidationError +from snowflake.cli.api.project.schemas.entities.application_entity import ( + ApplicationEntity, +) +from snowflake.cli.api.project.schemas.entities.common import ( + DefaultsField, + TargetField, +) +from snowflake.cli.api.project.schemas.entities.entities import ( + Entity, + v2_entity_types_map, +) from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel from snowflake.cli.api.utils.models import ProjectEnvironment from snowflake.cli.api.utils.types import Context +from typing_extensions import Annotated @dataclass @@ -105,8 +117,60 @@ def _convert_env( class DefinitionV20(_ProjectDefinitionBase): - entities: Dict = Field( - title="Entity definitions.", + entities: Dict[str, Annotated[Entity, Field(discriminator="type")]] = Field( + title="Entity definitions." + ) + + @model_validator(mode="before") + @classmethod + def apply_defaults(cls, data: Dict) -> Dict: + """ + Applies default values that exist on the model but not specified in yml + """ + if "defaults" in data and "entities" in data: + for key, entity in data["entities"].items(): + entity_type = entity["type"] + if entity_type not in v2_entity_types_map: + continue + entity_model = v2_entity_types_map[entity_type] + for default_key, default_value in data["defaults"].items(): + if ( + default_key in entity_model.model_fields + and default_key not in entity + ): + entity[default_key] = default_value + return data + + @field_validator("entities", mode="after") + @classmethod + def validate_entities(cls, entities: Dict[str, Entity]) -> Dict[str, Entity]: + for key, entity in entities.items(): + # TODO Automatically detect TargetFields to validate + if entity.type == ApplicationEntity.get_type(): + if isinstance(entity.from_.target, TargetField): + target_key = str(entity.from_.target) + target_class = entity.from_.__class__.model_fields["target"] + target_type = target_class.annotation.__args__[0] + cls._validate_target_field(target_key, target_type, entities) + return entities + + @classmethod + def _validate_target_field( + cls, target_key: str, target_type: Entity, entities: Dict[str, Entity] + ): + if target_key not in entities: + raise ValueError(f"No such target: {target_key}") + else: + # Validate the target type + actual_target_type = entities[target_key].__class__ + if target_type and target_type is not actual_target_type: + raise ValueError( + f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}" + ) + + defaults: Optional[DefaultsField] = Field( + title="Default key/value entity values that are merged recursively for each entity.", + default=None, ) env: Union[Dict[str, str], ProjectEnvironment, None] = Field( @@ -125,12 +189,6 @@ def _convert_env( return env return ProjectEnvironment(default_env=(env or {}), override_env={}) - @field_validator("entities") - @classmethod - def validate_entities(cls, entities: Dict) -> Dict: - # TODO Add entities validation logic - return entities - def build_project_definition(**data): """ diff --git a/tests/project/test_project_definition_v2.py b/tests/project/test_project_definition_v2.py new file mode 100644 index 0000000000..68788a6285 --- /dev/null +++ b/tests/project/test_project_definition_v2.py @@ -0,0 +1,224 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +from snowflake.cli.api.project.errors import SchemaValidationError +from snowflake.cli.api.project.schemas.entities.entities import ( + v2_entity_types_map, +) +from snowflake.cli.api.project.schemas.project_definition import ( + DefinitionV20, +) + +from tests.testing_utils.mock_config import mock_config_key + + +@pytest.mark.parametrize( + "definition_input,expected_error", + [ + [{}, "Your project definition is missing the following field: 'entities'"], + [{"entities": {}}, None], + [{"entities": {}, "defaults": {}, "env": {}}, None], + [ + {"entities": {}, "extra": "field"}, + "You provided field 'extra' with value 'field' that is not supported in given version", + ], + [ + {"entities": {"entity": {"type": "invalid_type"}}}, + "Input tag 'invalid_type' found using 'type' does not match any of the expected tags", + ], + # Application package tests + [ + {"entities": {"pkg": {"type": "application package"}}}, + [ + "missing the following field: 'entities.pkg.application package.name'", + "missing the following field: 'entities.pkg.application package.artifacts'", + "missing the following field: 'entities.pkg.application package.manifest'", + ], + ], + [ + { + "entities": { + "pkg": { + "type": "application package", + "name": "", + "artifacts": [], + "manifest": "", + } + } + }, + None, + ], + [ + { + "entities": { + "pkg": { + "type": "application package", + "name": "", + "artifacts": [], + "manifest": "", + "bundle_root": "", + "deploy_root": "", + "generated_root": "", + "stage": "stage", + "scratch_stage": "scratch_stage", + "distribution": "internal", + } + } + }, + None, + ], + [ + { + "entities": { + "pkg": { + "type": "application package", + "name": "", + "artifacts": [], + "manifest": "", + "distribution": "invalid", + } + } + }, + "Input should be 'internal', 'external', 'INTERNAL' or 'EXTERNAL'", + ], + # Application tests + [ + {"entities": {"app": {"type": "application"}}}, + [ + "Your project definition is missing the following field: 'entities.app.application.name'", + "Your project definition is missing the following field: 'entities.app.application.from'", + ], + ], + [ + { + "entities": { + "app": { + "type": "application", + "name": "", + "from": {"target": "non_existing"}, + } + } + }, + "No such target: non_existing", + ], + [ + { + "entities": { + "pkg": { + "type": "application package", + "name": "", + "artifacts": [], + "manifest": "", + }, + "app": { + "type": "application", + "name": "", + "from": {"target": "pkg"}, + }, + } + }, + None, + ], + # Meta fields + [ + { + "entities": { + "pkg": { + "type": "application package", + "name": "", + "artifacts": [], + "manifest": "", + "meta": { + "warehouse": "warehouse", + "role": "role", + "post_deploy": [{"sql_script": "script.sql"}], + }, + }, + "app": { + "type": "application", + "name": "", + "from": {"target": "pkg"}, + "meta": { + "warehouse": "warehouse", + "role": "role", + "post_deploy": [{"sql_script": "script.sql"}], + }, + }, + } + }, + None, + ], + ], +) +def test_project_definition_v2_schema(definition_input, expected_error): + definition_input["definition_version"] = "2" + with mock_config_key("enable_project_definition_v2", True): + try: + DefinitionV20(**definition_input) + except SchemaValidationError as err: + if expected_error: + if type(expected_error) == str: + assert expected_error in str(err) + else: + for err_msg in expected_error: + assert err_msg in str(err) + else: + raise err + + +def test_defaults_are_applied(): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "name": "", + "artifacts": [], + "manifest": "", + } + }, + "defaults": {"stage": "default_stage"}, + } + with mock_config_key("enable_project_definition_v2", True): + project = DefinitionV20(**definition_input) + assert project.entities["pkg"].stage == "default_stage" + + +def test_defaults_do_not_override_values(): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "name": "", + "artifacts": [], + "manifest": "", + "stage": "pkg_stage", + } + }, + "defaults": {"stage": "default_stage"}, + } + with mock_config_key("enable_project_definition_v2", True): + project = DefinitionV20(**definition_input) + assert project.entities["pkg"].stage == "pkg_stage" + + +# Verify that each entity type has the correct "type" field +def test_entity_types(): + v2_entity_types_map + for entity_type, entity_class in v2_entity_types_map.items(): + model_entity_type = entity_class.get_type() + assert model_entity_type == entity_type diff --git a/tests_integration/test_data/projects/project_definition_v2/snowflake.yml b/tests_integration/test_data/projects/project_definition_v2/snowflake.yml new file mode 100644 index 0000000000..e5b597a8d1 --- /dev/null +++ b/tests_integration/test_data/projects/project_definition_v2/snowflake.yml @@ -0,0 +1,38 @@ +definition_version: 2 + +entities: + pkg: + type: application package + name: my_app_pkg_<% ctx.env.foo %> + artifacts: + - src: src/**/* + dest: / + bundle_root: output/bundle/ + deploy_root: output/deploy/ + generated_root: __generated/ + stage: app_src.stage2 + scratch_stage: app_src.stage_snowflake_cli_scratch + distribution: internal + manifest: src/manifest.yml + meta: + warehouse: my_wh + role: my_role + post_deploy: + - sql_script: scripts/post_pkg_deploy.sql + app: + type: application + name: my_app_<% ctx.env.foo %> + from: + target: pkg + meta: + warehouse: my_wh + role: my_role + post_deploy: + - sql_script: scripts/post_app_deploy.sql + +defaults: + schema: default_schema + stage: default_stage + +env: + foo: bar diff --git a/tests_integration/workspaces/test_validate_schema.py b/tests_integration/workspaces/test_validate_schema.py new file mode 100644 index 0000000000..3c4220ce77 --- /dev/null +++ b/tests_integration/workspaces/test_validate_schema.py @@ -0,0 +1,32 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import os +from unittest import mock + + +@pytest.mark.integration +@mock.patch.dict( + os.environ, + { + "SNOWFLAKE_CLI_FEATURES_ENABLE_PROJECT_DEFINITION_V2": "true", + }, + clear=True, +) +def test_validate_project_definition_v2(runner, snowflake_session, project_directory): + with project_directory("project_definition_v2") as tmp_dir: + result = runner.invoke_with_connection_json(["ws", "validate"]) + + assert result.exit_code == 0