Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrading to pydantic v2 #155

Merged
merged 6 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ exclude = [
# pydantic = {version = "2.0.2"}
# compatible with Core ZenML
python = ">=3.8,<3.12"
pydantic = { version = "<1.11,>=1.9.0" }
pydantic = { version = "~2.7" }
pyyaml = { version = ">=6.0.1" }
click = { version = "^8.0.1,<8.1.4" }
python-terraform = { version = "^0.10.1" }
Expand Down
1 change: 1 addition & 0 deletions src/mlstacks/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""CLI for mlstacks."""

import random
import shutil
import string
Expand Down
67 changes: 24 additions & 43 deletions src/mlstacks/models/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# permissions and limitations under the License.
"""Component model."""

from typing import Any, Dict, Optional
from typing import Dict, Optional

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator, model_validator

from mlstacks.constants import (
INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE,
Expand Down Expand Up @@ -67,8 +67,9 @@ class Component(BaseModel):
component_flavor: ComponentFlavorEnum
metadata: Optional[ComponentMetadata] = None

@validator("name")
def validate_name(cls, name: str) -> str: # noqa
@field_validator("name")
@classmethod
def validate_name(cls, name: str) -> str:
"""Validate the name.

Name must start with an alphanumeric character and can only contain
Expand All @@ -89,55 +90,35 @@ def validate_name(cls, name: str) -> str: # noqa
raise ValueError(INVALID_NAME_ERROR_MESSAGE)
return name

@validator("component_type")
def validate_component_type(
cls, # noqa
component_type: str,
values: Dict[str, Any],
) -> str:
"""Validate the component type.
@model_validator(mode="after")
def validate_component_type_and_flavor(self) -> "Component":
"""Validate the component type and flavor.

Artifact Store, Container Registry, Experiment Tracker, Orchestrator,
MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d
providers. Step Operator may only be used with aws and gcp.

Args:
component_type: The component type.
values: The previously validated component specs.
Moreover, only certain flavors are allowed for a given
provider-component type combination. For more information, consult
the tables for your specified provider at the MLStacks documentation:
https://mlstacks.zenml.io/stacks/stack-specification.

Returns:
The validated component type.
The validated component instance.

Raises:
ValueError: If the component type is invalid.
ValueError: If the component type or flavor is invalid.
"""
if not is_valid_component_type(component_type, values["provider"]):
if not is_valid_component_type(
component_type=self.component_type, provider=self.provider
):
raise ValueError(INVALID_COMPONENT_TYPE_ERROR_MESSAGE)
return component_type

@validator("component_flavor")
def validate_component_flavor(
cls, # noqa
component_flavor: str,
values: Dict[str, Any],
) -> str:
"""Validate the component flavor.

Only certain flavors are allowed for a given provider-component
type combination. For more information, consult the tables for
your specified provider at the MLStacks documentation:
https://mlstacks.zenml.io/stacks/stack-specification.

Args:
component_flavor: The component flavor.
values: The previously validated component specs.

Returns:
The validated component flavor.

Raises:
ValueError: If the component flavor is invalid.
"""
if not is_valid_component_flavor(component_flavor, values):
if not is_valid_component_flavor(
component_flavor=self.component_flavor,
component_type=self.component_type,
provider=self.provider,
):
raise ValueError(INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE)
return component_flavor

return self
16 changes: 9 additions & 7 deletions src/mlstacks/models/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Stack model."""

from typing import Dict, List, Optional

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator

from mlstacks.constants import INVALID_NAME_ERROR_MESSAGE
from mlstacks.enums import (
Expand Down Expand Up @@ -44,15 +45,16 @@ class Stack(BaseModel):
spec_type: SpecTypeEnum = SpecTypeEnum.STACK
name: str
provider: ProviderEnum
default_region: Optional[str]
default_region: Optional[str] = None
default_tags: Optional[Dict[str, str]] = None
deployment_method: Optional[
DeploymentMethodEnum
] = DeploymentMethodEnum.KUBERNETES
deployment_method: Optional[DeploymentMethodEnum] = (
DeploymentMethodEnum.KUBERNETES
)
components: List[Component] = []

@validator("name")
def validate_name(cls, name: str) -> str: # noqa
@field_validator("name")
@classmethod
def validate_name(cls, name: str) -> str:
"""Validate the name.

Name must start with an alphanumeric character and can only contain
Expand Down
1 change: 1 addition & 0 deletions src/mlstacks/utils/environment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Environment utilities for mlstacks."""

import os


Expand Down
12 changes: 6 additions & 6 deletions src/mlstacks/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"""Util functions for Pydantic models and validation."""

import re
from typing import Any, Dict

from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX

Expand Down Expand Up @@ -49,25 +48,26 @@ def is_valid_component_type(component_type: str, provider: str) -> bool:


def is_valid_component_flavor(
component_flavor: str, specs: Dict[str, Any]
component_flavor: str,
component_type: str,
provider: str,
) -> bool:
"""Check if the component flavor is valid.

Used for components.

Args:
component_flavor: The component flavor.
specs: The previously validated component specs.
component_type: The component type.
provider: The provider.

Returns:
True if the component flavor is valid, False otherwise.
"""
try:
is_valid = (
component_flavor
in ALLOWED_COMPONENT_TYPES[specs["provider"]][
specs["component_type"]
]
in ALLOWED_COMPONENT_TYPES[provider][component_type]
)
except KeyError:
return False
Expand Down
1 change: 1 addition & 0 deletions src/mlstacks/utils/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Utility functions for loading YAML files into Python objects."""

from pathlib import Path
from typing import Any, Dict, Union

Expand Down
1 change: 0 additions & 1 deletion tests/unit/utils/test_zenml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# permissions and limitations under the License.
"""Tests for utilities for mlstacks-ZenML interaction."""


from mlstacks.models.component import Component
from mlstacks.models.stack import Stack
from mlstacks.utils.zenml_utils import has_valid_flavor_combinations
Expand Down
Loading