Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Type validation for Cluster Configuration Parameters #593

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 59 additions & 29 deletions src/codeflare_sdk/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
Cluster object.
"""

from dataclasses import dataclass, field
import pathlib
import typing
import warnings
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union, get_args, get_origin

dir = pathlib.Path(__file__).parent.parent.resolve()

Expand Down Expand Up @@ -73,43 +73,45 @@ class ClusterConfiguration:
"""

name: str
namespace: str = None
head_info: list = field(default_factory=list)
head_cpus: typing.Union[int, str] = 2
head_memory: typing.Union[int, str] = 8
head_gpus: int = None # Deprecating
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
worker_cpu_requests: typing.Union[int, str] = 1
worker_cpu_limits: typing.Union[int, str] = 1
min_cpus: typing.Union[int, str] = None # Deprecating
max_cpus: typing.Union[int, str] = None # Deprecating
namespace: Optional[str] = None
head_info: List[str] = field(default_factory=list)
head_cpus: Union[int, str] = 2
head_memory: Union[int, str] = 8
head_gpus: Optional[int] = None # Deprecating
head_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
machine_types: List[str] = field(
default_factory=list
) # ["m4.xlarge", "g4dn.xlarge"]
worker_cpu_requests: Union[int, str] = 1
worker_cpu_limits: Union[int, str] = 1
min_cpus: Optional[Union[int, str]] = None # Deprecating
max_cpus: Optional[Union[int, str]] = None # Deprecating
num_workers: int = 1
worker_memory_requests: typing.Union[int, str] = 2
worker_memory_limits: typing.Union[int, str] = 2
min_memory: typing.Union[int, str] = None # Deprecating
max_memory: typing.Union[int, str] = None # Deprecating
num_gpus: int = None # Deprecating
worker_memory_requests: Union[int, str] = 2
worker_memory_limits: Union[int, str] = 2
min_memory: Optional[Union[int, str]] = None # Deprecating
max_memory: Optional[Union[int, str]] = None # Deprecating
num_gpus: Optional[int] = None # Deprecating
template: str = f"{dir}/templates/base-template.yaml"
appwrapper: bool = False
envs: dict = field(default_factory=dict)
envs: Dict[str, str] = field(default_factory=dict)
image: str = ""
image_pull_secrets: list = field(default_factory=list)
image_pull_secrets: List[str] = field(default_factory=list)
write_to_file: bool = False
verify_tls: bool = True
labels: dict = field(default_factory=dict)
worker_extended_resource_requests: typing.Dict[str, int] = field(
default_factory=dict
)
extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
labels: Dict[str, str] = field(default_factory=dict)
worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
extended_resource_mapping: Dict[str, str] = field(default_factory=dict)
overwrite_default_resource_mapping: bool = False
local_queue: Optional[str] = None

def __post_init__(self):
if not self.verify_tls:
print(
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
)

self._validate_types()
self._memory_to_string()
self._str_mem_no_unit_add_GB()
self._memory_to_resource()
Expand Down Expand Up @@ -139,9 +141,7 @@ def _combine_extended_resource_mapping(self):
**self.extended_resource_mapping,
}

def _validate_extended_resource_requests(
self, extended_resources: typing.Dict[str, int]
):
def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]):
for k in extended_resources.keys():
if k not in self.extended_resource_mapping.keys():
raise ValueError(
Expand Down Expand Up @@ -206,4 +206,34 @@ def _memory_to_resource(self):
warnings.warn("max_memory is being deprecated, use worker_memory_limits")
self.worker_memory_limits = f"{self.max_memory}G"

local_queue: str = None
def _validate_types(self):
"""Validate the types of all fields in the ClusterConfiguration dataclass."""
for field_info in fields(self):
value = getattr(self, field_info.name)
expected_type = field_info.type
if not self._is_type(value, expected_type):
raise TypeError(
f"'{field_info.name}' should be of type {expected_type}"
)

@staticmethod
def _is_type(value, expected_type):
"""Check if the value matches the expected type."""

def check_type(value, expected_type):
origin_type = get_origin(expected_type)
args = get_args(expected_type)
if origin_type is Union:
return any(check_type(value, union_type) for union_type in args)
if origin_type is list:
return all(check_type(elem, args[0]) for elem in value)
if origin_type is dict:
return all(
check_type(k, args[0]) and check_type(v, args[1])
for k, v in value.items()
)
if origin_type is tuple:
return all(check_type(elem, etype) for elem, etype in zip(value, args))
return isinstance(value, expected_type)

return check_type(value, expected_type)
6 changes: 6 additions & 0 deletions tests/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from tests.unit_test_support import (
createClusterWithConfig,
createClusterConfig,
createClusterWrongType,
)

import codeflare_sdk.utils.kube_api_helpers
Expand Down Expand Up @@ -268,6 +269,11 @@ def test_config_creation():
assert config.appwrapper == True


def test_config_creation_wrong_type():
with pytest.raises(TypeError):
config = createClusterWrongType()


def test_cluster_creation(mocker):
# Create AppWrapper containing a Ray Cluster with no local queue specified
mocker.patch("kubernetes.client.ApisApi.get_api_versions")
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_test_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,23 @@ def createClusterWithConfig(mocker):
)
cluster = Cluster(createClusterConfig())
return cluster


def createClusterWrongType():
config = ClusterConfiguration(
name="unit-test-cluster",
namespace="ns",
num_workers=2,
worker_cpu_requests=[],
worker_cpu_limits=4,
worker_memory_requests=5,
worker_memory_limits=6,
worker_extended_resource_requests={"nvidia.com/gpu": 7},
appwrapper=True,
machine_types=[True, False],
image_pull_secrets=["unit-test-pull-secret"],
image="quay.io/rhoai/ray:2.23.0-py39-cu121",
write_to_file=True,
labels={1: 1},
)
return config