Skip to content

Commit

Permalink
simplify function calls and add option for custom resources
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin <[email protected]>
  • Loading branch information
KPostOffice committed May 9, 2024
1 parent 6a9b185 commit 7b5bacc
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 190 deletions.
48 changes: 1 addition & 47 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,53 +165,7 @@ def create_app_wrapper(self):
else:
priority_val = None

name = self.config.name
namespace = self.config.namespace
head_cpus = self.config.head_cpus
head_memory = self.config.head_memory
head_gpus = self.config.head_gpus
min_cpu = self.config.min_cpus
max_cpu = self.config.max_cpus
min_memory = self.config.min_memory
max_memory = self.config.max_memory
gpu = self.config.num_gpus
workers = self.config.num_workers
template = self.config.template
image = self.config.image
instascale = self.config.instascale
mcad = self.config.mcad
instance_types = self.config.machine_types
env = self.config.envs
image_pull_secrets = self.config.image_pull_secrets
dispatch_priority = self.config.dispatch_priority
write_to_file = self.config.write_to_file
verify_tls = self.config.verify_tls
local_queue = self.config.local_queue
return generate_appwrapper(
name=name,
namespace=namespace,
head_cpus=head_cpus,
head_memory=head_memory,
head_gpus=head_gpus,
min_cpu=min_cpu,
max_cpu=max_cpu,
min_memory=min_memory,
max_memory=max_memory,
gpu=gpu,
workers=workers,
template=template,
image=image,
instascale=instascale,
mcad=mcad,
instance_types=instance_types,
env=env,
image_pull_secrets=image_pull_secrets,
dispatch_priority=dispatch_priority,
priority_val=priority_val,
write_to_file=write_to_file,
verify_tls=verify_tls,
local_queue=local_queue,
)
return generate_appwrapper(self)

# creates a new cluster with the provided or default spec
def up(self):
Expand Down
47 changes: 47 additions & 0 deletions src/codeflare_sdk/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,22 @@
from dataclasses import dataclass, field
import pathlib
import typing
import warnings

dir = pathlib.Path(__file__).parent.parent.resolve()

# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
DEFAULT_RESOURCE_MAPPING = {
"nvidia.com/gpu": "GPU",
"intel.com/gpu": "GPU",
"amd.com/gpu": "GPU",
"aws.amazon.com/neuroncore": "neuron_cores",
"google.com/tpu": "TPU",
"habana.ai/gaudi": "HPU",
"huawei.com/Ascend910": "NPU",
"huawei.com/Ascend310": "NPU",
}


@dataclass
class ClusterConfiguration:
Expand All @@ -38,6 +51,7 @@ class ClusterConfiguration:
head_cpus: typing.Union[int, str] = 2
head_memory: typing.Union[int, str] = 8
head_gpus: int = 0
head_custom_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
min_cpus: typing.Union[int, str] = 1
max_cpus: typing.Union[int, str] = 1
Expand All @@ -54,6 +68,9 @@ class ClusterConfiguration:
dispatch_priority: str = None
write_to_file: bool = False
verify_tls: bool = True
worker_custom_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
custom_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
overwrite_default_resource_mapping: bool = False

def __post_init__(self):
if not self.verify_tls:
Expand All @@ -63,6 +80,36 @@ def __post_init__(self):
self._memory_to_string()
self._str_mem_no_unit_add_GB()

def _combine_custom_resource_mapping(self):
if self.overwrite_default_resource_mapping:
self.custom_resource_mapping = self.worker_custom_resource_requests
else:
if overwritten := self.worker_custom_resource_requests.keys().intersection(
DEFAULT_RESOURCE_MAPPING.keys()
):
warnings.warn(
f"Overwriting default resource mapping for {overwritten}",
UserWarning,
)
self.custom_resource_mapping = {
**DEFAULT_RESOURCE_MAPPING,
**self.worker_custom_resource_requests,
}

def _gpu_to_resource(self):
if self.head_gpus:
if "nvidia.com/gpu" in self.head_custom_resource_requests:
raise ValueError(
"nvidia.com/gpu already exists in head_custom_resource_requests"
)
self.head_custom_resource_requests["nvidia.com/gpu"] = self.head_gpus
if self.num_gpus:
if "nvidia.com/gpu" in self.worker_custom_resource_requests:
raise ValueError(
"nvidia.com/gpu already exists in worker_custom_resource_requests"
)
self.worker_custom_resource_requests["nvidia.com/gpu"] = self.num_gpus

def _str_mem_no_unit_add_GB(self):
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
self.head_memory = f"{self.head_memory}G"
Expand Down
Loading

0 comments on commit 7b5bacc

Please sign in to comment.