From 32cb75151dc549cd014a46732555cff97f1495d4 Mon Sep 17 00:00:00 2001 From: Fiona Waters Date: Fri, 10 May 2024 10:28:03 +0100 Subject: [PATCH] adding validation for local_queue provided in cluster config --- src/codeflare_sdk/utils/generate_yaml.py | 28 ++++++++++++++++++++++++ tests/unit_test.py | 8 +++++++ 2 files changed, 36 insertions(+) diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index f5de1fbae..2ea6dd78d 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -308,6 +308,26 @@ def get_default_kueue_name(namespace: str): ) +def local_queue_exists(namespace: str, local_queue_name: str): + # get all local queues in the namespace + try: + config_check() + api_instance = client.CustomObjectsApi(api_config_handler()) + local_queues = api_instance.list_namespaced_custom_object( + group="kueue.x-k8s.io", + version="v1beta1", + namespace=namespace, + plural="localqueues", + ) + except Exception as e: # pragma: no cover + return _kube_api_error_handling(e) + # check if local queue with the name provided in cluster config exists + for lq in local_queues["items"]: + if lq["metadata"]["name"] == local_queue_name: + return True + return False + + def write_components( user_yaml: dict, output_file_name: str, @@ -324,6 +344,10 @@ def write_components( open(output_file_name, "w").close() lq_name = local_queue or get_default_kueue_name(namespace) cluster_labels = labels + if not local_queue_exists(namespace, lq_name): + raise ValueError( + "local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration" + ) with open(output_file_name, "a") as outfile: for component in components: if "generictemplate" in component: @@ -355,6 +379,10 @@ def load_components( components = user_yaml.get("spec", "resources")["resources"].get("GenericItems") lq_name = local_queue or get_default_kueue_name(namespace) cluster_labels = labels + if not local_queue_exists(namespace, lq_name): + raise ValueError( + "local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration" + ) for component in components: if "generictemplate" in component: if ( diff --git a/tests/unit_test.py b/tests/unit_test.py index e8837a139..1d4ca3616 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -344,6 +344,10 @@ def test_cluster_creation_no_mcad_local_queue(mocker): "kubernetes.client.CustomObjectsApi.get_cluster_custom_object", return_value={"spec": {"domain": "apps.cluster.awsroute.org"}}, ) + mocker.patch( + "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", + return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"), + ) config = createClusterConfig() config.name = "unit-test-cluster-ray" config.mcad = False @@ -3046,6 +3050,10 @@ def test_cluster_throw_for_no_raycluster(mocker: MockerFixture): "codeflare_sdk.utils.generate_yaml.get_default_kueue_name", return_value="default", ) + mocker.patch( + "codeflare_sdk.utils.generate_yaml.local_queue_exists", + return_value="true", + ) def throw_if_getting_raycluster(group, version, namespace, plural): if plural == "rayclusters":