diff --git a/pyproject.toml b/pyproject.toml index a295e383b..1979a378c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ cryptography = "40.0.2" executing = "1.2.0" pydantic = "< 2" ipywidgets = "8.1.2" +mergedeep = "1.3.4" [tool.poetry.group.docs] optional = true diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 015f15eda..371fd0949 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -145,6 +145,7 @@ def create_app_wrapper(self): gpu = self.config.num_gpus workers = self.config.num_workers template = self.config.template + template_update_dict = self.config.template_update_dict image = self.config.image appwrapper = self.config.appwrapper env = self.config.envs @@ -176,6 +177,7 @@ def create_app_wrapper(self): labels=labels, volumes=volumes, volume_mounts=volume_mounts, + template_update_dict=template_update_dict, ) # creates a new cluster with the provided or default spec diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 970673652..c1a5c383b 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -46,6 +46,7 @@ class ClusterConfiguration: max_memory: typing.Union[int, str] = 2 num_gpus: int = 0 template: str = f"{dir}/templates/base-template.yaml" + template_update_dict = {} appwrapper: bool = False envs: dict = field(default_factory=dict) image: str = "" diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 3192ae1bc..9a513ae4a 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -30,6 +30,7 @@ from os import urandom from base64 import b64encode from urllib3.util import parse_url +from mergedeep import merge def read_template(template): @@ -300,8 +301,10 @@ def generate_appwrapper( labels, volumes: list[client.V1Volume], volume_mounts: list[client.V1VolumeMount], + template_update_dict={}, ): cluster_yaml = read_template(template) + cluster_yaml = merge(cluster_yaml, template_update_dict) appwrapper_name, cluster_name = gen_names(name) update_names(cluster_yaml, cluster_name, namespace) update_nodes(