diff --git a/src/mlstacks/utils/model_utils.py b/src/mlstacks/utils/model_utils.py index 3737e056..e42c23d5 100644 --- a/src/mlstacks/utils/model_utils.py +++ b/src/mlstacks/utils/model_utils.py @@ -69,7 +69,7 @@ def is_valid_component_flavor( specs["component_type"] ] ) - except ValueError: + except KeyError: return False return is_valid diff --git a/src/mlstacks/utils/yaml_utils.py b/src/mlstacks/utils/yaml_utils.py index cd1cc5ca..200a3132 100644 --- a/src/mlstacks/utils/yaml_utils.py +++ b/src/mlstacks/utils/yaml_utils.py @@ -58,6 +58,9 @@ def load_component_yaml(path: str) -> Component: Returns: The component model. + + Raises: + FileNotFoundError: If the file is not found. """ try: with open(path) as file: @@ -97,6 +100,9 @@ def load_stack_yaml(path: str) -> Stack: Returns: The stack model. + + Raises: + ValueError: If the stack and component have different providers """ with open(path) as yaml_file: stack_data = yaml.safe_load(yaml_file) diff --git a/src/mlstacks/utils/test_utils.py b/tests/test_utils.py similarity index 94% rename from src/mlstacks/utils/test_utils.py rename to tests/test_utils.py index ea493ab7..6d675268 100644 --- a/src/mlstacks/utils/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at: diff --git a/tests/unit/models/test_component.py b/tests/unit/models/test_component.py index 0cd3731d..ce8f0b1d 100644 --- a/tests/unit/models/test_component.py +++ b/tests/unit/models/test_component.py @@ -63,7 +63,6 @@ def test_component_metadata(instance): ) -# @given(st.builds(Component, name=st.from_regex(PERMITTED_NAME_REGEX), provider=st.sampled_from(["aws", "gcp", "k3d"]))) @given(valid_components()) def test_component(instance): print(f"instance: {instance}") @@ -73,11 +72,7 @@ def test_component(instance): assert instance.name is not None assert instance.spec_version is not None assert instance.spec_type is not None - print("!!!!!!!!!!!!!!!!!!!!!!!!") - print("just prior to component type test") assert isinstance(instance.component_type, ComponentTypeEnum) - print("just after component type test") - print("!!!!!!!!!!!!!!!!!!!!!!!!") assert isinstance(instance.component_flavor, ComponentFlavorEnum) assert isinstance(instance.provider, str) assert instance.provider is not None diff --git a/tests/unit/utils/test_terraform_utils.py b/tests/unit/utils/test_terraform_utils.py index 34377165..69abd74d 100644 --- a/tests/unit/utils/test_terraform_utils.py +++ b/tests/unit/utils/test_terraform_utils.py @@ -36,7 +36,7 @@ remote_state_bucket_exists, tf_definitions_present, ) -from mlstacks.utils.test_utils import get_allowed_providers +from tests.test_utils import get_allowed_providers EXISTING_S3_BUCKET_URL = "s3://public-flavor-logos" EXISTING_S3_BUCKET_REGION = "eu-central-1" @@ -117,8 +117,6 @@ def test_enable_key_function_handles_components_without_flavors( name=dummy_name, component_flavor=comp_flavor, component_type=comp_type, - # provider=random.choice(list(ProviderEnum)).value, - # Not sure why the above line was used when only "aws" is valid here provider=comp_provider, ) key = _compose_enable_key(c) @@ -130,7 +128,6 @@ def test_component_variable_parsing_works(): metadata = ComponentMetadata() component_flavor = "zenml" - # random_test = random.choice(list(ProviderEnum)).value allowed_providers = get_allowed_providers() random_test = random.choice(allowed_providers) @@ -159,7 +156,6 @@ def test_component_var_parsing_works_for_env_vars(): # EXCLUDE AZURE allowed_providers = get_allowed_providers() random_test = random.choice(allowed_providers) - # random_test = random.choice(list(ProviderEnum)).value components = [ Component( @@ -180,7 +176,6 @@ def test_component_var_parsing_works_for_env_vars(): def test_tf_variable_parsing_from_stack_works(): """Tests that the Terraform variables extraction (from a stack) works.""" - # provider = random.choice(list(ProviderEnum)).value allowed_providers = get_allowed_providers() provider = random.choice(allowed_providers) diff --git a/tests/unit/utils/test_zenml_utils.py b/tests/unit/utils/test_zenml_utils.py index 3274d119..cc72771c 100644 --- a/tests/unit/utils/test_zenml_utils.py +++ b/tests/unit/utils/test_zenml_utils.py @@ -12,7 +12,6 @@ # permissions and limitations under the License. """Tests for utilities for mlstacks-ZenML interaction.""" -import pydantic from mlstacks.models.component import Component from mlstacks.models.stack import Stack @@ -46,34 +45,21 @@ def test_flavor_combination_validator_fails_aws_gcp(): Tests a known failure case. (AWS Stack with a GCP artifact store.) """ - # valid_stack = Stack( - # name="aria-stack", - # provider="aws", - # components=[], - # ) - # invalid_component = Component( - # name="blupus-component", - # component_type="artifact_store", - # component_flavor="gcp", - # provider=valid_stack.provider, - # ) - # assert not has_valid_flavor_combinations( - # stack=valid_stack, - # components=[invalid_component], - # ) - - valid = True - try: - Component( - name="blupus-component", - component_type="artifact_store", - component_flavor="gcp", - provider="aws", - ) - except pydantic.error_wrappers.ValidationError: - valid = False - - assert not valid + valid_stack = Stack( + name="aria-stack", + provider="aws", + components=[], + ) + invalid_component = Component( + name="blupus-component", + component_type="artifact_store", + component_flavor="gcp", + provider="gcp", + ) + assert not has_valid_flavor_combinations( + stack=valid_stack, + components=[invalid_component], + ) def test_flavor_combination_validator_fails_k3d_s3(): @@ -81,31 +67,18 @@ def test_flavor_combination_validator_fails_k3d_s3(): Tests a known failure case. (K3D Stack with a S3 artifact store.) """ - # valid_stack = Stack( - # name="aria-stack", - # provider="k3d", - # components=[], - # ) - # invalid_component = Component( - # name="blupus-component", - # component_type="artifact_store", - # component_flavor="s3", - # provider="k3d", - # ) - # assert not has_valid_flavor_combinations( - # stack=valid_stack, - # components=[invalid_component], - # ) - - valid = True - try: - Component( - name="blupus-component", - component_type="artifact_store", - component_flavor="s3", - provider="k3d", - ) - except pydantic.error_wrappers.ValidationError: - valid = False - - assert not valid + valid_stack = Stack( + name="aria-stack", + provider="k3d", + components=[], + ) + invalid_component = Component( + name="blupus-component", + component_type="artifact_store", + component_flavor="s3", + provider="aws", + ) + assert not has_valid_flavor_combinations( + stack=valid_stack, + components=[invalid_component], + )