From 7210a6f1bfd032bc1896bf2f6762c3d20c4734bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:42:52 +0200 Subject: [PATCH] Upgrading to `pydantic` v2 (#155) * upgrading to pydantic v2 * slightly different approach for the validator * removed unused noqa * upgrading ruff * Bump pydantic version for airflow recipe --------- Co-authored-by: Michael Schuster --- gcp-airflow/composer.tf | 2 +- pyproject.toml | 4 +- src/mlstacks/cli/cli.py | 1 + src/mlstacks/models/component.py | 67 +++++++++---------------- src/mlstacks/models/stack.py | 16 +++--- src/mlstacks/utils/environment_utils.py | 1 + src/mlstacks/utils/model_utils.py | 12 ++--- src/mlstacks/utils/yaml_utils.py | 1 + tests/unit/utils/test_zenml_utils.py | 1 - 9 files changed, 45 insertions(+), 60 deletions(-) diff --git a/gcp-airflow/composer.tf b/gcp-airflow/composer.tf index c655d4c3..adca786f 100644 --- a/gcp-airflow/composer.tf +++ b/gcp-airflow/composer.tf @@ -14,7 +14,7 @@ resource "google_composer_environment" "zenml-airflow" { software_config { image_version = "composer-2-airflow-2" pypi_packages = { - pydantic = "~=1.9.2" + pydantic = "~=2.7.1" } } diff --git a/pyproject.toml b/pyproject.toml index 0576d1f1..2325bc03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ exclude = [ # pydantic = {version = "2.0.2"} # compatible with Core ZenML python = ">=3.8,<3.12" -pydantic = { version = "<1.11,>=1.9.0" } +pydantic = { version = "~2.7" } pyyaml = { version = ">=6.0.1" } click = { version = "^8.0.1,<8.1.4" } python-terraform = { version = "^0.10.1" } @@ -74,7 +74,7 @@ rich = { version = "^12.0.0" } analytics-python = { version = "^1.4.0" } # dev dependencies -ruff = { version = "^0.1.7", optional = true } +ruff = { version = ">=0.1.7", optional = true } pytest = { version = "^7.4.0", optional = true } mypy = { version = "^1.7.1", optional = true } darglint = { version = "^1.8.1", optional = true } diff --git a/src/mlstacks/cli/cli.py b/src/mlstacks/cli/cli.py index b90a14c1..d2afc883 100644 --- a/src/mlstacks/cli/cli.py +++ b/src/mlstacks/cli/cli.py @@ -11,6 +11,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """CLI for mlstacks.""" + import random import shutil import string diff --git a/src/mlstacks/models/component.py b/src/mlstacks/models/component.py index 4eed4395..10afa265 100644 --- a/src/mlstacks/models/component.py +++ b/src/mlstacks/models/component.py @@ -12,9 +12,9 @@ # permissions and limitations under the License. """Component model.""" -from typing import Any, Dict, Optional +from typing import Dict, Optional -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator, model_validator from mlstacks.constants import ( INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE, @@ -67,8 +67,9 @@ class Component(BaseModel): component_flavor: ComponentFlavorEnum metadata: Optional[ComponentMetadata] = None - @validator("name") - def validate_name(cls, name: str) -> str: # noqa + @field_validator("name") + @classmethod + def validate_name(cls, name: str) -> str: """Validate the name. Name must start with an alphanumeric character and can only contain @@ -89,55 +90,35 @@ def validate_name(cls, name: str) -> str: # noqa raise ValueError(INVALID_NAME_ERROR_MESSAGE) return name - @validator("component_type") - def validate_component_type( - cls, # noqa - component_type: str, - values: Dict[str, Any], - ) -> str: - """Validate the component type. + @model_validator(mode="after") + def validate_component_type_and_flavor(self) -> "Component": + """Validate the component type and flavor. Artifact Store, Container Registry, Experiment Tracker, Orchestrator, MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d providers. Step Operator may only be used with aws and gcp. - Args: - component_type: The component type. - values: The previously validated component specs. + Moreover, only certain flavors are allowed for a given + provider-component type combination. For more information, consult + the tables for your specified provider at the MLStacks documentation: + https://mlstacks.zenml.io/stacks/stack-specification. Returns: - The validated component type. + The validated component instance. Raises: - ValueError: If the component type is invalid. + ValueError: If the component type or flavor is invalid. """ - if not is_valid_component_type(component_type, values["provider"]): + if not is_valid_component_type( + component_type=self.component_type, provider=self.provider + ): raise ValueError(INVALID_COMPONENT_TYPE_ERROR_MESSAGE) - return component_type - - @validator("component_flavor") - def validate_component_flavor( - cls, # noqa - component_flavor: str, - values: Dict[str, Any], - ) -> str: - """Validate the component flavor. - - Only certain flavors are allowed for a given provider-component - type combination. For more information, consult the tables for - your specified provider at the MLStacks documentation: - https://mlstacks.zenml.io/stacks/stack-specification. - - Args: - component_flavor: The component flavor. - values: The previously validated component specs. - Returns: - The validated component flavor. - - Raises: - ValueError: If the component flavor is invalid. - """ - if not is_valid_component_flavor(component_flavor, values): + if not is_valid_component_flavor( + component_flavor=self.component_flavor, + component_type=self.component_type, + provider=self.provider, + ): raise ValueError(INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE) - return component_flavor + + return self diff --git a/src/mlstacks/models/stack.py b/src/mlstacks/models/stack.py index b32bdba4..14ff2876 100644 --- a/src/mlstacks/models/stack.py +++ b/src/mlstacks/models/stack.py @@ -11,9 +11,10 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Stack model.""" + from typing import Dict, List, Optional -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator from mlstacks.constants import INVALID_NAME_ERROR_MESSAGE from mlstacks.enums import ( @@ -44,15 +45,16 @@ class Stack(BaseModel): spec_type: SpecTypeEnum = SpecTypeEnum.STACK name: str provider: ProviderEnum - default_region: Optional[str] + default_region: Optional[str] = None default_tags: Optional[Dict[str, str]] = None - deployment_method: Optional[ - DeploymentMethodEnum - ] = DeploymentMethodEnum.KUBERNETES + deployment_method: Optional[DeploymentMethodEnum] = ( + DeploymentMethodEnum.KUBERNETES + ) components: List[Component] = [] - @validator("name") - def validate_name(cls, name: str) -> str: # noqa + @field_validator("name") + @classmethod + def validate_name(cls, name: str) -> str: """Validate the name. Name must start with an alphanumeric character and can only contain diff --git a/src/mlstacks/utils/environment_utils.py b/src/mlstacks/utils/environment_utils.py index d43c1a4b..cbb3bb4d 100644 --- a/src/mlstacks/utils/environment_utils.py +++ b/src/mlstacks/utils/environment_utils.py @@ -11,6 +11,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Environment utilities for mlstacks.""" + import os diff --git a/src/mlstacks/utils/model_utils.py b/src/mlstacks/utils/model_utils.py index e42c23d5..bf573572 100644 --- a/src/mlstacks/utils/model_utils.py +++ b/src/mlstacks/utils/model_utils.py @@ -13,7 +13,6 @@ """Util functions for Pydantic models and validation.""" import re -from typing import Any, Dict from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX @@ -49,7 +48,9 @@ def is_valid_component_type(component_type: str, provider: str) -> bool: def is_valid_component_flavor( - component_flavor: str, specs: Dict[str, Any] + component_flavor: str, + component_type: str, + provider: str, ) -> bool: """Check if the component flavor is valid. @@ -57,7 +58,8 @@ def is_valid_component_flavor( Args: component_flavor: The component flavor. - specs: The previously validated component specs. + component_type: The component type. + provider: The provider. Returns: True if the component flavor is valid, False otherwise. @@ -65,9 +67,7 @@ def is_valid_component_flavor( try: is_valid = ( component_flavor - in ALLOWED_COMPONENT_TYPES[specs["provider"]][ - specs["component_type"] - ] + in ALLOWED_COMPONENT_TYPES[provider][component_type] ) except KeyError: return False diff --git a/src/mlstacks/utils/yaml_utils.py b/src/mlstacks/utils/yaml_utils.py index 0ef0734a..d29c8d9c 100644 --- a/src/mlstacks/utils/yaml_utils.py +++ b/src/mlstacks/utils/yaml_utils.py @@ -11,6 +11,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Utility functions for loading YAML files into Python objects.""" + from pathlib import Path from typing import Any, Dict, Union diff --git a/tests/unit/utils/test_zenml_utils.py b/tests/unit/utils/test_zenml_utils.py index cc72771c..26f4ea51 100644 --- a/tests/unit/utils/test_zenml_utils.py +++ b/tests/unit/utils/test_zenml_utils.py @@ -12,7 +12,6 @@ # permissions and limitations under the License. """Tests for utilities for mlstacks-ZenML interaction.""" - from mlstacks.models.component import Component from mlstacks.models.stack import Stack from mlstacks.utils.zenml_utils import has_valid_flavor_combinations