diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 6a522fbc4..610d53c44 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -18,10 +18,10 @@ Cluster object. """ -from dataclasses import dataclass, field import pathlib -import typing import warnings +from dataclasses import dataclass, field, fields +from typing import Dict, List, Optional, Union, get_args, get_origin dir = pathlib.Path(__file__).parent.parent.resolve() @@ -73,36 +73,37 @@ class ClusterConfiguration: """ name: str - namespace: str = None - head_info: list = field(default_factory=list) - head_cpus: typing.Union[int, str] = 2 - head_memory: typing.Union[int, str] = 8 - head_gpus: int = None # Deprecating - head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict) - machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"] - worker_cpu_requests: typing.Union[int, str] = 1 - worker_cpu_limits: typing.Union[int, str] = 1 - min_cpus: typing.Union[int, str] = None # Deprecating - max_cpus: typing.Union[int, str] = None # Deprecating + namespace: Optional[str] = None + head_info: List[str] = field(default_factory=list) + head_cpus: Union[int, str] = 2 + head_memory: Union[int, str] = 8 + head_gpus: Optional[int] = None # Deprecating + head_extended_resource_requests: Dict[str, int] = field(default_factory=dict) + machine_types: List[str] = field( + default_factory=list + ) # ["m4.xlarge", "g4dn.xlarge"] + worker_cpu_requests: Union[int, str] = 1 + worker_cpu_limits: Union[int, str] = 1 + min_cpus: Optional[Union[int, str]] = None # Deprecating + max_cpus: Optional[Union[int, str]] = None # Deprecating num_workers: int = 1 - worker_memory_requests: typing.Union[int, str] = 2 - worker_memory_limits: typing.Union[int, str] = 2 - min_memory: typing.Union[int, str] = None # Deprecating - max_memory: typing.Union[int, str] = None # Deprecating - num_gpus: int = None # Deprecating + worker_memory_requests: Union[int, str] = 2 + worker_memory_limits: Union[int, str] = 2 + min_memory: Optional[Union[int, str]] = None # Deprecating + max_memory: Optional[Union[int, str]] = None # Deprecating + num_gpus: Optional[int] = None # Deprecating template: str = f"{dir}/templates/base-template.yaml" appwrapper: bool = False - envs: dict = field(default_factory=dict) + envs: Dict[str, str] = field(default_factory=dict) image: str = "" - image_pull_secrets: list = field(default_factory=list) + image_pull_secrets: List[str] = field(default_factory=list) write_to_file: bool = False verify_tls: bool = True - labels: dict = field(default_factory=dict) - worker_extended_resource_requests: typing.Dict[str, int] = field( - default_factory=dict - ) - extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict) + labels: Dict[str, str] = field(default_factory=dict) + worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict) + extended_resource_mapping: Dict[str, str] = field(default_factory=dict) overwrite_default_resource_mapping: bool = False + local_queue: Optional[str] = None def __post_init__(self): if not self.verify_tls: @@ -110,6 +111,7 @@ def __post_init__(self): "Warning: TLS verification has been disabled - Endpoint checks will be bypassed" ) + self._validate_types() self._memory_to_string() self._str_mem_no_unit_add_GB() self._memory_to_resource() @@ -139,9 +141,7 @@ def _combine_extended_resource_mapping(self): **self.extended_resource_mapping, } - def _validate_extended_resource_requests( - self, extended_resources: typing.Dict[str, int] - ): + def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]): for k in extended_resources.keys(): if k not in self.extended_resource_mapping.keys(): raise ValueError( @@ -206,4 +206,34 @@ def _memory_to_resource(self): warnings.warn("max_memory is being deprecated, use worker_memory_limits") self.worker_memory_limits = f"{self.max_memory}G" - local_queue: str = None + def _validate_types(self): + """Validate the types of all fields in the ClusterConfiguration dataclass.""" + for field_info in fields(self): + value = getattr(self, field_info.name) + expected_type = field_info.type + if not self._is_type(value, expected_type): + raise TypeError( + f"'{field_info.name}' should be of type {expected_type}" + ) + + @staticmethod + def _is_type(value, expected_type): + """Check if the value matches the expected type.""" + + def check_type(value, expected_type): + origin_type = get_origin(expected_type) + args = get_args(expected_type) + if origin_type is Union: + return any(check_type(value, union_type) for union_type in args) + if origin_type is list: + return all(check_type(elem, args[0]) for elem in value) + if origin_type is dict: + return all( + check_type(k, args[0]) and check_type(v, args[1]) + for k, v in value.items() + ) + if origin_type is tuple: + return all(check_type(elem, etype) for elem, etype in zip(value, args)) + return isinstance(value, expected_type) + + return check_type(value, expected_type) diff --git a/tests/unit_test.py b/tests/unit_test.py index 2decade20..bb14d2b20 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -67,6 +67,7 @@ from tests.unit_test_support import ( createClusterWithConfig, createClusterConfig, + createClusterWrongType, ) import codeflare_sdk.utils.kube_api_helpers @@ -268,6 +269,11 @@ def test_config_creation(): assert config.appwrapper == True +def test_config_creation_wrong_type(): + with pytest.raises(TypeError): + config = createClusterWrongType() + + def test_cluster_creation(mocker): # Create AppWrapper containing a Ray Cluster with no local queue specified mocker.patch("kubernetes.client.ApisApi.get_api_versions") diff --git a/tests/unit_test_support.py b/tests/unit_test_support.py index 36c7d8710..25e206c52 100644 --- a/tests/unit_test_support.py +++ b/tests/unit_test_support.py @@ -31,3 +31,23 @@ def createClusterWithConfig(mocker): ) cluster = Cluster(createClusterConfig()) return cluster + + +def createClusterWrongType(): + config = ClusterConfiguration( + name="unit-test-cluster", + namespace="ns", + num_workers=2, + worker_cpu_requests=[], + worker_cpu_limits=4, + worker_memory_requests=5, + worker_memory_limits=6, + worker_extended_resource_requests={"nvidia.com/gpu": 7}, + appwrapper=True, + machine_types=[True, False], + image_pull_secrets=["unit-test-pull-secret"], + image="quay.io/rhoai/ray:2.23.0-py39-cu121", + write_to_file=True, + labels={1: 1}, + ) + return config