diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 015f15eda..b27413d91 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -145,6 +145,7 @@ def create_app_wrapper(self): gpu = self.config.num_gpus workers = self.config.num_workers template = self.config.template + template_update_dict = self.config.template_update_dict image = self.config.image appwrapper = self.config.appwrapper env = self.config.envs @@ -176,6 +177,7 @@ def create_app_wrapper(self): labels=labels, volumes=volumes, volume_mounts=volume_mounts, + template_update_dict = template_update_dict ) # creates a new cluster with the provided or default spec diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 970673652..c1a5c383b 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -46,6 +46,7 @@ class ClusterConfiguration: max_memory: typing.Union[int, str] = 2 num_gpus: int = 0 template: str = f"{dir}/templates/base-template.yaml" + template_update_dict = {} appwrapper: bool = False envs: dict = field(default_factory=dict) image: str = "" diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 3192ae1bc..b163a38a5 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -300,8 +300,10 @@ def generate_appwrapper( labels, volumes: list[client.V1Volume], volume_mounts: list[client.V1VolumeMount], + template_update_dict = {} ): cluster_yaml = read_template(template) + cluster_yaml.update(template_update_dict) appwrapper_name, cluster_name = gen_names(name) update_names(cluster_yaml, cluster_name, namespace) update_nodes(