From 9b9367b3254f7b14d5238d1e444fe4f84ea2869d Mon Sep 17 00:00:00 2001 From: Jiri Petrlik Date: Wed, 19 Jun 2024 16:34:33 +0200 Subject: [PATCH] RHOAIENG-8098 - ClusterConfiguration can be patched --- src/codeflare_sdk/cluster/cluster.py | 2 ++ src/codeflare_sdk/cluster/config.py | 1 + src/codeflare_sdk/utils/generate_yaml.py | 2 ++ 3 files changed, 5 insertions(+) diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 015f15eda..371fd0949 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -145,6 +145,7 @@ def create_app_wrapper(self): gpu = self.config.num_gpus workers = self.config.num_workers template = self.config.template + template_update_dict = self.config.template_update_dict image = self.config.image appwrapper = self.config.appwrapper env = self.config.envs @@ -176,6 +177,7 @@ def create_app_wrapper(self): labels=labels, volumes=volumes, volume_mounts=volume_mounts, + template_update_dict=template_update_dict, ) # creates a new cluster with the provided or default spec diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 970673652..c1a5c383b 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -46,6 +46,7 @@ class ClusterConfiguration: max_memory: typing.Union[int, str] = 2 num_gpus: int = 0 template: str = f"{dir}/templates/base-template.yaml" + template_update_dict = {} appwrapper: bool = False envs: dict = field(default_factory=dict) image: str = "" diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 3192ae1bc..3146aa3d8 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -300,8 +300,10 @@ def generate_appwrapper( labels, volumes: list[client.V1Volume], volume_mounts: list[client.V1VolumeMount], + template_update_dict={}, ): cluster_yaml = read_template(template) + cluster_yaml.update(template_update_dict) appwrapper_name, cluster_name = gen_names(name) update_names(cluster_yaml, cluster_name, namespace) update_nodes(