Skip to content

Commit

Permalink
Add validation for Cluster configuration parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Ygnas committed Jul 18, 2024
1 parent e7a45ba commit ec09866
Showing 1 changed file with 59 additions and 29 deletions.
88 changes: 59 additions & 29 deletions src/codeflare_sdk/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
Cluster object.
"""

from dataclasses import dataclass, field
import pathlib
import typing
import warnings
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union, get_args, get_origin

dir = pathlib.Path(__file__).parent.parent.resolve()

Expand Down Expand Up @@ -73,36 +73,37 @@ class ClusterConfiguration:
"""

name: str
namespace: str = None
head_info: list = field(default_factory=list)
head_cpus: typing.Union[int, str] = 2
head_memory: typing.Union[int, str] = 8
head_gpus: int = None # Deprecating
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
worker_cpu_requests: typing.Union[int, str] = 1
worker_cpu_limits: typing.Union[int, str] = 1
min_cpus: typing.Union[int, str] = None # Deprecating
max_cpus: typing.Union[int, str] = None # Deprecating
namespace: Optional[str] = None
head_info: List[str] = field(default_factory=list)
head_cpus: Union[int, str] = 2
head_memory: Union[int, str] = 8
head_gpus: Optional[int] = None # Deprecating
head_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
machine_types: List[str] = field(
default_factory=list
) # ["m4.xlarge", "g4dn.xlarge"]
worker_cpu_requests: Union[int, str] = 1
worker_cpu_limits: Union[int, str] = 1
min_cpus: Optional[Union[int, str]] = None # Deprecating
max_cpus: Optional[Union[int, str]] = None # Deprecating
num_workers: int = 1
worker_memory_requests: typing.Union[int, str] = 2
worker_memory_limits: typing.Union[int, str] = 2
min_memory: typing.Union[int, str] = None # Deprecating
max_memory: typing.Union[int, str] = None # Deprecating
num_gpus: int = None # Deprecating
worker_memory_requests: Union[int, str] = 2
worker_memory_limits: Union[int, str] = 2
min_memory: Optional[Union[int, str]] = None # Deprecating
max_memory: Optional[Union[int, str]] = None # Deprecating
num_gpus: Optional[int] = None # Deprecating
template: str = f"{dir}/templates/base-template.yaml"
appwrapper: bool = False
envs: dict = field(default_factory=dict)
envs: Dict[str, str] = field(default_factory=dict)
image: str = ""
image_pull_secrets: list = field(default_factory=list)
image_pull_secrets: List[str] = field(default_factory=list)
write_to_file: bool = False
verify_tls: bool = True
labels: dict = field(default_factory=dict)
worker_extended_resource_requests: typing.Dict[str, int] = field(
default_factory=dict
)
extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
labels: Dict[str, str] = field(default_factory=dict)
worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
extended_resource_mapping: Dict[str, str] = field(default_factory=dict)
overwrite_default_resource_mapping: bool = False
local_queue: Optional[str] = None

def __post_init__(self):
if not self.verify_tls:
Expand All @@ -120,6 +121,7 @@ def __post_init__(self):
self._validate_extended_resource_requests(
self.worker_extended_resource_requests
)
self._validate_types()

def _combine_extended_resource_mapping(self):
if overwritten := set(self.extended_resource_mapping.keys()).intersection(
Expand All @@ -139,9 +141,7 @@ def _combine_extended_resource_mapping(self):
**self.extended_resource_mapping,
}

def _validate_extended_resource_requests(
self, extended_resources: typing.Dict[str, int]
):
def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]):
for k in extended_resources.keys():
if k not in self.extended_resource_mapping.keys():
raise ValueError(
Expand Down Expand Up @@ -206,4 +206,34 @@ def _memory_to_resource(self):
warnings.warn("max_memory is being deprecated, use worker_memory_limits")
self.worker_memory_limits = f"{self.max_memory}G"

local_queue: str = None
def _validate_types(self):
"""Validate the types of all fields in the ClusterConfiguration dataclass."""
for field_info in fields(self):
value = getattr(self, field_info.name)
expected_type = field_info.type
if not self._is_type(value, expected_type):
raise TypeError(
f"'{field_info.name}' should be of type {expected_type}"
)

@staticmethod
def _is_type(value, expected_type):
"""Check if the value matches the expected type."""

def check_type(value, expected_type):
origin_type = get_origin(expected_type)
args = get_args(expected_type)
if origin_type is Union:
return any(check_type(value, union_type) for union_type in args)
if origin_type is list:
return all(check_type(elem, args[0]) for elem in value)
if origin_type is dict:
return all(
check_type(k, args[0]) and check_type(v, args[1])
for k, v in value.items()
)
if origin_type is tuple:
return all(check_type(elem, etype) for elem, etype in zip(value, args))
return isinstance(value, expected_type)

return check_type(value, expected_type)

0 comments on commit ec09866

Please sign in to comment.