Skip to content

Commit

Permalink
RHOAIENG-8098 - ClusterConfiguration should support tolerations
Browse files Browse the repository at this point in the history
  • Loading branch information
jiripetrlik committed Jan 7, 2025
1 parent 6b0a3cc commit 2f2cdca
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
27 changes: 19 additions & 8 deletions src/codeflare_sdk/ray/cluster/build_ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
This sub-module exists primarily to be used internally by the Cluster object
(in the cluster sub-module) for RayCluster/AppWrapper generation.
"""
from typing import Union, Tuple, Dict
from typing import List, Union, Tuple, Dict
from ...common import _kube_api_error_handling
from ...common.kubernetes_cluster import get_api_client, config_check
from kubernetes.client.exceptions import ApiException
Expand All @@ -40,6 +40,7 @@
V1PodTemplateSpec,
V1PodSpec,
V1LocalObjectReference,
V1Toleration
)

import yaml
Expand Down Expand Up @@ -139,7 +140,8 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
"resources": head_resources,
},
"template": {
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)])
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)],
cluster.config.head_tolerations)
},
},
"workerGroupSpecs": [
Expand All @@ -154,7 +156,8 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
"resources": worker_resources,
},
"template": V1PodTemplateSpec(
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)])
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)],
cluster.config.tolerations)
),
}
],
Expand Down Expand Up @@ -243,14 +246,22 @@ def update_image(image) -> str:
return image


def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers, tolerations):
"""
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
"""
pod_spec = V1PodSpec(
containers=containers,
volumes=VOLUMES,
)
if tolerations is None:
pod_spec = V1PodSpec(
containers=containers,
volumes=VOLUMES
)
else:
pod_spec = V1PodSpec(
containers=containers,
volumes=VOLUMES,
tolerations=tolerations
)

if cluster.config.image_pull_secrets != []:
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)

Expand Down
12 changes: 11 additions & 1 deletion src/codeflare_sdk/ray/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import warnings
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union, get_args, get_origin
from kubernetes.client import V1Toleration

dir = pathlib.Path(__file__).parent.parent.resolve()

Expand Down Expand Up @@ -57,6 +58,8 @@ class ClusterConfiguration:
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
head_extended_resource_requests:
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
head_tolerations:
List of tolerations for head nodes.
min_cpus:
The minimum number of CPUs to allocate to each worker.
max_cpus:
Expand All @@ -69,6 +72,8 @@ class ClusterConfiguration:
The maximum amount of memory to allocate to each worker.
num_gpus:
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
tolerations:
List of tolerations for worker nodes.
appwrapper:
A boolean indicating whether to use an AppWrapper.
envs:
Expand Down Expand Up @@ -105,6 +110,7 @@ class ClusterConfiguration:
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
default_factory=dict
)
head_tolerations: Optional[List[V1Toleration]] = None
worker_cpu_requests: Union[int, str] = 1
worker_cpu_limits: Union[int, str] = 1
min_cpus: Optional[Union[int, str]] = None # Deprecating
Expand All @@ -115,6 +121,7 @@ class ClusterConfiguration:
min_memory: Optional[Union[int, str]] = None # Deprecating
max_memory: Optional[Union[int, str]] = None # Deprecating
num_gpus: Optional[int] = None # Deprecating
tolerations: Optional[List[V1Toleration]] = None
appwrapper: bool = False
envs: Dict[str, str] = field(default_factory=dict)
image: str = ""
Expand Down Expand Up @@ -265,7 +272,10 @@ def check_type(value, expected_type):
if origin_type is Union:
return any(check_type(value, union_type) for union_type in args)
if origin_type is list:
return all(check_type(elem, args[0]) for elem in value)
if value is not None:
return all(check_type(elem, args[0]) for elem in value)
else:
return True
if origin_type is dict:
return all(
check_type(k, args[0]) and check_type(v, args[1])
Expand Down

0 comments on commit 2f2cdca

Please sign in to comment.