From ffbdfecfed79a890bf250fbd840d3aee1c140fe7 Mon Sep 17 00:00:00 2001 From: Ignas Baranauskas Date: Tue, 1 Oct 2024 17:41:59 +0100 Subject: [PATCH] Refactor: kueue module --- src/codeflare_sdk/common/kueue/__init__.py | 5 ++ src/codeflare_sdk/common/kueue/kueue.py | 78 +++++++++++++++++++ .../ray/cluster/generate_yaml.py | 62 +-------------- tests/unit_test.py | 18 ++--- 4 files changed, 93 insertions(+), 70 deletions(-) create mode 100644 src/codeflare_sdk/common/kueue/__init__.py create mode 100644 src/codeflare_sdk/common/kueue/kueue.py diff --git a/src/codeflare_sdk/common/kueue/__init__.py b/src/codeflare_sdk/common/kueue/__init__.py new file mode 100644 index 000000000..b02e3c240 --- /dev/null +++ b/src/codeflare_sdk/common/kueue/__init__.py @@ -0,0 +1,5 @@ +from .kueue import ( + get_default_kueue_name, + local_queue_exists, + add_queue_label, +) diff --git a/src/codeflare_sdk/common/kueue/kueue.py b/src/codeflare_sdk/common/kueue/kueue.py new file mode 100644 index 000000000..0c207548d --- /dev/null +++ b/src/codeflare_sdk/common/kueue/kueue.py @@ -0,0 +1,78 @@ +# Copyright 2024 IBM, Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from codeflare_sdk.common import _kube_api_error_handling +from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client +from kubernetes import client +from kubernetes.client.exceptions import ApiException + + +def get_default_kueue_name(namespace: str): + # If the local queue is set, use it. Otherwise, try to use the default queue. + try: + config_check() + api_instance = client.CustomObjectsApi(get_api_client()) + local_queues = api_instance.list_namespaced_custom_object( + group="kueue.x-k8s.io", + version="v1beta1", + namespace=namespace, + plural="localqueues", + ) + except ApiException as e: # pragma: no cover + if e.status == 404 or e.status == 403: + return + else: + return _kube_api_error_handling(e) + for lq in local_queues["items"]: + if ( + "annotations" in lq["metadata"] + and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"] + and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower() + == "true" + ): + return lq["metadata"]["name"] + + +def local_queue_exists(namespace: str, local_queue_name: str): + # get all local queues in the namespace + try: + config_check() + api_instance = client.CustomObjectsApi(get_api_client()) + local_queues = api_instance.list_namespaced_custom_object( + group="kueue.x-k8s.io", + version="v1beta1", + namespace=namespace, + plural="localqueues", + ) + except Exception as e: # pragma: no cover + return _kube_api_error_handling(e) + # check if local queue with the name provided in cluster config exists + for lq in local_queues["items"]: + if lq["metadata"]["name"] == local_queue_name: + return True + return False + + +def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]): + lq_name = local_queue or get_default_kueue_name(namespace) + if lq_name == None: + return + elif not local_queue_exists(namespace, lq_name): + raise ValueError( + "local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration" + ) + if not "labels" in item["metadata"]: + item["metadata"]["labels"] = {} + item["metadata"]["labels"].update({"kueue.x-k8s.io/queue-name": lq_name}) diff --git a/src/codeflare_sdk/ray/cluster/generate_yaml.py b/src/codeflare_sdk/ray/cluster/generate_yaml.py index f0d70cf52..0b174650a 100755 --- a/src/codeflare_sdk/ray/cluster/generate_yaml.py +++ b/src/codeflare_sdk/ray/cluster/generate_yaml.py @@ -18,18 +18,17 @@ """ import json -from typing import Optional import typing import yaml import os import uuid from kubernetes import client from ...common import _kube_api_error_handling +from ...common.kueue import add_queue_label from ...common.kubernetes_cluster.auth import ( get_api_client, config_check, ) -from kubernetes.client.exceptions import ApiException import codeflare_sdk @@ -229,65 +228,6 @@ def del_from_list_by_name(l: list, target: typing.List[str]) -> list: return [x for x in l if x["name"] not in target] -def get_default_kueue_name(namespace: str): - # If the local queue is set, use it. Otherwise, try to use the default queue. - try: - config_check() - api_instance = client.CustomObjectsApi(get_api_client()) - local_queues = api_instance.list_namespaced_custom_object( - group="kueue.x-k8s.io", - version="v1beta1", - namespace=namespace, - plural="localqueues", - ) - except ApiException as e: # pragma: no cover - if e.status == 404 or e.status == 403: - return - else: - return _kube_api_error_handling(e) - for lq in local_queues["items"]: - if ( - "annotations" in lq["metadata"] - and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"] - and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower() - == "true" - ): - return lq["metadata"]["name"] - - -def local_queue_exists(namespace: str, local_queue_name: str): - # get all local queues in the namespace - try: - config_check() - api_instance = client.CustomObjectsApi(get_api_client()) - local_queues = api_instance.list_namespaced_custom_object( - group="kueue.x-k8s.io", - version="v1beta1", - namespace=namespace, - plural="localqueues", - ) - except Exception as e: # pragma: no cover - return _kube_api_error_handling(e) - # check if local queue with the name provided in cluster config exists - for lq in local_queues["items"]: - if lq["metadata"]["name"] == local_queue_name: - return True - return False - - -def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]): - lq_name = local_queue or get_default_kueue_name(namespace) - if lq_name == None: - return - elif not local_queue_exists(namespace, lq_name): - raise ValueError( - "local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration" - ) - if not "labels" in item["metadata"]: - item["metadata"]["labels"] = {} - item["metadata"]["labels"].update({"kueue.x-k8s.io/queue-name": lq_name}) - - def augment_labels(item: dict, labels: dict): if not "labels" in item["metadata"]: item["metadata"]["labels"] = {} diff --git a/tests/unit_test.py b/tests/unit_test.py index 235eed0e7..74da56b77 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -968,7 +968,7 @@ def test_ray_details(mocker, capsys): return_value="", ) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", ) cf = Cluster( @@ -2007,7 +2007,7 @@ def test_get_cluster_openshift(mocker): ] mocker.patch("kubernetes.client.ApisApi", return_value=mock_api) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", ) @@ -2042,7 +2042,7 @@ def custom_side_effect(group, version, namespace, plural, **kwargs): ], ) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", ) @@ -2085,7 +2085,7 @@ def test_get_cluster(mocker): return_value=ingress_retrieval(cluster_name="quicktest", client_ing=True), ) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", ) cluster = get_cluster("quicktest") @@ -2123,7 +2123,7 @@ def test_get_cluster_no_mcad(mocker): return_value=ingress_retrieval(cluster_name="quicktest", client_ing=True), ) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", ) cluster = get_cluster("quicktest") @@ -2359,7 +2359,7 @@ def test_cluster_status(mocker): mocker.patch("kubernetes.client.ApisApi.get_api_versions") mocker.patch("kubernetes.config.load_kube_config", return_value="ignore") mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", ) fake_aw = AppWrapper("test", AppWrapperStatus.FAILED) @@ -2462,7 +2462,7 @@ def test_wait_ready(mocker, capsys): "codeflare_sdk.ray.cluster.cluster._ray_cluster_status", return_value=None ) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", ) mocker.patch.object( @@ -2694,11 +2694,11 @@ def test_cluster_throw_for_no_raycluster(mocker: MockerFixture): return_value="opendatahub", ) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.get_default_kueue_name", + "codeflare_sdk.common.kueue.kueue.get_default_kueue_name", return_value="default", ) mocker.patch( - "codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists", + "codeflare_sdk.common.kueue.kueue.local_queue_exists", return_value="true", )