From e03ba65216a31bc309b158b2ad4e8ed7aab7d03e Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Mon, 26 Aug 2024 11:41:07 +0200 Subject: [PATCH] Post-review-fixes --- .../api/project/schemas/entities/common.py | 13 +++++++-- .../api/project/schemas/project_definition.py | 29 +++++++++++-------- tests/project/test_project_definition_v2.py | 3 ++ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index 4b748bf639..8ae50946ae 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -17,7 +17,7 @@ from abc import ABC from typing import Generic, List, Optional, TypeVar, Union -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, field_validator from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.schemas.identifier_model import Identifier from snowflake.cli.api.project.schemas.native_app.application import ( @@ -41,11 +41,20 @@ class MetaField(UpdatableModel): title="Actions that will be executed after the application object is created/upgraded", default=None, ) - use_mixins: Optional[str | List[str]] = Field( + use_mixins: Optional[List[str]] = Field( title="Name of the mixin used to fill the entity fields", default=None, ) + @field_validator("use_mixins", mode="before") + @classmethod + def ensure_use_mixins_is_a_list( + cls, mixins: Optional[str | List[str]] + ) -> Optional[List[str]]: + if isinstance(mixins, str): + return [mixins] + return mixins + class DefaultsField(UpdatableModel): schema_: Optional[str] = Field( diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 0f62749238..6545dc5aa2 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -201,18 +201,23 @@ def apply_mixins(cls, data: Dict) -> Dict: """ Applies mixins to those entities, whose meta field contains the mixin name. """ - if "mixins" in data and "entities" in data: - for entity in data["entities"].values(): - entity_mixins = entity_mixins_to_list( - entity.get("meta", {}).get("use_mixins") - ) - entity_fields = get_allowed_fields_for_entity(entity) - if entity_fields and entity_mixins: - for mixin_name in entity_mixins: - if mixin_name in data["mixins"]: - for key, value in data["mixins"][mixin_name].items(): - if key in entity_fields: - entity[key] = value + if "mixins" not in data or "entities" not in data: + return data + + for entity in data["entities"].values(): + entity_mixins = entity_mixins_to_list( + entity.get("meta", {}).get("use_mixins") + ) + + entity_fields = get_allowed_fields_for_entity(entity) + if entity_fields and entity_mixins: + for mixin_name in entity_mixins: + if mixin_name in data["mixins"]: + for key, value in data["mixins"][mixin_name].items(): + if key in entity_fields: + entity[key] = value + else: + raise ValueError(f"Mixin {mixin_name} not found in mixins") return data def get_entities_by_type(self, entity_type: str): diff --git a/tests/project/test_project_definition_v2.py b/tests/project/test_project_definition_v2.py index 2f6849065a..d6678b334e 100644 --- a/tests/project/test_project_definition_v2.py +++ b/tests/project/test_project_definition_v2.py @@ -385,6 +385,9 @@ def test_v1_to_v2_conversion( _assert_entities_are_equal(v1_function, v2_function) +# TODO: +# 1. rewrite projects to have one big definition covering all complex positive cases +# 2. Add negative case - entity uses non-existent mixin @pytest.mark.parametrize( "project_name,stage1,stage2", [("mixins_basic", "foo", "bar"), ("mixins_defaults_hierarchy", "foo", "baz")],