Skip to content

Commit

Permalink
slightly different approach for the validator
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak committed May 7, 2024
1 parent 12c5305 commit 9356d81
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 36 deletions.
50 changes: 20 additions & 30 deletions src/mlstacks/models/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Dict, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, model_validator

from mlstacks.constants import (
INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE,
Expand Down Expand Up @@ -90,45 +90,35 @@ def validate_name(cls, name: str) -> str: # noqa: N805
raise ValueError(INVALID_NAME_ERROR_MESSAGE)
return name

@field_validator("component_type", mode="after")
def validate_component_type(self, component_type: str) -> str:
"""Validate the component type.
@model_validator(mode="after")
def validate_component_type_and_flavor(self) -> "Component":
"""Validate the component type and flavor.
Artifact Store, Container Registry, Experiment Tracker, Orchestrator,
MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d
providers. Step Operator may only be used with aws and gcp.
Args:
component_type: The component type.
Moreover, only certain flavors are allowed for a given
provider-component type combination. For more information, consult
the tables for your specified provider at the MLStacks documentation:
https://mlstacks.zenml.io/stacks/stack-specification.
Returns:
The validated component type.
The validated component instance.
Raises:
ValueError: If the component type is invalid.
ValueError: If the component type or flavor is invalid.
"""
if not is_valid_component_type(component_type, self.provider):
if not is_valid_component_type(
component_type=self.component_type, provider=self.provider
):
raise ValueError(INVALID_COMPONENT_TYPE_ERROR_MESSAGE)
return component_type

@field_validator("component_flavor", mode="after")
def validate_component_flavor(self, component_flavor: str) -> str:
"""Validate the component flavor.
Only certain flavors are allowed for a given provider-component
type combination. For more information, consult the tables for
your specified provider at the MLStacks documentation:
https://mlstacks.zenml.io/stacks/stack-specification.
Args:
component_flavor: The component flavor.
Returns:
The validated component flavor.

Raises:
ValueError: If the component flavor is invalid.
"""
if not is_valid_component_flavor(component_flavor, dict(self)):
if not is_valid_component_flavor(
component_flavor=self.component_flavor,
component_type=self.component_type,
provider=self.provider,
):
raise ValueError(INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE)
return component_flavor

return self
12 changes: 6 additions & 6 deletions src/mlstacks/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"""Util functions for Pydantic models and validation."""

import re
from typing import Any, Dict

from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX

Expand Down Expand Up @@ -49,25 +48,26 @@ def is_valid_component_type(component_type: str, provider: str) -> bool:


def is_valid_component_flavor(
component_flavor: str, specs: Dict[str, Any]
component_flavor: str,
component_type: str,
provider: str,
) -> bool:
"""Check if the component flavor is valid.
Used for components.
Args:
component_flavor: The component flavor.
specs: The previously validated component specs.
component_type: The component type.
provider: The provider.
Returns:
True if the component flavor is valid, False otherwise.
"""
try:
is_valid = (
component_flavor
in ALLOWED_COMPONENT_TYPES[specs["provider"]][
specs["component_type"]
]
in ALLOWED_COMPONENT_TYPES[provider][component_type]
)
except KeyError:
return False
Expand Down

0 comments on commit 9356d81

Please sign in to comment.