diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index 4868f1f396..1ee7b9564a 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -17,7 +17,7 @@ from abc import ABC from typing import Generic, List, Optional, TypeVar, Union -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, field_validator from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.schemas.identifier_model import Identifier from snowflake.cli.api.project.schemas.native_app.application import ( @@ -41,6 +41,19 @@ class MetaField(UpdatableModel): title="Actions that will be executed after the application object is created/upgraded", default=None, ) + use_mixins: Optional[List[str]] = Field( + title="Name of the mixin used to fill the entity fields", + default=None, + ) + + @field_validator("use_mixins", mode="before") + @classmethod + def ensure_use_mixins_is_a_list( + cls, mixins: Optional[str | List[str]] + ) -> Optional[List[str]]: + if isinstance(mixins, str): + return [mixins] + return mixins class DefaultsField(UpdatableModel): diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index ce74e38739..6545dc5aa2 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from packaging.version import Version from pydantic import Field, ValidationError, field_validator, model_validator @@ -123,15 +123,11 @@ def apply_defaults(cls, data: Dict) -> Dict: """ if "defaults" in data and "entities" in data: for key, entity in data["entities"].items(): - entity_type = entity["type"] - if entity_type not in v2_entity_model_types_map: + entity_fields = get_allowed_fields_for_entity(entity) + if not entity_fields: continue - entity_model = v2_entity_model_types_map[entity_type] for default_key, default_value in data["defaults"].items(): - if ( - default_key in entity_model.model_fields - and default_key not in entity - ): + if default_key in entity_fields and default_key not in entity: entity[default_key] = default_value return data @@ -194,6 +190,36 @@ def _validate_target_field( default=None, ) + mixins: Optional[Dict[str, Dict]] = Field( + title="Mixins to apply to entities", + default=None, + ) + + @model_validator(mode="before") + @classmethod + def apply_mixins(cls, data: Dict) -> Dict: + """ + Applies mixins to those entities, whose meta field contains the mixin name. + """ + if "mixins" not in data or "entities" not in data: + return data + + for entity in data["entities"].values(): + entity_mixins = entity_mixins_to_list( + entity.get("meta", {}).get("use_mixins") + ) + + entity_fields = get_allowed_fields_for_entity(entity) + if entity_fields and entity_mixins: + for mixin_name in entity_mixins: + if mixin_name in data["mixins"]: + for key, value in data["mixins"][mixin_name].items(): + if key in entity_fields: + entity[key] = value + else: + raise ValueError(f"Mixin {mixin_name} not found in mixins") + return data + def get_entities_by_type(self, entity_type: str): return {i: e for i, e in self.entities.items() if e.get_type() == entity_type} @@ -222,3 +248,26 @@ def get_version_map(): if FeatureFlag.ENABLE_PROJECT_DEFINITION_V2.is_enabled(): version_map["2"] = DefinitionV20 return version_map + + +def entity_mixins_to_list(entity_mixins: Optional[str | List[str]]) -> List[str]: + """ + Convert an optional string or a list of strings to a list of strings. + """ + if entity_mixins is None: + return [] + if isinstance(entity_mixins, str): + return [entity_mixins] + return entity_mixins + + +def get_allowed_fields_for_entity(entity: Dict[str, Any]) -> List[str]: + """ + Get the allowed fields for the given entity. + """ + entity_type = entity.get("type") + if entity_type not in v2_entity_model_types_map: + return [] + + entity_model = v2_entity_model_types_map[entity_type] + return entity_model.model_fields diff --git a/tests/project/test_project_definition_v2.py b/tests/project/test_project_definition_v2.py index 16ded60fa4..d6678b334e 100644 --- a/tests/project/test_project_definition_v2.py +++ b/tests/project/test_project_definition_v2.py @@ -385,6 +385,40 @@ def test_v1_to_v2_conversion( _assert_entities_are_equal(v1_function, v2_function) +# TODO: +# 1. rewrite projects to have one big definition covering all complex positive cases +# 2. Add negative case - entity uses non-existent mixin +@pytest.mark.parametrize( + "project_name,stage1,stage2", + [("mixins_basic", "foo", "bar"), ("mixins_defaults_hierarchy", "foo", "baz")], +) +def test_mixins(project_directory, project_name, stage1, stage2): + with project_directory(project_name) as project_dir: + definition = DefinitionManager(project_dir).project_definition + + assert definition.entities["function1"].stage == stage1 + assert definition.entities["function1"].handler == "app.hello" + assert definition.entities["function2"].stage == stage2 + assert definition.entities["function1"].handler == "app.hello" + + +def test_mixins_for_different_entities(project_directory): + with project_directory("mixins_different_entities") as project_dir: + definition = DefinitionManager(project_dir).project_definition + + assert definition.entities["function1"].stage == "foo" + assert definition.entities["streamlit1"].main_file == "streamlit_app.py" + + +def test_list_of_mixins_in_correct_order(project_directory): + with project_directory("mixins_list_applied_in_order") as project_dir: + definition = DefinitionManager(project_dir).project_definition + + assert definition.entities["function1"].stage == "foo" + assert definition.entities["function2"].stage == "baz" + assert definition.entities["streamlit1"].stage == "bar" + + def _assert_entities_are_equal( v1_entity: _CallableBase, v2_entity: SnowparkEntityModel ): diff --git a/tests/test_data/projects/mixins_basic/snowflake.yml b/tests/test_data/projects/mixins_basic/snowflake.yml new file mode 100644 index 0000000000..921415c943 --- /dev/null +++ b/tests/test_data/projects/mixins_basic/snowflake.yml @@ -0,0 +1,28 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: my_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello + identifier: name + returns: string + signature: + - name: name + type: string + stage: bar + type: function +mixins: + my_mixin: + stage: foo diff --git a/tests/test_data/projects/mixins_defaults_hierarchy/snowflake.yml b/tests/test_data/projects/mixins_defaults_hierarchy/snowflake.yml new file mode 100644 index 0000000000..9a5d517dad --- /dev/null +++ b/tests/test_data/projects/mixins_defaults_hierarchy/snowflake.yml @@ -0,0 +1,29 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: my_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello2 + identifier: name + returns: string + signature: + - name: name + type: string + type: function +defaults: + stage: baz +mixins: + my_mixin: + stage: foo diff --git a/tests/test_data/projects/mixins_different_entities/environment.yml b/tests/test_data/projects/mixins_different_entities/environment.yml new file mode 100644 index 0000000000..ac8feac3e8 --- /dev/null +++ b/tests/test_data/projects/mixins_different_entities/environment.yml @@ -0,0 +1,5 @@ +name: sf_env +channels: + - snowflake +dependencies: + - pandas diff --git a/tests/test_data/projects/mixins_different_entities/pages/my_page.py b/tests/test_data/projects/mixins_different_entities/pages/my_page.py new file mode 100644 index 0000000000..bc3ecbccba --- /dev/null +++ b/tests/test_data/projects/mixins_different_entities/pages/my_page.py @@ -0,0 +1,3 @@ +import streamlit as st + +st.title("Example page") diff --git a/tests/test_data/projects/mixins_different_entities/snowflake.yml b/tests/test_data/projects/mixins_different_entities/snowflake.yml new file mode 100644 index 0000000000..ed99c5fdf3 --- /dev/null +++ b/tests/test_data/projects/mixins_different_entities/snowflake.yml @@ -0,0 +1,45 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: my_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello + identifier: name + returns: string + signature: + - name: name + type: string + type: function + streamlit1: + artifacts: + - streamlit_app.py + - environment.yml + - pages + identifier: + name: test_streamlit + pages_dir: non_existent_dir + query_warehouse: test_warehouse + stage: streamlit + title: My Fancy Streamlit + type: streamlit + meta: + use_mixins: my_mixin +defaults: + stage: baz +mixins: + my_mixin: + stage: foo + main_file: streamlit_app.py + pages_dir: pages diff --git a/tests/test_data/projects/mixins_different_entities/streamlit_app.py b/tests/test_data/projects/mixins_different_entities/streamlit_app.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_data/projects/mixins_list_applied_in_order/environment.yml b/tests/test_data/projects/mixins_list_applied_in_order/environment.yml new file mode 100644 index 0000000000..ac8feac3e8 --- /dev/null +++ b/tests/test_data/projects/mixins_list_applied_in_order/environment.yml @@ -0,0 +1,5 @@ +name: sf_env +channels: + - snowflake +dependencies: + - pandas diff --git a/tests/test_data/projects/mixins_list_applied_in_order/pages/my_page.py b/tests/test_data/projects/mixins_list_applied_in_order/pages/my_page.py new file mode 100644 index 0000000000..bc3ecbccba --- /dev/null +++ b/tests/test_data/projects/mixins_list_applied_in_order/pages/my_page.py @@ -0,0 +1,3 @@ +import streamlit as st + +st.title("Example page") diff --git a/tests/test_data/projects/mixins_list_applied_in_order/snowflake.yml b/tests/test_data/projects/mixins_list_applied_in_order/snowflake.yml new file mode 100644 index 0000000000..9a536fa768 --- /dev/null +++ b/tests/test_data/projects/mixins_list_applied_in_order/snowflake.yml @@ -0,0 +1,58 @@ +definition_version: '2' +entities: + function1: + artifacts: + - src + handler: app.hello + identifier: name + meta: + use_mixins: + - second_mixin + - first_mixin + returns: string + signature: + - name: name + type: string + type: function + function2: + artifacts: + - src + handler: app.hello + identifier: name + returns: string + meta: + use_mixins: + - third_mixin + signature: + - name: name + type: string + type: function + streamlit1: + artifacts: + - streamlit_app.py + - environment.yml + - pages + identifier: + name: test_streamlit + pages_dir: non_existent_dir + query_warehouse: test_warehouse + stage: streamlit + title: My Fancy Streamlit + type: streamlit + meta: + use_mixins: + - first_mixin + - second_mixin +mixins: + first_mixin: + stage: foo + main_file: streamlit_app.py + pages_dir: non_existent_pages_dir + second_mixin: + stage: bar + main_file: streamlit_app.py + pages_dir: pages + third_mixin: + stage: baz + main_file: streamlit_app.py + pages_dir: pages diff --git a/tests/test_data/projects/mixins_list_applied_in_order/streamlit_app.py b/tests/test_data/projects/mixins_list_applied_in_order/streamlit_app.py new file mode 100644 index 0000000000..e69de29bb2