diff --git a/scripts/component_integration_tests.py b/scripts/component_integration_tests.py index 3ff266acb..5ca4e85e8 100755 --- a/scripts/component_integration_tests.py +++ b/scripts/component_integration_tests.py @@ -62,7 +62,7 @@ def main() -> None: if scheduler in ( "kubernetes", "kubernetes_mcad", - "kueue_job", + "kueue", "local_docker", "aws_batch", "lsf", @@ -97,13 +97,13 @@ def main() -> None: "namespace": "torchx-dev", }, }, - "kueue_job": { + "kueue": { "providers": [ component_provider, examples_app_defs_providers, ], "image": torchx_image, - "cfg": {"namespace": "torchx-dev", "local_queue": "default-kueue"}, + "cfg": {"namespace": "torchx-dev", "local_queue": "torchx-local-queue"}, }, "local_cwd": { "providers": [ diff --git a/scripts/kueue_test.py b/scripts/kueue_test.py new file mode 100644 index 000000000..4c1913e7a --- /dev/null +++ b/scripts/kueue_test.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchx.components.dist import ddp +from torchx.runner import get_runner +from integ_test_utils import ( + build_images, + BuildInfo, + push_images, + MissingEnvError +) +import argparse +from torchx.specs import AppState +from torchx.util.types import none_throws + +def argparser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Kueue dist trainer integration test runner.") + parser.add_argument("--container_repo", type=str) + parser.add_argument("--dryrun", action="store_true", + help="Does not actually submit the app," " just prints the scheduler request",) + return parser + +def build_and_push_image(container_repo: str) -> BuildInfo: + build = build_images() + push_images(build, container_repo=container_repo) + return build + +def run_kueue_test(dryrun: bool = False): + # Gather args & build image + print("Building image") + args = argparser().parse_args() + build = build_and_push_image(args.container_repo) + image = build.torchx_image + # Create the app definition + runner = get_runner("kueue") + app = ddp( + name="kueue-test", + image=image, + m="torchx.examples.apps.lightning.train", + cpu=1, + memMB=4000, + j="1x1", + ) + # Pass config variables + cfg={"namespace":"torchx-dev", "local_queue":"torchx-local-queue"} + print("Submitting job") + if dryrun: + dryrun_info = runner.dryrun(app, "kueue", cfg) + print(f"Dryrun info: {dryrun_info}") + else: + app_handle = runner.run(app, "kueue", cfg) + print(app_handle) + runner.wait(app_handle) + final_status = runner.status(app_handle) + print(f"Final status: {final_status}") + if none_throws(final_status).state != AppState.SUCCEEDED: + raise Exception(f"Dist app failed with status: {final_status}") + +def main() -> None: + args = argparser().parse_args() + + try: + run_kueue_test(args.dryrun) + except MissingEnvError: + print("Skip runnig tests, executed only docker build step") + +if __name__ == "__main__": + main() + + diff --git a/scripts/setup_minikube_kueue.sh b/scripts/setup_minikube_kueue.sh new file mode 100644 index 000000000..c6e8b42f5 --- /dev/null +++ b/scripts/setup_minikube_kueue.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +set -eux +minikube delete +minikube start --driver=docker --cpus=max --memory=max --nodes=2 +minikube addons enable registry + +# setup multi node volumes +# https://github.com/kubernetes/minikube/issues/12360#issuecomment-1430243861 +minikube addons disable storage-provisioner +minikube addons disable default-storageclass +minikube addons enable volumesnapshots +minikube addons enable csi-hostpath-driver +kubectl patch storageclass csi-hostpath-sc -p '{"metadata": {"annotations":{"storageclass.kubernetes.io/is-default-class":"true"}}}' + +# create namespace +kubectl create namespace torchx-dev + +# install Kueue and Kueue related resources +VERSION=v0.6.0 +kubectl apply --server-side -f https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml + +# Function to check if the kueue manager pod is running +check_pod_status() { + local status=$(kubectl get pods -n kueue-system | grep "kueue-controller-manager" | awk '{print $3}') + echo "$status" +} + +# Wait until the pod is in the 'Running' state +echo "Waiting for kueue-controller-manager pod to be running in the kueue-system namespace..." +while [[ $(check_pod_status) != "Running" ]]; do + sleep 5 +done +# Function to check if the service exists +check_service_existence() { + kubectl get svc kueue-webhook-service -n kueue-system --no-headers 2>/dev/null +} + +# Wait until the service exists +echo "Waiting for kueue-webhook-service to exist in the kueue-system namespace..." +while [[ $(check_service_existence) == "" ]]; do + sleep 5 +done +echo "kueue-webhook-service exists in the kueue-system namespace." +sleep 20 +# Create Cluster Queue - UPDATE MAX VALUES +cat < object: from kubernetes import client @@ -177,169 +173,6 @@ def sanitize_for_serialization(obj: object) -> object: return api.sanitize_for_serialization(obj) -def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod": - from kubernetes.client.models import ( # noqa: F811 redefinition of unused - V1Container, - V1ContainerPort, - V1EmptyDirVolumeSource, - V1EnvVar, - V1HostPathVolumeSource, - V1ObjectMeta, - V1PersistentVolumeClaimVolumeSource, - V1Pod, - V1PodSpec, - V1ResourceRequirements, - V1SecurityContext, - V1Volume, - V1VolumeMount, - ) - - # limits puts an upper cap on the resources a pod may consume. - # requests is how much the scheduler allocates. We assume that the jobs will - # be allocation the whole machine so requests is slightly lower than the - # requested resources to account for the Kubernetes node reserved resources. - limits = {} - requests = {} - - resource = role.resource - if resource.cpu > 0: - mcpu = int(resource.cpu * 1000) - limits["cpu"] = f"{mcpu}m" - request_mcpu = max(mcpu - RESERVED_MILLICPU, 0) - requests["cpu"] = f"{request_mcpu}m" - if resource.memMB > 0: - limits["memory"] = f"{int(resource.memMB)}M" - request_memMB = max(int(resource.memMB) - RESERVED_MEMMB, 0) - requests["memory"] = f"{request_memMB}M" - if resource.gpu > 0: - requests["nvidia.com/gpu"] = limits["nvidia.com/gpu"] = str(resource.gpu) - - for device_name, device_limit in resource.devices.items(): - limits[device_name] = str(device_limit) - - resources = V1ResourceRequirements( - limits=limits, - requests=requests, - ) - - node_selector: Dict[str, str] = {} - if LABEL_INSTANCE_TYPE in resource.capabilities: - node_selector[LABEL_INSTANCE_TYPE] = resource.capabilities[LABEL_INSTANCE_TYPE] - - # To support PyTorch dataloaders we need to set /dev/shm to larger than the - # 64M default so we mount an unlimited sized tmpfs directory on it. - SHM_VOL = "dshm" - volumes = [ - V1Volume( - name=SHM_VOL, - empty_dir=V1EmptyDirVolumeSource( - medium="Memory", - ), - ), - ] - volume_mounts = [ - V1VolumeMount(name=SHM_VOL, mount_path="/dev/shm"), - ] - security_context = V1SecurityContext() - - for i, mount in enumerate(role.mounts): - mount_name = f"mount-{i}" - if isinstance(mount, BindMount): - volumes.append( - V1Volume( - name=mount_name, - host_path=V1HostPathVolumeSource( - path=mount.src_path, - ), - ) - ) - volume_mounts.append( - V1VolumeMount( - name=mount_name, - mount_path=mount.dst_path, - read_only=mount.read_only, - ) - ) - elif isinstance(mount, VolumeMount): - volumes.append( - V1Volume( - name=mount_name, - persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( - claim_name=mount.src, - ), - ) - ) - volume_mounts.append( - V1VolumeMount( - name=mount_name, - mount_path=mount.dst_path, - read_only=mount.read_only, - ) - ) - elif isinstance(mount, DeviceMount): - volumes.append( - V1Volume( - name=mount_name, - host_path=V1HostPathVolumeSource( - path=mount.src_path, - ), - ) - ) - volume_mounts.append( - V1VolumeMount( - name=mount_name, - mount_path=mount.dst_path, - read_only=( - "w" not in mount.permissions and "m" not in mount.permissions - ), - ) - ) - security_context.privileged = True - else: - raise TypeError(f"unknown mount type {mount}") - - container = V1Container( - command=[role.entrypoint] + role.args, - image=role.image, - name=name, - env=[ - V1EnvVar( - name=name, - value=value, - ) - for name, value in role.env.items() - ], - resources=resources, - ports=[ - V1ContainerPort( - name=name, - container_port=port, - ) - for name, port in role.port_map.items() - ], - volume_mounts=volume_mounts, - security_context=security_context, - ) - - return V1Pod( - spec=V1PodSpec( - containers=[container], - restart_policy="Never", - service_account_name=service_account, - volumes=volumes, - node_selector=node_selector, - ), - metadata=V1ObjectMeta( - annotations={ - # Disable the istio sidecar as it prevents the containers from - # exiting once finished. - ANNOTATION_ISTIO_SIDECAR: "false", - }, - labels={}, - ), - ) - - def app_to_resource( app: AppDef, queue: str, diff --git a/torchx/schedulers/kueue_scheduler.py b/torchx/schedulers/kueue_scheduler.py new file mode 100644 index 000000000..457054bea --- /dev/null +++ b/torchx/schedulers/kueue_scheduler.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" + +This contains the TorchX Kubernetes Kueue Job scheduler which can be used to run TorchX +components on a Kubernetes cluster via Kueue. + +Prerequisites +============== + +The TorchX Kubernetes scheduler depends on Kueue. + +""" +import json +import logging +import warnings +from dataclasses import dataclass +from datetime import datetime +from typing import ( + Any, + cast, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + TYPE_CHECKING, +) + +import torchx +import yaml +from torchx.schedulers.api import ( + AppDryRunInfo, + DescribeAppResponse, + filter_regex, + ListAppResponse, + Scheduler, + split_lines, + Stream, +) +from torchx.schedulers.ids import make_unique +from torchx.specs.api import ( + AppDef, + AppState, + BindMount, + CfgVal, + DeviceMount, + macros, + ReplicaState, + ReplicaStatus, + Role, + RoleStatus, + runopts, + VolumeMount, +) +from torchx.util.strings import normalize_str +from torchx.workspace.docker_workspace import DockerWorkspaceMixin +from typing_extensions import TypedDict +from torchx.util.role_to_pod import role_to_pod +if TYPE_CHECKING: + from docker import DockerClient + from kubernetes.client import ApiClient, BatchV1Api, CoreV1Api, CustomObjectsApi + from kubernetes.client.rest import ApiException + +logger: logging.Logger = logging.getLogger(__name__) + +# Kubernetes reserves a small amount of resources per host for the system. For +# TorchX we always assume the entire host is being requested so we adjust the +# requested numbers account for the node reserved resources. +# +# https://kubernetes.io/docs/tasks/administer-cluster/reserve-compute-resources/ +RESERVED_MILLICPU = 100 +RESERVED_MEMMB = 1024 + +JOB_STATE: Dict[str, AppState] = { + # Pending is the phase that job is pending in the queue, waiting for + # scheduling decision + "Pending": AppState.PENDING, + # Aborting is the phase that job is aborted, waiting for releasing pods + "Aborting": AppState.RUNNING, + # Aborted is the phase that job is aborted by user or error handling + "Aborted": AppState.CANCELLED, + # Running is the phase that minimal available tasks of Job are running + "Running": AppState.RUNNING, + # Restarting is the phase that the Job is restarted, waiting for pod + # releasing and recreating + "Restarting": AppState.RUNNING, + # Completed is the phase that all tasks of Job are completed successfully + "Completed": AppState.SUCCEEDED, + # Terminating is the phase that the Job is terminated, waiting for releasing + # pods + "Terminating": AppState.RUNNING, + # Teriminated is the phase that the job is finished unexpected, e.g. events + "Terminated": AppState.FAILED, + # Failed is the phase that the job has failed + "Failed": AppState.FAILED, + # Suspended is the phase that the job has been suspended by Kueue + "Suspended": AppState.SUSPENDED, +} + +KUEUE_STATE: Dict[str, ReplicaState] = { + # Kueue related States + # JobSuspended is the state where Kueue has suspended the job + "JobSuspended": ReplicaState.SUSPENDED, + # JobResumed is the state where Kueue releases the Job + "JobResumed": ReplicaState.RESUMED, + # Unknown is the state where the Job state is unknown + "Unknown": ReplicaState.UNKNOWN, +} + +LABEL_VERSION = "torchx.pytorch.org/version" +LABEL_APP_NAME = "torchx.pytorch.org/app-name" +LABEL_ROLE_INDEX = "torchx.pytorch.org/role-index" +LABEL_ROLE_NAME = "torchx.pytorch.org/role-name" +LABEL_REPLICA_ID = "torchx.pytorch.org/replica-id" +LABEL_KUBE_APP_NAME = "app.kubernetes.io/name" +LABEL_ORGANIZATION = "app.kubernetes.io/managed-by" +LABEL_UNIQUE_NAME = "app.kubernetes.io/instance" + +# Local Kueue and Priority class labels +LOCAL_QUEUE_LABEL = "kueue.x-k8s.io/queue-name" +PRIORITY_CLASS_LABEL = "kueue.x-k8s.io/priority-class" + + +def sanitize_for_serialization(obj: object) -> object: + from kubernetes import client + + api = client.ApiClient() + return api.sanitize_for_serialization(obj) + +def app_to_resource( + app: AppDef, + service_account: Optional[str], + local_queue: Optional[str] = None, + priority_class: Optional[str] = None, + annotations: Optional[dict] = None, +) -> Dict[str, object]: + """ + app_to_resource creates a kubernetes batch job resource definition from + the provided AppDef. The resource definition can be used to launch the + app on Kubernetes. + + The local queue is a required variable to add to the job labels. + priority_class is used to provide the workload priority class name see: https://kueue.sigs.k8s.io/docs/concepts/workload_priority_class/#how-to-use-workloadpriorityclass-on-jobs + """ + unique_app_id = normalize_str(make_unique(app.name)) + for role_idx, role in enumerate(app.roles): + for replica_id in range(role.num_replicas): + values = macros.Values( + img_root="", + app_id=unique_app_id, + replica_id=str(replica_id), + rank0_env=f"KUEUE_{normalize_str(app.roles[0].name)}_0_HOSTS".upper(), + ) + if role_idx == 0 and replica_id == 0: + values.rank0_env = "TORCHX_RANK0_HOST" + name = normalize_str(f"{role.name}-{replica_id}") + replica_role = values.apply(role) + if role_idx == 0 and replica_id == 0: + replica_role.env["TORCHX_RANK0_HOST"] = "localhost" + + pod = role_to_pod(name, replica_role, service_account) + pod.metadata.labels.update( + pod_labels( + app=app, + role_idx=role_idx, + role=role, + replica_id=replica_id, + app_id=unique_app_id, + local_queue=local_queue, + priority_class=priority_class, + ) + ) + task: Dict[str, Any] = { + "replicas": 1, + "name": name, + "template": pod, + } + if role.max_retries > 0: + task["backoffLimit"] = role.max_retries + + if role.min_replicas is not None: + # first min_replicas tasks are required, afterward optional + task["minAvailable"] = 1 if replica_id < role.min_replicas else 0 + + resource: Dict[str, object] = { + "apiVersion": "batch/v1", + "kind": "Job", + "metadata": {"name": f"{unique_app_id}"}, + "spec": task, + } + if annotations is not None: + resource["metadata"]["annotations"] = annotations + return resource + + +@dataclass +class Kueue: + images_to_push: Dict[str, Tuple[str, str]] + resource: Dict[str, object] + + def __str__(self) -> str: + return yaml.dump(sanitize_for_serialization(self.resource)) + + def __repr__(self) -> str: + return str(self) + + +class KueueOpts(TypedDict, total=False): + namespace: Optional[str] + image_repo: Optional[str] + service_account: Optional[str] + local_queue: Optional[str] + priority_class: Optional[str] + annotations: Optional[dict] + + +class KueueScheduler(DockerWorkspaceMixin, Scheduler[KueueOpts]): + """ + KueueScheduler is a TorchX scheduling interface to Kubernetes that relies on Kueue. + + You can install Kueue here https://kueue.sigs.k8s.io/docs/installation/#install-a-released-version + + .. code-block:: bash + + $ pip install torchx[kueue] + $ torchx run --scheduler kueue --scheduler_args namespace=default,local_queue="default-kueue",image_repo="user/alpine" utils.echo --image alpine:latest --msg hello + kueue://torchx_user/1234 + $ torchx status kueue://torchx_user/1234 + ... + + **Config Options** + + .. runopts:: + class: torchx.schedulers.kueue_scheduler.create_scheduler + + **Mounts** + + Mounting external filesystems/volumes is via the HostPath and + PersistentVolumeClaim support. + + * hostPath volumes: ``type=bind,src=,dst=[,readonly]`` + * PersistentVolumeClaim: ``type=volume,src=,dst=[,readonly]`` + * host devices: ``type=device,src=/dev/foo[,dst=][,perm=rwm]`` + If you specify a host device the job will run in privileged mode since + Kubernetes doesn't expose a way to pass `--device` to the underlying + container runtime. Users should prefer to use device plugins. + + See :py:func:`torchx.specs.parse_mounts` for more info. + + External docs: https://kubernetes.io/docs/concepts/storage/persistent-volumes/ + + **Resources / Allocation** + + To select a specific machine type you can add a capability to your resources + with ``node.kubernetes.io/instance-type`` which will constrain the launched + jobs to nodes of that instance type. + + >>> from torchx import specs + >>> specs.Resource( + ... cpu=4, + ... memMB=16000, + ... gpu=2, + ... capabilities={ + ... "node.kubernetes.io/instance-type": "", + ... }, + ... ) + Resource(...) + + Kubernetes may reserve some memory for the host. TorchX assumes you're + scheduling on whole hosts and thus will automatically reduce the resource + request by a small amount to account for the node reserved CPU and memory. + If you run into scheduling issues you may need to reduce the requested CPU + and memory from the host values. + + **Compatibility** + + .. compatibility:: + type: scheduler + features: + cancel: true + logs: true + distributed: true + describe: | + Partial support. KueJobScheduler will return job and job suspension status but does not provide the complete original AppSpec. + workspaces: true + mounts: true + elasticity: Requires Kueue >= v0.5.0 + """ + + def __init__( + self, + session_name: str, + client: Optional["ApiClient"] = None, + docker_client: Optional["DockerClient"] = None, + ) -> None: + # NOTE: make sure any new init options are supported in create_scheduler(...) + super().__init__("kueue", session_name, docker_client=docker_client) + + self._client = client + + def _api_client(self) -> "ApiClient": + from kubernetes import client, config + + c = self._client + if c is None: + configuration = client.Configuration() + try: + config.load_kube_config(client_configuration=configuration) + except config.ConfigException as e: + warnings.warn(f"failed to load kube config: {e}") + + c = self._client = client.ApiClient(configuration) + + return c + + def _custom_objects_api(self) -> "CustomObjectsApi": + from kubernetes import client + + return client.CustomObjectsApi(self._api_client()) + + def _batchv1_api(self) -> "BatchV1Api": + from kubernetes import client + + return client.BatchV1Api(self._api_client()) + + def _corev1_api(self) -> "CoreV1Api": + from kubernetes import client + + return client.CoreV1Api(self._api_client()) + + def _get_job_name_from_exception(self, e: "ApiException") -> Optional[str]: + try: + return json.loads(e.body)["details"]["name"] + except Exception as e: + logger.exception("Unable to retrieve job name, got exception", e) + return None + + def _get_active_context(self) -> Dict[str, Any]: + from kubernetes import config + + contexts, active_context = config.list_kube_config_contexts() + return active_context + + def schedule(self, dryrun_info: AppDryRunInfo[Kueue]) -> str: + from kubernetes.client.rest import ApiException + + cfg = dryrun_info._cfg + assert cfg is not None, f"{dryrun_info} missing cfg" + namespace = cfg.get("namespace") or "default" + + images_to_push = dryrun_info.request.images_to_push + self.push_images(images_to_push) + + resource = dryrun_info.request.resource + try: + resp = self._batchv1_api().create_namespaced_job( + namespace=namespace, + body=resource, + ) + + except ApiException as e: + if e.status == 409 and e.reason == "Conflict": + job_name = self._get_job_name_from_exception(e) + raise ValueError( + f"Job `{job_name}` already exists. This seems like a transient exception, try resubmitting job" + ) from e + else: + raise + + return f"{namespace}:{resp.metadata.name}" + + def _submit_dryrun(self, app: AppDef, cfg: KueueOpts) -> AppDryRunInfo[Kueue]: + # map any local images to the remote image + images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg)) + + service_account = cfg.get("service_account") + assert service_account is None or isinstance( + service_account, str + ), "service_account must be a str" + + local_queue = cfg.get("local_queue") + assert isinstance( + local_queue, str + ), "local_queue is a required string please specify local_queue in scheduler_args" + + priority_class = cfg.get("priority_class") + assert priority_class is None or isinstance( + priority_class, str + ), "priority_class must be a str" + + annotations = cfg.get("annotations") + assert annotations is None or isinstance( + annotations, dict + ), "annotations must be a dict" + + resource = app_to_resource( + app, service_account, local_queue, priority_class, annotations + ) + req = Kueue( + resource=resource, + images_to_push=images_to_push, + ) + return AppDryRunInfo(req, repr) + + def _validate(self, app: AppDef, scheduler: str) -> None: + # Skip validation step + pass + + def _cancel_existing(self, app_id: str) -> None: + from kubernetes import client + + namespace, name = app_id.split(":") + + self._batchv1_api().delete_namespaced_job( + namespace=namespace, + name=name, + body=client.V1DeleteOptions(propagation_policy="Foreground"), + ) + + def _run_opts(self) -> runopts: + opts = runopts() + opts.add( + "namespace", + type_=str, + help="Kubernetes namespace to schedule Job in", + default="default", + ) + opts.add( + "service_account", + type_=str, + help="The service account name to set on the pod specs", + ) + opts.add( + "local_queue", + type_=str, + help="The Local Kueue name to set on the local Kueue label", + ) + opts.add( + "priority_class", + type_=str, + help="The kueue priority class name to use for the priority class label", + ) + opts.add( + "annotations", + type_=dict, + help="The annotations to add to the job", + ) + return opts + + def describe(self, app_id: str) -> Optional[DescribeAppResponse]: + from kubernetes import client + + namespace, name = app_id.split(":") + roles = {} + roles_statuses = {} + + try: + api_instance = self._batchv1_api() + job = api_instance.read_namespaced_job_status(name, namespace) + except client.ApiException as e: + return f"Exception: {e}" + try: + status = job.status + except Exception as e: + print(f"Cannot gather job status: {e}") + status = None + app_state = None + if status: + for condition in status.conditions or []: + role, _, idx = job.metadata.name.rpartition("-") + condition_reason = condition.reason + if condition.type == "Suspended": + app_state = JOB_STATE["Suspended"] + state = KUEUE_STATE[condition_reason] + if condition.reason == "JobResumed": + state = KUEUE_STATE[condition_reason] + if status.active is not None: + app_state = JOB_STATE["Running"] + + elif status.active is not None: + state = app_state = JOB_STATE["Running"] + elif condition.type == "Complete": + state = app_state = JOB_STATE["Completed"] + elif condition.type == "Failed": + state = app_state = JOB_STATE["Failed"] + else: + state = app_state = JOB_STATE["Pending"] + + if role not in roles: + roles[role] = Role(name=role, num_replicas=0, image="") + roles_statuses[role] = RoleStatus(role, []) + + roles[role].num_replicas += 1 + roles_statuses[role].replicas.append( + ReplicaStatus(id=0, role=role, state=state, hostname="") + ) + else: + app_state = AppState.UNKNOWN + + return DescribeAppResponse( + app_id=app_id, + roles=list(roles.values()), + roles_statuses=list(roles_statuses.values()), + state=app_state, + ) + + def log_iter( + self, + app_id: str, + role_name: str, + k: int = 0, + regex: Optional[str] = None, + since: Optional[datetime] = None, + until: Optional[datetime] = None, + should_tail: bool = False, + streams: Optional[Stream] = None, + ) -> Iterable[str]: + assert until is None, "kubernetes API doesn't support until" + + if streams not in (None, Stream.COMBINED): + raise ValueError("KueueScheduler only supports COMBINED log stream") + + from kubernetes import client, watch + + namespace, name = app_id.split(":") + + pod_name = get_pod_name_from_job(self, job_name=name, namespace=namespace) + if pod_name is None: + raise ValueError("Pods not found. Is the Job Suspended?") + + args: Dict[str, object] = { + "name": pod_name, + "namespace": namespace, + "timestamps": True, + } + if since is not None: + args["since_seconds"] = (datetime.now() - since).total_seconds() + + core_api = client.CoreV1Api(self._api_client()) + if should_tail: + w = watch.Watch() + iterator = w.stream(core_api.read_namespaced_pod_log, **args) + else: + resp = core_api.read_namespaced_pod_log(**args) + iterator = split_lines(resp) + + if regex: + return filter_regex(regex, iterator) + else: + return iterator + + def list(self) -> List[ListAppResponse]: + active_context = self._get_active_context() + namespace = active_context["context"]["namespace"] + resp = self._custom_objects_api().list_namespaced_custom_object( + group="batch", + version="v1", + namespace=namespace, + plural="job", + timeout_seconds=30, + ) + return [ + ListAppResponse( + app_id=f"{namespace}:{app['metadata']['name']}", + state=JOB_STATE[app["status"]["state"]["phase"]], + ) + for app in resp["items"] + ] + + +def get_pod_name_from_job(self, job_name, namespace): + from kubernetes import client + + api_instance = self._batchv1_api() + + try: + job = api_instance.read_namespaced_job(job_name, namespace) + except client.ApiException as e: + return f"Api Exception: {e}" + + selector = job.spec.selector.match_labels + label = ",".join([f"{k}={v}" for k, v in selector.items()]) + + api_instance = self._corev1_api() + try: + pods = api_instance.list_namespaced_pod(namespace, label_selector=label) + except client.ApiException as e: + return f"Api Exception {e}" + + if not pods.items: + return None + else: + # Sort the list of pods by creation timestamp and get most recent one + sorted_pods = sorted( + pods.items, key=lambda p: str(p.metadata.creation_timestamp), reverse=True + ) + most_recent_pod = sorted_pods[0].metadata.name + + return most_recent_pod + + +def create_scheduler( + session_name: str, + client: Optional["ApiClient"] = None, + docker_client: Optional["DockerClient"] = None, + **kwargs: Any, +) -> KueueScheduler: + return KueueScheduler( + session_name=session_name, + client=client, + docker_client=docker_client, + ) + + +def pod_labels( + app: AppDef, + role_idx: int, + role: Role, + replica_id: int, + app_id: str, + local_queue: str, + priority_class: str, +) -> Dict[str, str]: + + labels = { + LABEL_VERSION: torchx.__version__, + LABEL_APP_NAME: app.name, + LABEL_ROLE_INDEX: str(role_idx), + LABEL_ROLE_NAME: role.name, + LABEL_REPLICA_ID: str(replica_id), + LABEL_KUBE_APP_NAME: app.name, + LABEL_ORGANIZATION: "torchx.pytorch.org", + LABEL_UNIQUE_NAME: app_id, + LOCAL_QUEUE_LABEL: local_queue, + } + + if priority_class is not None: + labels[PRIORITY_CLASS_LABEL] = priority_class + return labels diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index 8e6b356a9..ec9b92ae6 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -27,10 +27,9 @@ KubernetesJob, KubernetesOpts, KubernetesScheduler, - LABEL_INSTANCE_TYPE, - role_to_pod, ) from torchx.specs import AppState +from torchx.util.role_to_pod import LABEL_INSTANCE_TYPE, role_to_pod SKIP_DOCKER: bool = not has_docker() diff --git a/torchx/schedulers/test/kueue_job_scheduler_test.py b/torchx/schedulers/test/kueue_scheduler_test.py similarity index 93% rename from torchx/schedulers/test/kueue_job_scheduler_test.py rename to torchx/schedulers/test/kueue_scheduler_test.py index fe0abf4db..cb9b2a2b4 100644 --- a/torchx/schedulers/test/kueue_job_scheduler_test.py +++ b/torchx/schedulers/test/kueue_scheduler_test.py @@ -15,20 +15,19 @@ import torchx from torchx import schedulers, specs -# @manual=//torchx/schedulers:kueue_job_scheduler -from torchx.schedulers import kueue_job_scheduler +# @manual=//torchx/schedulers:kueue_scheduler +from torchx.schedulers import kueue_scheduler from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, ListAppResponse from torchx.schedulers.docker_scheduler import has_docker -from torchx.schedulers.kueue_job_scheduler import ( +from torchx.schedulers.kueue_scheduler import ( app_to_resource, create_scheduler, - KueueJob, - KueueJobOpts, - KueueJobScheduler, - LABEL_INSTANCE_TYPE, - role_to_pod, + Kueue, + KueueOpts, + KueueScheduler, ) from torchx.specs import AppState +from torchx.util.role_to_pod import role_to_pod, LABEL_INSTANCE_TYPE SKIP_DOCKER: bool = not has_docker() @@ -89,12 +88,12 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef: return specs.AppDef("test", roles=[trainer_role]) -class KueueJobSchedulerTest(unittest.TestCase): +class KueueSchedulerTest(unittest.TestCase): def test_create_scheduler(self) -> None: client = MagicMock() docker_client = MagicMock scheduler = create_scheduler("foo", client=client, docker_client=docker_client) - self.assertIsInstance(scheduler, kueue_job_scheduler.KueueJobScheduler) + self.assertIsInstance(scheduler, kueue_scheduler.KueueScheduler) self.assertEqual(scheduler._docker_client, docker_client) self.assertEqual(scheduler._client, client) @@ -102,7 +101,7 @@ def test_app_to_resource_resolved_macros(self) -> None: app = _test_app() unique_app_name = "app-name-42" with patch( - "torchx.schedulers.kueue_job_scheduler.make_unique" + "torchx.schedulers.kueue_scheduler.make_unique" ) as make_unique_ctx: make_unique_ctx.return_value = unique_app_name resource = app_to_resource( @@ -239,13 +238,13 @@ def test_role_to_pod(self) -> None: ) def test_submit_dryrun(self) -> None: - cfg = KueueJobOpts( + cfg = KueueOpts( {"namespace": "testnamespace", "local_queue": "default-kueue"} ) scheduler = create_scheduler("test") app = _test_app() with patch( - "torchx.schedulers.kueue_job_scheduler.make_unique" + "torchx.schedulers.kueue_scheduler.make_unique" ) as make_unique_ctx: make_unique_ctx.return_value = "app-name-42" info = scheduler.submit_dryrun(app, cfg) @@ -478,11 +477,11 @@ def test_rank0_env(self) -> None: scheduler = create_scheduler("test") app = _test_app(num_replicas=2) - cfg = KueueJobOpts( + cfg = KueueOpts( {"namespace": "testnamespace", "local_queue": "default-kueue"} ) with patch( - "torchx.schedulers.kueue_job_scheduler.make_unique" + "torchx.schedulers.kueue_scheduler.make_unique" ) as make_unique_ctx: make_unique_ctx.return_value = "app-name-42" info = scheduler.submit_dryrun(app, cfg) @@ -490,18 +489,18 @@ def test_rank0_env(self) -> None: task = info.request.resource["spec"] container0 = task["template"].spec.containers[0] - self.assertIn("KUEUE_JOB_TRAINERFOO_0_HOSTS", container0.command) + self.assertIn("KUEUE_TRAINERFOO_0_HOSTS", container0.command) self.assertIn(V1EnvVar(name="FOO", value="bar"), container0.env) def test_submit_dryrun_patch(self) -> None: scheduler = create_scheduler("test") app = _test_app() app.roles[0].image = "sha256:testhash" - cfg = KueueJobOpts( + cfg = KueueOpts( {"image_repo": "example.com/some/repo", "local_queue": "default-kueue"} ) with patch( - "torchx.schedulers.kueue_job_scheduler.make_unique" + "torchx.schedulers.kueue_scheduler.make_unique" ) as make_unique_ctx: make_unique_ctx.return_value = "app-name-42" info = scheduler.submit_dryrun(app, cfg) @@ -521,7 +520,7 @@ def test_submit_dryrun_service_account(self) -> None: scheduler = create_scheduler("test") self.assertIn("service_account", scheduler.run_opts()._opts) app = _test_app() - cfg = KueueJobOpts( + cfg = KueueOpts( { "service_account": "srvacc", "local_queue": "default-kueue", @@ -547,7 +546,7 @@ def test_submit(self, create_namespaced_job: MagicMock) -> None: scheduler = create_scheduler("test") app = _test_app() - cfg = KueueJobOpts( + cfg = KueueOpts( { "namespace": "testnamespace", "local_queue": "default-kueue", @@ -571,7 +570,7 @@ def test_submit_no_kueue_label( } scheduler = create_scheduler("test") app = _test_app() - cfg = KueueJobOpts( + cfg = KueueOpts( { "namespace": "testnamespace", } @@ -589,7 +588,7 @@ def test_submit_job_name_conflict(self, create_namespaced_job: MagicMock) -> Non scheduler = create_scheduler("test") app = _test_app() - cfg = KueueJobOpts( + cfg = KueueOpts( { "namespace": "testnamespace", "local_queue": "default-kueue", @@ -698,7 +697,7 @@ def test_describe_unknown( ) def test_runopts(self) -> None: - scheduler = kueue_job_scheduler.create_scheduler("foo") + scheduler = kueue_scheduler.create_scheduler("foo") runopts = scheduler.run_opts() self.assertEqual( set(runopts._opts.keys()), @@ -707,7 +706,7 @@ def test_runopts(self) -> None: "local_queue", "image_repo", "service_account", - "kueue_priority_class", + "priority_class", "annotations", }, ) @@ -732,7 +731,7 @@ def test_cancel_existing(self, delete_namespaced_job: MagicMock) -> None: @patch("kubernetes.client.CustomObjectsApi.list_namespaced_custom_object") def test_list(self, list_namespaced_custom_object: MagicMock) -> None: with patch( - "torchx.schedulers.kueue_job_scheduler.KueueJobScheduler._get_active_context" + "torchx.schedulers.kueue_scheduler.KueueScheduler._get_active_context" ) as test_context: test_context.return_value = TEST_KUBE_CONFIG["contexts"][0] scheduler = create_scheduler("test") @@ -797,7 +796,7 @@ def test_list_values(self, list_namespaced_custom_object: MagicMock) -> None: ], } with patch( - "torchx.schedulers.kueue_job_scheduler.KueueJobScheduler._get_active_context" + "torchx.schedulers.kueue_scheduler.KueueScheduler._get_active_context" ) as test_context: test_context.return_value = TEST_KUBE_CONFIG["contexts"][0] @@ -839,7 +838,7 @@ def test_list_failure(self, list_namespaced_custom_object: MagicMock) -> None: ) list_namespaced_custom_object.side_effect = api_exc with patch( - "torchx.schedulers.kueue_job_scheduler.KueueJobScheduler._get_active_context" + "torchx.schedulers.kueue_scheduler.KueueScheduler._get_active_context" ) as test_context: test_context.return_value = TEST_KUBE_CONFIG["contexts"][0] scheduler = create_scheduler("test") @@ -847,7 +846,7 @@ def test_list_failure(self, list_namespaced_custom_object: MagicMock) -> None: scheduler.list() @patch( - "torchx.schedulers.kueue_job_scheduler.get_pod_name_from_job", + "torchx.schedulers.kueue_scheduler.get_pod_name_from_job", return_value="testjob-roleblah-1-0", ) @patch("kubernetes.client.CoreV1Api.read_namespaced_pod_log") @@ -887,13 +886,13 @@ def test_log_iter( def test_push_patches(self) -> None: client = MagicMock() - scheduler = KueueJobScheduler( + scheduler = KueueScheduler( "foo", client=MagicMock(), docker_client=client, ) - job = KueueJob( + job = Kueue( images_to_push={ "sha256:testimage": ("repo.com/img", "testimage"), }, @@ -928,13 +927,13 @@ def test_min_replicas(self) -> None: def test_submit_dryrun_priority_class(self) -> None: scheduler = create_scheduler("test") - self.assertIn("kueue_priority_class", scheduler.run_opts()._opts) + self.assertIn("priority_class", scheduler.run_opts()._opts) app = _test_app() - cfg = KueueJobOpts( + cfg = KueueOpts( { "namespace": "testnamespace", "local_queue": "default-kueue", - "kueue_priority_class": "sample-priority", + "priority_class": "sample-priority", } ) @@ -944,7 +943,7 @@ def test_submit_dryrun_priority_class(self) -> None: str(info.request.resource), ) - del cfg["kueue_priority_class"] + del cfg["priority_class"] info = scheduler.submit_dryrun(app, cfg) self.assertNotIn( "'kueue.x-k8s.io/priority-class': 'sample-priority'", @@ -955,7 +954,7 @@ def test_submit_dryrun_with_annotations(self) -> None: scheduler = create_scheduler("test") self.assertIn("annotations", scheduler.run_opts()._opts) app = _test_app() - cfg = KueueJobOpts( + cfg = KueueOpts( { "namespace": "testnamespace", "local_queue": "default-kueue", @@ -977,9 +976,9 @@ def test_submit_dryrun_with_annotations(self) -> None: ) -class KueueJobSchedulerNoImportTest(unittest.TestCase): +class KueueSchedulerNoImportTest(unittest.TestCase): """ - KueueJobSchedulerNoImportTest tests the kubernetes scheduler behavior when + KueueSchedulerNoImportTest tests the kubernetes scheduler behavior when Kubernetes is not available. """ @@ -989,9 +988,9 @@ def setUp(self) -> None: if mod.startswith("kubernetes"): sys.modules[mod] = None # pyre-ignore - # reload to ensure kueue_job_scheduler doesn't depend on them at import + # reload to ensure kueue_scheduler doesn't depend on them at import # time - importlib.reload(kueue_job_scheduler) + importlib.reload(kueue_scheduler) importlib.reload(schedulers) def tearDown(self) -> None: @@ -999,22 +998,22 @@ def tearDown(self) -> None: for mod in list(sys.modules.keys()): if mod.startswith("kubernetes"): del sys.modules[mod] - # reimport kueue_job_scheduler to get to a clean state - importlib.reload(kueue_job_scheduler) + # reimport kueue_scheduler to get to a clean state + importlib.reload(kueue_scheduler) def test_runopts(self) -> None: - scheduler = kueue_job_scheduler.create_scheduler("foo") + scheduler = kueue_scheduler.create_scheduler("foo") self.assertIsNotNone(scheduler.run_opts()) def test_describe(self) -> None: - scheduler = kueue_job_scheduler.create_scheduler("foo") + scheduler = kueue_scheduler.create_scheduler("foo") with self.assertRaises(ModuleNotFoundError): scheduler.describe("foo:bar") def test_dryrun(self) -> None: - scheduler = kueue_job_scheduler.create_scheduler("foo") + scheduler = kueue_scheduler.create_scheduler("foo") app = _test_app() - cfg = KueueJobOpts( + cfg = KueueOpts( {"namespace": "testnamespace", "local_queue": "default-kueue"} ) diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 1f3ecd452..55483e711 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -407,7 +407,7 @@ class AppState(int, Enum): 7. CANCELLED - app was cancelled before completing 8. UNKNOWN - app state is unknown 9. SUSPENDED - app is suspended - 10. Resumed - app is resumed + 10. RESUMED - app is resumed """ UNSUBMITTED = 0 diff --git a/torchx/util/role_to_pod.py b/torchx/util/role_to_pod.py new file mode 100644 index 000000000..f52ad0df0 --- /dev/null +++ b/torchx/util/role_to_pod.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torchx.specs.api import ( + BindMount, + DeviceMount, + Role, + VolumeMount, +) +from typing import Dict, Optional +from kubernetes.client.models import V1Pod + +# Constants +RESERVED_MILLICPU = 100 +RESERVED_MEMMB = 1024 +LABEL_INSTANCE_TYPE = "node.kubernetes.io/instance-type" +ANNOTATION_ISTIO_SIDECAR = "sidecar.istio.io/inject" + +def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod": + from kubernetes.client.models import ( # noqa: F811 redefinition of unused + V1Container, + V1ContainerPort, + V1EmptyDirVolumeSource, + V1EnvVar, + V1HostPathVolumeSource, + V1ObjectMeta, + V1PersistentVolumeClaimVolumeSource, + V1Pod, + V1PodSpec, + V1ResourceRequirements, + V1SecurityContext, + V1Volume, + V1VolumeMount, + ) + + # limits puts an upper cap on the resources a pod may consume. + # requests is how much the scheduler allocates. We assume that the jobs will + # be allocation the whole machine so requests is slightly lower than the + # requested resources to account for the Kubernetes node reserved resources. + limits = {} + requests = {} + + resource = role.resource + if resource.cpu > 0: + mcpu = int(resource.cpu * 1000) + limits["cpu"] = f"{mcpu}m" + request_mcpu = max(mcpu - RESERVED_MILLICPU, 0) + requests["cpu"] = f"{request_mcpu}m" + if resource.memMB > 0: + limits["memory"] = f"{int(resource.memMB)}M" + request_memMB = max(int(resource.memMB) - RESERVED_MEMMB, 0) + requests["memory"] = f"{request_memMB}M" + if resource.gpu > 0: + requests["nvidia.com/gpu"] = limits["nvidia.com/gpu"] = str(resource.gpu) + + for device_name, device_limit in resource.devices.items(): + requests[device_name] = str(device_limit) + limits[device_name] = str(device_limit) + + resources = V1ResourceRequirements( + limits=limits, + requests=requests, + ) + + node_selector: Dict[str, str] = {} + if LABEL_INSTANCE_TYPE in resource.capabilities: + node_selector[LABEL_INSTANCE_TYPE] = resource.capabilities[LABEL_INSTANCE_TYPE] + + # To support PyTorch dataloaders we need to set /dev/shm to larger than the + # 64M default so we mount an unlimited sized tmpfs directory on it. + SHM_VOL = "dshm" + volumes = [ + V1Volume( + name=SHM_VOL, + empty_dir=V1EmptyDirVolumeSource( + medium="Memory", + ), + ), + ] + volume_mounts = [ + V1VolumeMount(name=SHM_VOL, mount_path="/dev/shm"), + ] + security_context = V1SecurityContext() + + for i, mount in enumerate(role.mounts): + mount_name = f"mount-{i}" + if isinstance(mount, BindMount): + volumes.append( + V1Volume( + name=mount_name, + host_path=V1HostPathVolumeSource( + path=mount.src_path, + ), + ) + ) + volume_mounts.append( + V1VolumeMount( + name=mount_name, + mount_path=mount.dst_path, + read_only=mount.read_only, + ) + ) + elif isinstance(mount, VolumeMount): + volumes.append( + V1Volume( + name=mount_name, + persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( + claim_name=mount.src, + ), + ) + ) + volume_mounts.append( + V1VolumeMount( + name=mount_name, + mount_path=mount.dst_path, + read_only=mount.read_only, + ) + ) + elif isinstance(mount, DeviceMount): + volumes.append( + V1Volume( + name=mount_name, + host_path=V1HostPathVolumeSource( + path=mount.src_path, + ), + ) + ) + volume_mounts.append( + V1VolumeMount( + name=mount_name, + mount_path=mount.dst_path, + read_only=( + "w" not in mount.permissions and "m" not in mount.permissions + ), + ) + ) + security_context.privileged = True + else: + raise TypeError(f"unknown mount type {mount}") + + container = V1Container( + command=[role.entrypoint] + role.args, + image=role.image, + name=name, + env=[ + V1EnvVar( + name=name, + value=value, + ) + for name, value in role.env.items() + ], + resources=resources, + ports=[ + V1ContainerPort( + name=name, + container_port=port, + ) + for name, port in role.port_map.items() + ], + volume_mounts=volume_mounts, + security_context=security_context, + ) + + return V1Pod( + spec=V1PodSpec( + containers=[container], + restart_policy="Never", + service_account_name=service_account, + volumes=volumes, + node_selector=node_selector, + ), + metadata=V1ObjectMeta( + annotations={ + # Disable the istio sidecar as it prevents the containers from + # exiting once finished. + ANNOTATION_ISTIO_SIDECAR: "false", + }, + labels={}, + ), + ) \ No newline at end of file