diff --git a/src/codeflare_sdk/__init__.py b/src/codeflare_sdk/__init__.py index a1b5535c4..9ab5c7450 100644 --- a/src/codeflare_sdk/__init__.py +++ b/src/codeflare_sdk/__init__.py @@ -21,6 +21,10 @@ KubeConfigFileAuthentication, ) +from .common.kueue import ( + list_local_queues, +) + from .common.utils import generate_cert from .common.utils.demos import copy_demo_nbs diff --git a/src/codeflare_sdk/common/kubernetes_cluster/kube_api_helpers.py b/src/codeflare_sdk/common/kubernetes_cluster/kube_api_helpers.py index efa1d2b6c..8974a3f30 100644 --- a/src/codeflare_sdk/common/kubernetes_cluster/kube_api_helpers.py +++ b/src/codeflare_sdk/common/kubernetes_cluster/kube_api_helpers.py @@ -20,6 +20,7 @@ import executing from kubernetes import client, config from urllib3.util import parse_url +import os # private methods diff --git a/src/codeflare_sdk/common/kueue/__init__.py b/src/codeflare_sdk/common/kueue/__init__.py index b02e3c240..c9c641c1e 100644 --- a/src/codeflare_sdk/common/kueue/__init__.py +++ b/src/codeflare_sdk/common/kueue/__init__.py @@ -2,4 +2,5 @@ get_default_kueue_name, local_queue_exists, add_queue_label, + list_local_queues, ) diff --git a/src/codeflare_sdk/common/kueue/kueue.py b/src/codeflare_sdk/common/kueue/kueue.py index 0c207548d..c063c6fe0 100644 --- a/src/codeflare_sdk/common/kueue/kueue.py +++ b/src/codeflare_sdk/common/kueue/kueue.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, List from codeflare_sdk.common import _kube_api_error_handling from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client from kubernetes import client @@ -45,6 +45,53 @@ def get_default_kueue_name(namespace: str): return lq["metadata"]["name"] +def list_local_queues( + namespace: Optional[str] = None, flavors: Optional[List[str]] = None +) -> List[dict]: + """ + This function lists all local queues in the namespace provided. + + If no namespace is provided, it will use the current namespace. If flavors is provided, it will only return local + queues that support all the flavors provided. + + Note: + Depending on the version of the local queue API, the available flavors may not be present in the response. + + Args: + namespace (str, optional): The namespace to list local queues from. Defaults to None. + flavors (List[str], optional): The flavors to filter local queues by. Defaults to None. + Returns: + List[dict]: A list of dictionaries containing the name of the local queue and the available flavors + """ + + from ...ray.cluster.cluster import get_current_namespace + + if namespace is None: # pragma: no cover + namespace = get_current_namespace() + try: + config_check() + api_instance = client.CustomObjectsApi(get_api_client()) + local_queues = api_instance.list_namespaced_custom_object( + group="kueue.x-k8s.io", + version="v1beta1", + namespace=namespace, + plural="localqueues", + ) + except ApiException as e: # pragma: no cover + return _kube_api_error_handling(e) + to_return = [] + for lq in local_queues["items"]: + item = {"name": lq["metadata"]["name"]} + if "flavors" in lq["status"]: + item["flavors"] = [f["name"] for f in lq["status"]["flavors"]] + if flavors is not None and not set(flavors).issubset(set(item["flavors"])): + continue + elif flavors is not None: + continue # NOTE: may be indicative old local queue API and might be worth while raising or warning here + to_return.append(item) + return to_return + + def local_queue_exists(namespace: str, local_queue_name: str): # get all local queues in the namespace try: diff --git a/src/codeflare_sdk/common/kueue/test_kueue.py b/src/codeflare_sdk/common/kueue/test_kueue.py index a4e984c30..e9de364d8 100644 --- a/src/codeflare_sdk/common/kueue/test_kueue.py +++ b/src/codeflare_sdk/common/kueue/test_kueue.py @@ -18,6 +18,7 @@ import os import filecmp from pathlib import Path +from .kueue import list_local_queues parent = Path(__file__).resolve().parents[4] # project directory aw_dir = os.path.expanduser("~/.codeflare/resources/") @@ -131,6 +132,47 @@ def test_get_local_queue_exists_fail(mocker): ) +def test_list_local_queues(mocker): + mocker.patch("kubernetes.client.ApisApi.get_api_versions") + mocker.patch( + "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", + return_value={ + "items": [ + { + "metadata": {"name": "lq1"}, + "status": {"flavors": [{"name": "default"}]}, + }, + { + "metadata": {"name": "lq2"}, + "status": { + "flavors": [{"name": "otherflavor"}, {"name": "default"}] + }, + }, + ] + }, + ) + lqs = list_local_queues("ns") + assert lqs == [ + {"name": "lq1", "flavors": ["default"]}, + {"name": "lq2", "flavors": ["otherflavor", "default"]}, + ] + lqs = list_local_queues("ns", flavors=["otherflavor"]) + assert lqs == [{"name": "lq2", "flavors": ["otherflavor", "default"]}] + mocker.patch( + "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", + return_value={ + "items": [ + { + "metadata": {"name": "lq1"}, + "status": {}, + }, + ] + }, + ) + lqs = list_local_queues("ns", flavors=["default"]) + assert lqs == [] + + # Make sure to always keep this function last def test_cleanup(): os.remove(f"{aw_dir}unit-test-cluster-kueue.yaml")