Skip to content

Commit

Permalink
Made final changes for pull request. Got rid of comments and print ca…
Browse files Browse the repository at this point in the history
…lls.
  • Loading branch information
MASisserson committed Mar 27, 2024
1 parent 6a2353c commit 8f937d9
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 70 deletions.
2 changes: 1 addition & 1 deletion src/mlstacks/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def is_valid_component_flavor(
specs["component_type"]
]
)
except ValueError:
except KeyError:
return False

return is_valid
6 changes: 6 additions & 0 deletions src/mlstacks/utils/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def load_component_yaml(path: str) -> Component:
Returns:
The component model.
Raises:
FileNotFoundError: If the file is not found.
"""
try:
with open(path) as file:
Expand Down Expand Up @@ -97,6 +100,9 @@ def load_stack_yaml(path: str) -> Stack:
Returns:
The stack model.
Raises:
ValueError: If the stack and component have different providers
"""
with open(path) as yaml_file:
stack_data = yaml.safe_load(yaml_file)
Expand Down
2 changes: 1 addition & 1 deletion src/mlstacks/utils/test_utils.py → tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/models/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def test_component_metadata(instance):
)


# @given(st.builds(Component, name=st.from_regex(PERMITTED_NAME_REGEX), provider=st.sampled_from(["aws", "gcp", "k3d"])))
@given(valid_components())
def test_component(instance):
print(f"instance: {instance}")
Expand All @@ -73,11 +72,7 @@ def test_component(instance):
assert instance.name is not None
assert instance.spec_version is not None
assert instance.spec_type is not None
print("!!!!!!!!!!!!!!!!!!!!!!!!")
print("just prior to component type test")
assert isinstance(instance.component_type, ComponentTypeEnum)
print("just after component type test")
print("!!!!!!!!!!!!!!!!!!!!!!!!")
assert isinstance(instance.component_flavor, ComponentFlavorEnum)
assert isinstance(instance.provider, str)
assert instance.provider is not None
Expand Down
7 changes: 1 addition & 6 deletions tests/unit/utils/test_terraform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
remote_state_bucket_exists,
tf_definitions_present,
)
from mlstacks.utils.test_utils import get_allowed_providers
from tests.test_utils import get_allowed_providers

EXISTING_S3_BUCKET_URL = "s3://public-flavor-logos"
EXISTING_S3_BUCKET_REGION = "eu-central-1"
Expand Down Expand Up @@ -117,8 +117,6 @@ def test_enable_key_function_handles_components_without_flavors(
name=dummy_name,
component_flavor=comp_flavor,
component_type=comp_type,
# provider=random.choice(list(ProviderEnum)).value,
# Not sure why the above line was used when only "aws" is valid here
provider=comp_provider,
)
key = _compose_enable_key(c)
Expand All @@ -130,7 +128,6 @@ def test_component_variable_parsing_works():
metadata = ComponentMetadata()
component_flavor = "zenml"

# random_test = random.choice(list(ProviderEnum)).value
allowed_providers = get_allowed_providers()
random_test = random.choice(allowed_providers)

Expand Down Expand Up @@ -159,7 +156,6 @@ def test_component_var_parsing_works_for_env_vars():
# EXCLUDE AZURE
allowed_providers = get_allowed_providers()
random_test = random.choice(allowed_providers)
# random_test = random.choice(list(ProviderEnum)).value

components = [
Component(
Expand All @@ -180,7 +176,6 @@ def test_component_var_parsing_works_for_env_vars():

def test_tf_variable_parsing_from_stack_works():
"""Tests that the Terraform variables extraction (from a stack) works."""
# provider = random.choice(list(ProviderEnum)).value
allowed_providers = get_allowed_providers()
provider = random.choice(allowed_providers)

Expand Down
87 changes: 30 additions & 57 deletions tests/unit/utils/test_zenml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# permissions and limitations under the License.
"""Tests for utilities for mlstacks-ZenML interaction."""

import pydantic

from mlstacks.models.component import Component
from mlstacks.models.stack import Stack
Expand Down Expand Up @@ -46,66 +45,40 @@ def test_flavor_combination_validator_fails_aws_gcp():
Tests a known failure case. (AWS Stack with a GCP artifact store.)
"""
# valid_stack = Stack(
# name="aria-stack",
# provider="aws",
# components=[],
# )
# invalid_component = Component(
# name="blupus-component",
# component_type="artifact_store",
# component_flavor="gcp",
# provider=valid_stack.provider,
# )
# assert not has_valid_flavor_combinations(
# stack=valid_stack,
# components=[invalid_component],
# )

valid = True
try:
Component(
name="blupus-component",
component_type="artifact_store",
component_flavor="gcp",
provider="aws",
)
except pydantic.error_wrappers.ValidationError:
valid = False

assert not valid
valid_stack = Stack(
name="aria-stack",
provider="aws",
components=[],
)
invalid_component = Component(
name="blupus-component",
component_type="artifact_store",
component_flavor="gcp",
provider="gcp",
)
assert not has_valid_flavor_combinations(
stack=valid_stack,
components=[invalid_component],
)


def test_flavor_combination_validator_fails_k3d_s3():
"""Checks that the flavor combination validator fails.
Tests a known failure case. (K3D Stack with a S3 artifact store.)
"""
# valid_stack = Stack(
# name="aria-stack",
# provider="k3d",
# components=[],
# )
# invalid_component = Component(
# name="blupus-component",
# component_type="artifact_store",
# component_flavor="s3",
# provider="k3d",
# )
# assert not has_valid_flavor_combinations(
# stack=valid_stack,
# components=[invalid_component],
# )

valid = True
try:
Component(
name="blupus-component",
component_type="artifact_store",
component_flavor="s3",
provider="k3d",
)
except pydantic.error_wrappers.ValidationError:
valid = False

assert not valid
valid_stack = Stack(
name="aria-stack",
provider="k3d",
components=[],
)
invalid_component = Component(
name="blupus-component",
component_type="artifact_store",
component_flavor="s3",
provider="aws",
)
assert not has_valid_flavor_combinations(
stack=valid_stack,
components=[invalid_component],
)

0 comments on commit 8f937d9

Please sign in to comment.