Skip to content

Commit

Permalink
Project definition v2 entity schemas: application and application pac…
Browse files Browse the repository at this point in the history
…kage (#1280)

* application and application package entity schemas

* add integration test

* simplify entity type, use discriminator

* create types map dynamically

* get_type method

* read entity list from union
  • Loading branch information
sfc-gh-gbloom authored Jul 12, 2024
1 parent e6e19e5 commit 426173a
Show file tree
Hide file tree
Showing 8 changed files with 589 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Literal, Optional

from pydantic import AliasChoices, Field
from snowflake.cli.api.project.schemas.entities.application_package_entity import (
ApplicationPackageEntity,
)
from snowflake.cli.api.project.schemas.entities.common import (
EntityBase,
TargetField,
)
from snowflake.cli.api.project.schemas.updatable_model import (
UpdatableModel,
)


class ApplicationEntity(EntityBase):
type: Literal["application"] # noqa: A003
name: str = Field(
title="Name of the application created when this entity is deployed"
)
from_: ApplicationFromField = Field(
validation_alias=AliasChoices("from"),
title="An application package this entity should be created from",
)
debug: Optional[bool] = Field(
title="Whether to enable debug mode when using a named stage to create an application object",
default=None,
)


class ApplicationFromField(UpdatableModel):
target: TargetField[ApplicationPackageEntity] = Field(
title="Reference to an application package entity",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from pathlib import Path
from typing import List, Literal, Optional, Union

from pydantic import Field
from snowflake.cli.api.project.schemas.entities.common import (
EntityBase,
)
from snowflake.cli.api.project.schemas.native_app.package import DistributionOptions
from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping
from snowflake.cli.api.project.schemas.updatable_model import IdentifierField


class ApplicationPackageEntity(EntityBase):
type: Literal["application package"] # noqa: A003
name: str = Field(
title="Name of the application package created when this entity is deployed"
)
artifacts: List[Union[PathMapping, Path]] = Field(
title="List of paths or file source/destination pairs to add to the deploy root",
)
bundle_root: Optional[Path] = Field(
title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored.",
default=Path("output/bundle/"),
)
deploy_root: Optional[Path] = Field(
title="Folder at the root of your project where the build step copies the artifacts",
default=Path("output/deploy/"),
)
generated_root: Optional[Path] = Field(
title="Subdirectory of the deploy root where files generated by the Snowflake CLI will be written.",
default=Path("__generated/"),
)
stage: Optional[str] = IdentifierField(
title="Identifier of the stage that stores the application artifacts.",
default="app_src.stage",
)
scratch_stage: Optional[str] = IdentifierField(
title="Identifier of the stage that stores temporary scratch data used by the Snowflake CLI.",
default="app_src.stage_snowflake_cli_scratch",
)
distribution: Optional[DistributionOptions] = Field(
title="Distribution of the application package created by the Snowflake CLI",
default="internal",
)
manifest: Path = Field(
title="Path to manifest.yml",
)
85 changes: 85 additions & 0 deletions src/snowflake/cli/api/project/schemas/entities/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC
from typing import Generic, List, Optional, TypeVar

from pydantic import AliasChoices, Field, GetCoreSchemaHandler, ValidationInfo
from pydantic_core import core_schema
from snowflake.cli.api.project.schemas.native_app.application import (
ApplicationPostDeployHook,
)
from snowflake.cli.api.project.schemas.updatable_model import (
IdentifierField,
UpdatableModel,
)


class MetaField(UpdatableModel):
warehouse: Optional[str] = IdentifierField(
title="Warehouse used to run the scripts", default=None
)
role: Optional[str] = IdentifierField(
title="Role to use when creating the entity object",
default=None,
)
post_deploy: Optional[List[ApplicationPostDeployHook]] = Field(
title="Actions that will be executed after the application object is created/upgraded",
default=None,
)


class DefaultsField(UpdatableModel):
schema_: Optional[str] = Field(
title="Schema.",
validation_alias=AliasChoices("schema"),
default=None,
)
stage: Optional[str] = Field(
title="Stage.",
default=None,
)


class EntityBase(ABC, UpdatableModel):
@classmethod
def get_type(cls) -> str:
return cls.model_fields["type"].annotation.__args__[0]

meta: Optional[MetaField] = Field(title="Meta fields", default=None)


TargetType = TypeVar("TargetType")


class TargetField(Generic[TargetType]):
def __init__(self, entity_target_key: str):
self.value = entity_target_key

def __repr__(self):
return self.value

@classmethod
def validate(cls, value: str, info: ValidationInfo) -> TargetField:
return cls(value)

@classmethod
def __get_pydantic_core_schema__(
cls, source_type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.with_info_after_validator_function(
cls.validate, handler(str), field_name=handler.field_name
)
30 changes: 30 additions & 0 deletions src/snowflake/cli/api/project/schemas/entities/entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Union, get_args

from snowflake.cli.api.project.schemas.entities.application_entity import (
ApplicationEntity,
)
from snowflake.cli.api.project.schemas.entities.application_package_entity import (
ApplicationPackageEntity,
)

Entity = Union[ApplicationEntity, ApplicationPackageEntity]

ALL_ENTITIES = [*get_args(Entity)]

v2_entity_types_map = {e.get_type(): e for e in ALL_ENTITIES}
76 changes: 67 additions & 9 deletions src/snowflake/cli/api/project/schemas/project_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,27 @@
from typing import Dict, Optional, Union

from packaging.version import Version
from pydantic import Field, ValidationError, field_validator
from pydantic import Field, ValidationError, field_validator, model_validator
from snowflake.cli.api.feature_flags import FeatureFlag
from snowflake.cli.api.project.errors import SchemaValidationError
from snowflake.cli.api.project.schemas.entities.application_entity import (
ApplicationEntity,
)
from snowflake.cli.api.project.schemas.entities.common import (
DefaultsField,
TargetField,
)
from snowflake.cli.api.project.schemas.entities.entities import (
Entity,
v2_entity_types_map,
)
from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp
from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark
from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit
from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
from snowflake.cli.api.utils.models import ProjectEnvironment
from snowflake.cli.api.utils.types import Context
from typing_extensions import Annotated


@dataclass
Expand Down Expand Up @@ -105,8 +117,60 @@ def _convert_env(


class DefinitionV20(_ProjectDefinitionBase):
entities: Dict = Field(
title="Entity definitions.",
entities: Dict[str, Annotated[Entity, Field(discriminator="type")]] = Field(
title="Entity definitions."
)

@model_validator(mode="before")
@classmethod
def apply_defaults(cls, data: Dict) -> Dict:
"""
Applies default values that exist on the model but not specified in yml
"""
if "defaults" in data and "entities" in data:
for key, entity in data["entities"].items():
entity_type = entity["type"]
if entity_type not in v2_entity_types_map:
continue
entity_model = v2_entity_types_map[entity_type]
for default_key, default_value in data["defaults"].items():
if (
default_key in entity_model.model_fields
and default_key not in entity
):
entity[default_key] = default_value
return data

@field_validator("entities", mode="after")
@classmethod
def validate_entities(cls, entities: Dict[str, Entity]) -> Dict[str, Entity]:
for key, entity in entities.items():
# TODO Automatically detect TargetFields to validate
if entity.type == ApplicationEntity.get_type():
if isinstance(entity.from_.target, TargetField):
target_key = str(entity.from_.target)
target_class = entity.from_.__class__.model_fields["target"]
target_type = target_class.annotation.__args__[0]
cls._validate_target_field(target_key, target_type, entities)
return entities

@classmethod
def _validate_target_field(
cls, target_key: str, target_type: Entity, entities: Dict[str, Entity]
):
if target_key not in entities:
raise ValueError(f"No such target: {target_key}")
else:
# Validate the target type
actual_target_type = entities[target_key].__class__
if target_type and target_type is not actual_target_type:
raise ValueError(
f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}"
)

defaults: Optional[DefaultsField] = Field(
title="Default key/value entity values that are merged recursively for each entity.",
default=None,
)

env: Union[Dict[str, str], ProjectEnvironment, None] = Field(
Expand All @@ -125,12 +189,6 @@ def _convert_env(
return env
return ProjectEnvironment(default_env=(env or {}), override_env={})

@field_validator("entities")
@classmethod
def validate_entities(cls, entities: Dict) -> Dict:
# TODO Add entities validation logic
return entities


def build_project_definition(**data):
"""
Expand Down
Loading

0 comments on commit 426173a

Please sign in to comment.