diff --git a/src/mlstacks/models/component.py b/src/mlstacks/models/component.py index dfc1e544..789fb44d 100644 --- a/src/mlstacks/models/component.py +++ b/src/mlstacks/models/component.py @@ -14,7 +14,7 @@ from typing import Dict, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, model_validator from mlstacks.constants import ( INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE, @@ -90,45 +90,35 @@ def validate_name(cls, name: str) -> str: # noqa: N805 raise ValueError(INVALID_NAME_ERROR_MESSAGE) return name - @field_validator("component_type", mode="after") - def validate_component_type(self, component_type: str) -> str: - """Validate the component type. + @model_validator(mode="after") + def validate_component_type_and_flavor(self) -> "Component": + """Validate the component type and flavor. Artifact Store, Container Registry, Experiment Tracker, Orchestrator, MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d providers. Step Operator may only be used with aws and gcp. - Args: - component_type: The component type. + Moreover, only certain flavors are allowed for a given + provider-component type combination. For more information, consult + the tables for your specified provider at the MLStacks documentation: + https://mlstacks.zenml.io/stacks/stack-specification. Returns: - The validated component type. + The validated component instance. Raises: - ValueError: If the component type is invalid. + ValueError: If the component type or flavor is invalid. """ - if not is_valid_component_type(component_type, self.provider): + if not is_valid_component_type( + component_type=self.component_type, provider=self.provider + ): raise ValueError(INVALID_COMPONENT_TYPE_ERROR_MESSAGE) - return component_type - - @field_validator("component_flavor", mode="after") - def validate_component_flavor(self, component_flavor: str) -> str: - """Validate the component flavor. - - Only certain flavors are allowed for a given provider-component - type combination. For more information, consult the tables for - your specified provider at the MLStacks documentation: - https://mlstacks.zenml.io/stacks/stack-specification. - - Args: - component_flavor: The component flavor. - - Returns: - The validated component flavor. - Raises: - ValueError: If the component flavor is invalid. - """ - if not is_valid_component_flavor(component_flavor, dict(self)): + if not is_valid_component_flavor( + component_flavor=self.component_flavor, + component_type=self.component_type, + provider=self.provider, + ): raise ValueError(INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE) - return component_flavor + + return self diff --git a/src/mlstacks/utils/model_utils.py b/src/mlstacks/utils/model_utils.py index e42c23d5..bf573572 100644 --- a/src/mlstacks/utils/model_utils.py +++ b/src/mlstacks/utils/model_utils.py @@ -13,7 +13,6 @@ """Util functions for Pydantic models and validation.""" import re -from typing import Any, Dict from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX @@ -49,7 +48,9 @@ def is_valid_component_type(component_type: str, provider: str) -> bool: def is_valid_component_flavor( - component_flavor: str, specs: Dict[str, Any] + component_flavor: str, + component_type: str, + provider: str, ) -> bool: """Check if the component flavor is valid. @@ -57,7 +58,8 @@ def is_valid_component_flavor( Args: component_flavor: The component flavor. - specs: The previously validated component specs. + component_type: The component type. + provider: The provider. Returns: True if the component flavor is valid, False otherwise. @@ -65,9 +67,7 @@ def is_valid_component_flavor( try: is_valid = ( component_flavor - in ALLOWED_COMPONENT_TYPES[specs["provider"]][ - specs["component_type"] - ] + in ALLOWED_COMPONENT_TYPES[provider][component_type] ) except KeyError: return False