From e44dd8e37dbb79c7fd62fb68bf8593f49ff20069 Mon Sep 17 00:00:00 2001 From: Kevin Postlethwait Date: Fri, 20 Oct 2023 09:41:10 -0400 Subject: [PATCH] add functions for creating ray with oauth proxy in front of the dashboard (#298) * add functions for creating ray with oauth proxy in front of the dashboard Signed-off-by: Kevin * add unit test for OAuth create Signed-off-by: Kevin * add tests for replace and generate sidecar Signed-off-by: Kevin --------- Signed-off-by: Kevin --- src/codeflare_sdk/cluster/auth.py | 4 +- src/codeflare_sdk/cluster/cluster.py | 96 +++++++-- src/codeflare_sdk/cluster/config.py | 1 + src/codeflare_sdk/job/jobs.py | 151 +++++++------ src/codeflare_sdk/utils/generate_yaml.py | 95 ++++++++- src/codeflare_sdk/utils/kube_api_helpers.py | 5 + src/codeflare_sdk/utils/openshift_oauth.py | 217 +++++++++++++++++++ tests/unit_test.py | 225 ++++++++++++++++---- 8 files changed, 667 insertions(+), 127 deletions(-) create mode 100644 src/codeflare_sdk/utils/openshift_oauth.py diff --git a/src/codeflare_sdk/cluster/auth.py b/src/codeflare_sdk/cluster/auth.py index eb739136b..1015a8018 100644 --- a/src/codeflare_sdk/cluster/auth.py +++ b/src/codeflare_sdk/cluster/auth.py @@ -25,6 +25,8 @@ import urllib3 from ..utils.kube_api_helpers import _kube_api_error_handling +from typing import Optional + global api_client api_client = None global config_path @@ -188,7 +190,7 @@ def config_check() -> str: return config_path -def api_config_handler() -> str: +def api_config_handler() -> Optional[client.ApiClient]: """ This function is used to load the api client if the user has logged in """ diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 5d00cdae8..29c026bdc 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -21,12 +21,19 @@ from time import sleep from typing import List, Optional, Tuple, Dict +import openshift as oc +from kubernetes import config from ray.job_submission import JobSubmissionClient +import urllib3 from .auth import config_check, api_config_handler from ..utils import pretty_print from ..utils.generate_yaml import generate_appwrapper from ..utils.kube_api_helpers import _kube_api_error_handling +from ..utils.openshift_oauth import ( + create_openshift_oauth_objects, + delete_openshift_oauth_objects, +) from .config import ClusterConfiguration from .model import ( AppWrapper, @@ -40,6 +47,8 @@ import os import requests +from kubernetes import config + class Cluster: """ @@ -61,6 +70,39 @@ def __init__(self, config: ClusterConfiguration): self.config = config self.app_wrapper_yaml = self.create_app_wrapper() self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0] + self._client = None + + @property + def _client_headers(self): + k8_client = api_config_handler() or client.ApiClient() + return { + "Authorization": k8_client.configuration.get_api_key_with_prefix( + "authorization" + ) + } + + @property + def _client_verify_tls(self): + return not self.config.openshift_oauth + + @property + def client(self): + if self._client: + return self._client + if self.config.openshift_oauth: + print( + api_config_handler().configuration.get_api_key_with_prefix( + "authorization" + ) + ) + self._client = JobSubmissionClient( + self.cluster_dashboard_uri(), + headers=self._client_headers, + verify=self._client_verify_tls, + ) + else: + self._client = JobSubmissionClient(self.cluster_dashboard_uri()) + return self._client def evaluate_dispatch_priority(self): priority_class = self.config.dispatch_priority @@ -147,6 +189,7 @@ def create_app_wrapper(self): image_pull_secrets=image_pull_secrets, dispatch_priority=dispatch_priority, priority_val=priority_val, + openshift_oauth=self.config.openshift_oauth, ) # creates a new cluster with the provided or default spec @@ -156,6 +199,11 @@ def up(self): the MCAD queue. """ namespace = self.config.namespace + if self.config.openshift_oauth: + create_openshift_oauth_objects( + cluster_name=self.config.name, namespace=namespace + ) + try: config_check() api_instance = client.CustomObjectsApi(api_config_handler()) @@ -190,6 +238,11 @@ def down(self): except Exception as e: # pragma: no cover return _kube_api_error_handling(e) + if self.config.openshift_oauth: + delete_openshift_oauth_objects( + cluster_name=self.config.name, namespace=namespace + ) + def status( self, print_to_console: bool = True ) -> Tuple[CodeFlareClusterStatus, bool]: @@ -258,7 +311,16 @@ def status( return status, ready def is_dashboard_ready(self) -> bool: - response = requests.get(self.cluster_dashboard_uri(), timeout=5) + try: + response = requests.get( + self.cluster_dashboard_uri(), + headers=self._client_headers, + timeout=5, + verify=self._client_verify_tls, + ) + except requests.exceptions.SSLError: + # SSL exception occurs when oauth ingress has been created but cluster is not up + return False if response.status_code == 200: return True else: @@ -330,7 +392,13 @@ def cluster_dashboard_uri(self) -> str: return _kube_api_error_handling(e) for route in routes["items"]: - if route["metadata"]["name"] == f"ray-dashboard-{self.config.name}": + if route["metadata"][ + "name" + ] == f"ray-dashboard-{self.config.name}" or route["metadata"][ + "name" + ].startswith( + f"{self.config.name}-ingress" + ): protocol = "https" if route["spec"].get("tls") else "http" return f"{protocol}://{route['spec']['host']}" return "Dashboard route not available yet, have you run cluster.up()?" @@ -339,30 +407,24 @@ def list_jobs(self) -> List: """ This method accesses the head ray node in your cluster and lists the running jobs. """ - dashboard_route = self.cluster_dashboard_uri() - client = JobSubmissionClient(dashboard_route) - return client.list_jobs() + return self.client.list_jobs() def job_status(self, job_id: str) -> str: """ This method accesses the head ray node in your cluster and returns the job status for the provided job id. """ - dashboard_route = self.cluster_dashboard_uri() - client = JobSubmissionClient(dashboard_route) - return client.get_job_status(job_id) + return self.client.get_job_status(job_id) def job_logs(self, job_id: str) -> str: """ This method accesses the head ray node in your cluster and returns the logs for the provided job id. """ - dashboard_route = self.cluster_dashboard_uri() - client = JobSubmissionClient(dashboard_route) - return client.get_job_logs(job_id) + return self.client.get_job_logs(job_id) def torchx_config( self, working_dir: str = None, requirements: str = None ) -> Dict[str, str]: - dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}" + dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host to_return = { "cluster_name": self.config.name, "dashboard_address": dashboard_address, @@ -591,7 +653,7 @@ def _get_app_wrappers( def _map_to_ray_cluster(rc) -> Optional[RayCluster]: - if "status" in rc and "state" in rc["status"]: + if "state" in rc["status"]: status = RayClusterStatus(rc["status"]["state"].lower()) else: status = RayClusterStatus.UNKNOWN @@ -606,7 +668,13 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]: ) ray_route = None for route in routes["items"]: - if route["metadata"]["name"] == f"ray-dashboard-{rc['metadata']['name']}": + if route["metadata"][ + "name" + ] == f"ray-dashboard-{rc['metadata']['name']}" or route["metadata"][ + "name" + ].startswith( + f"{rc['metadata']['name']}-ingress" + ): protocol = "https" if route["spec"].get("tls") else "http" ray_route = f"{protocol}://{route['spec']['host']}" diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 30875a98f..fe83e9e55 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -51,3 +51,4 @@ class ClusterConfiguration: local_interactive: bool = False image_pull_secrets: list = field(default_factory=list) dispatch_priority: str = None + openshift_oauth: bool = False # NOTE: to use the user must have permission to create a RoleBinding for system:auth-delegator diff --git a/src/codeflare_sdk/job/jobs.py b/src/codeflare_sdk/job/jobs.py index b9bb9cdc1..27f15283d 100644 --- a/src/codeflare_sdk/job/jobs.py +++ b/src/codeflare_sdk/job/jobs.py @@ -18,15 +18,19 @@ from pathlib import Path from torchx.components.dist import ddp -from torchx.runner import get_runner +from torchx.runner import get_runner, Runner +from torchx.schedulers.ray_scheduler import RayScheduler from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo +from ray.job_submission import JobSubmissionClient + +import openshift as oc + if TYPE_CHECKING: from ..cluster.cluster import Cluster from ..cluster.cluster import get_current_namespace all_jobs: List["Job"] = [] -torchx_runner = get_runner() class JobDefinition(metaclass=abc.ABCMeta): @@ -92,30 +96,37 @@ def __init__( def _dry_run(self, cluster: "Cluster"): j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus - return torchx_runner.dryrun( - app=ddp( - *self.script_args, - script=self.script, - m=self.m, - name=self.name, - h=self.h, - cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus, - gpu=self.gpu if self.gpu is not None else cluster.config.num_gpus, - memMB=self.memMB - if self.memMB is not None - else cluster.config.max_memory * 1024, - j=self.j if self.j is not None else j, - env=self.env, - max_retries=self.max_retries, - rdzv_port=self.rdzv_port, - rdzv_backend=self.rdzv_backend - if self.rdzv_backend is not None - else "static", - mounts=self.mounts, + runner = get_runner(ray_client=cluster.client) + runner._scheduler_instances["ray"] = RayScheduler( + session_name=runner._name, ray_client=cluster.client + ) + return ( + runner.dryrun( + app=ddp( + *self.script_args, + script=self.script, + m=self.m, + name=self.name, + h=self.h, + cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus, + gpu=self.gpu if self.gpu is not None else cluster.config.num_gpus, + memMB=self.memMB + if self.memMB is not None + else cluster.config.max_memory * 1024, + j=self.j if self.j is not None else j, + env=self.env, + max_retries=self.max_retries, + rdzv_port=self.rdzv_port, + rdzv_backend=self.rdzv_backend + if self.rdzv_backend is not None + else "static", + mounts=self.mounts, + ), + scheduler=cluster.torchx_scheduler, + cfg=cluster.torchx_config(**self.scheduler_args), + workspace=self.workspace, ), - scheduler=cluster.torchx_scheduler, - cfg=cluster.torchx_config(**self.scheduler_args), - workspace=self.workspace, + runner, ) def _missing_spec(self, spec: str): @@ -125,41 +136,47 @@ def _dry_run_no_cluster(self): if self.scheduler_args is not None: if self.scheduler_args.get("namespace") is None: self.scheduler_args["namespace"] = get_current_namespace() - return torchx_runner.dryrun( - app=ddp( - *self.script_args, - script=self.script, - m=self.m, - name=self.name if self.name is not None else self._missing_spec("name"), - h=self.h, - cpu=self.cpu - if self.cpu is not None - else self._missing_spec("cpu (# cpus per worker)"), - gpu=self.gpu - if self.gpu is not None - else self._missing_spec("gpu (# gpus per worker)"), - memMB=self.memMB - if self.memMB is not None - else self._missing_spec("memMB (memory in MB)"), - j=self.j - if self.j is not None - else self._missing_spec( - "j (`workers`x`procs`)" - ), # # of proc. = # of gpus, - env=self.env, # should this still exist? - max_retries=self.max_retries, - rdzv_port=self.rdzv_port, # should this still exist? - rdzv_backend=self.rdzv_backend - if self.rdzv_backend is not None - else "c10d", - mounts=self.mounts, - image=self.image - if self.image is not None - else self._missing_spec("image"), + runner = get_runner() + return ( + runner.dryrun( + app=ddp( + *self.script_args, + script=self.script, + m=self.m, + name=self.name + if self.name is not None + else self._missing_spec("name"), + h=self.h, + cpu=self.cpu + if self.cpu is not None + else self._missing_spec("cpu (# cpus per worker)"), + gpu=self.gpu + if self.gpu is not None + else self._missing_spec("gpu (# gpus per worker)"), + memMB=self.memMB + if self.memMB is not None + else self._missing_spec("memMB (memory in MB)"), + j=self.j + if self.j is not None + else self._missing_spec( + "j (`workers`x`procs`)" + ), # # of proc. = # of gpus, + env=self.env, # should this still exist? + max_retries=self.max_retries, + rdzv_port=self.rdzv_port, # should this still exist? + rdzv_backend=self.rdzv_backend + if self.rdzv_backend is not None + else "c10d", + mounts=self.mounts, + image=self.image + if self.image is not None + else self._missing_spec("image"), + ), + scheduler="kubernetes_mcad", + cfg=self.scheduler_args, + workspace="", ), - scheduler="kubernetes_mcad", - cfg=self.scheduler_args, - workspace="", + runner, ) def submit(self, cluster: "Cluster" = None) -> "Job": @@ -171,18 +188,20 @@ def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster" = None self.job_definition = job_definition self.cluster = cluster if self.cluster: - self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster)) + definition, runner = job_definition._dry_run(cluster) + self._app_handle = runner.schedule(definition) + self._runner = runner else: - self._app_handle = torchx_runner.schedule( - job_definition._dry_run_no_cluster() - ) + definition, runner = job_definition._dry_run_no_cluster() + self._app_handle = runner.schedule(definition) + self._runner = runner all_jobs.append(self) def status(self) -> str: - return torchx_runner.status(self._app_handle) + return self._runner.status(self._app_handle) def logs(self) -> str: - return "".join(torchx_runner.log_lines(self._app_handle, None)) + return "".join(self._runner.log_lines(self._app_handle, None)) def cancel(self): - torchx_runner.cancel(self._app_handle) + self._runner.cancel(self._app_handle) diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 95e1c5ecb..4757f5370 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -24,6 +24,13 @@ from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling from ..cluster.auth import api_config_handler, config_check +from os import urandom +from base64 import b64encode +from urllib3.util import parse_url + +from kubernetes import client, config + +from .kube_api_helpers import _get_api_host def read_template(template): @@ -46,13 +53,17 @@ def gen_names(name): def update_dashboard_route(route_item, cluster_name, namespace): metadata = route_item.get("generictemplate", {}).get("metadata") - metadata["name"] = f"ray-dashboard-{cluster_name}" + metadata["name"] = gen_dashboard_route_name(cluster_name) metadata["namespace"] = namespace metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc" spec = route_item.get("generictemplate", {}).get("spec") spec["to"]["name"] = f"{cluster_name}-head-svc" +def gen_dashboard_route_name(cluster_name): + return f"ray-dashboard-{cluster_name}" + + # ToDo: refactor the update_x_route() functions def update_rayclient_route(route_item, cluster_name, namespace): metadata = route_item.get("generictemplate", {}).get("metadata") @@ -369,6 +380,83 @@ def write_user_appwrapper(user_yaml, output_file_name): print(f"Written to: {output_file_name}") +def enable_openshift_oauth(user_yaml, cluster_name, namespace): + config_check() + k8_client = api_config_handler() or client.ApiClient() + tls_mount_location = "/etc/tls/private" + oauth_port = 8443 + oauth_sa_name = f"{cluster_name}-oauth-proxy" + tls_secret_name = f"{cluster_name}-proxy-tls-secret" + tls_volume_name = "proxy-tls-secret" + port_name = "oauth-proxy" + host = _get_api_host(k8_client) + host = host.replace( + "api.", f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps." + ) + oauth_sidecar = _create_oauth_sidecar_object( + namespace, + tls_mount_location, + oauth_port, + oauth_sa_name, + tls_volume_name, + port_name, + ) + tls_secret_volume = client.V1Volume( + name=tls_volume_name, + secret=client.V1SecretVolumeSource(secret_name=tls_secret_name), + ) + # allows for setting value of Cluster object when initializing object from an existing AppWrapper on cluster + user_yaml["metadata"]["annotations"] = user_yaml["metadata"].get("annotations", {}) + user_yaml["metadata"]["annotations"][ + "codeflare-sdk-use-oauth" + ] = "true" # if the user gets an + ray_headgroup_pod = user_yaml["spec"]["resources"]["GenericItems"][0][ + "generictemplate" + ]["spec"]["headGroupSpec"]["template"]["spec"] + user_yaml["spec"]["resources"]["GenericItems"].pop(1) + ray_headgroup_pod["serviceAccount"] = oauth_sa_name + ray_headgroup_pod["volumes"] = ray_headgroup_pod.get("volumes", []) + + # we use a generic api client here so that the serialization function doesn't need to be mocked for unit tests + ray_headgroup_pod["volumes"].append( + client.ApiClient().sanitize_for_serialization(tls_secret_volume) + ) + ray_headgroup_pod["containers"].append( + client.ApiClient().sanitize_for_serialization(oauth_sidecar) + ) + + +def _create_oauth_sidecar_object( + namespace: str, + tls_mount_location: str, + oauth_port: int, + oauth_sa_name: str, + tls_volume_name: str, + port_name: str, +) -> client.V1Container: + return client.V1Container( + args=[ + f"--https-address=:{oauth_port}", + "--provider=openshift", + f"--openshift-service-account={oauth_sa_name}", + "--upstream=http://localhost:8265", + f"--tls-cert={tls_mount_location}/tls.crt", + f"--tls-key={tls_mount_location}/tls.key", + f"--cookie-secret={b64encode(urandom(64)).decode('utf-8')}", # create random string for encrypting cookie + f'--openshift-delegate-urls={{"/":{{"resource":"pods","namespace":"{namespace}","verb":"get"}}}}', + ], + image="registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366", + name="oauth-proxy", + ports=[client.V1ContainerPort(container_port=oauth_port, name=port_name)], + resources=client.V1ResourceRequirements(limits=None, requests=None), + volume_mounts=[ + client.V1VolumeMount( + mount_path=tls_mount_location, name=tls_volume_name, read_only=True + ) + ], + ) + + def generate_appwrapper( name: str, namespace: str, @@ -390,6 +478,7 @@ def generate_appwrapper( image_pull_secrets: list, dispatch_priority: str, priority_val: int, + openshift_oauth: bool, ): user_yaml = read_template(template) appwrapper_name, cluster_name = gen_names(name) @@ -433,6 +522,10 @@ def generate_appwrapper( enable_local_interactive(resources, cluster_name, namespace) else: disable_raycluster_tls(resources["resources"]) + + if openshift_oauth: + enable_openshift_oauth(user_yaml, cluster_name, namespace) + outfile = appwrapper_name + ".yaml" write_user_appwrapper(user_yaml, outfile) return outfile diff --git a/src/codeflare_sdk/utils/kube_api_helpers.py b/src/codeflare_sdk/utils/kube_api_helpers.py index 58358a053..8f8180b97 100644 --- a/src/codeflare_sdk/utils/kube_api_helpers.py +++ b/src/codeflare_sdk/utils/kube_api_helpers.py @@ -19,6 +19,7 @@ import executing from kubernetes import client, config +from urllib3.util import parse_url # private methods @@ -42,3 +43,7 @@ def _kube_api_error_handling(e: Exception): # pragma: no cover elif e.reason == "Conflict": raise FileExistsError(exists_msg) raise e + + +def _get_api_host(api_client: client.ApiClient): # pragma: no cover + return parse_url(api_client.configuration.host).host diff --git a/src/codeflare_sdk/utils/openshift_oauth.py b/src/codeflare_sdk/utils/openshift_oauth.py new file mode 100644 index 000000000..5c3fc55aa --- /dev/null +++ b/src/codeflare_sdk/utils/openshift_oauth.py @@ -0,0 +1,217 @@ +from urllib3.util import parse_url +from .generate_yaml import gen_dashboard_route_name +from .kube_api_helpers import _get_api_host +from base64 import b64decode + +from ..cluster.auth import config_check, api_config_handler + +from kubernetes import client + + +def create_openshift_oauth_objects(cluster_name, namespace): + config_check() + api_client = api_config_handler() or client.ApiClient() + oauth_port = 8443 + oauth_sa_name = f"{cluster_name}-oauth-proxy" + tls_secret_name = _gen_tls_secret_name(cluster_name) + service_name = f"{cluster_name}-oauth" + port_name = "oauth-proxy" + host = _get_api_host(api_client) + + # replace "^api" with the expected host + host = f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps" + host.lstrip( + "api" + ) + + _create_or_replace_oauth_sa(namespace, oauth_sa_name, host) + _create_or_replace_oauth_service_obj( + cluster_name, namespace, oauth_port, tls_secret_name, service_name, port_name + ) + _create_or_replace_oauth_ingress_object( + cluster_name, namespace, service_name, port_name, host + ) + _create_or_replace_oauth_rb(cluster_name, namespace, oauth_sa_name) + + +def _create_or_replace_oauth_sa(namespace, oauth_sa_name, host): + api_client = api_config_handler() + oauth_sa = client.V1ServiceAccount( + api_version="v1", + kind="ServiceAccount", + metadata=client.V1ObjectMeta( + name=oauth_sa_name, + namespace=namespace, + annotations={ + "serviceaccounts.openshift.io/oauth-redirecturi.first": f"https://{host}" + }, + ), + ) + try: + client.CoreV1Api(api_client).create_namespaced_service_account( + namespace=namespace, body=oauth_sa + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.CoreV1Api(api_client).replace_namespaced_service_account( + namespace=namespace, + body=oauth_sa, + name=oauth_sa_name, + ) + else: + raise e + + +def _create_or_replace_oauth_rb(cluster_name, namespace, oauth_sa_name): + api_client = api_config_handler() + oauth_crb = client.V1ClusterRoleBinding( + api_version="rbac.authorization.k8s.io/v1", + kind="ClusterRoleBinding", + metadata=client.V1ObjectMeta(name=f"{cluster_name}-rb"), + role_ref=client.V1RoleRef( + api_group="rbac.authorization.k8s.io", + kind="ClusterRole", + name="system:auth-delegator", + ), + subjects=[ + client.V1Subject( + kind="ServiceAccount", name=oauth_sa_name, namespace=namespace + ) + ], + ) + try: + client.RbacAuthorizationV1Api(api_client).create_cluster_role_binding( + body=oauth_crb + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.RbacAuthorizationV1Api(api_client).replace_cluster_role_binding( + body=oauth_crb, name=f"{cluster_name}-rb" + ) + else: + raise e + + +def _gen_tls_secret_name(cluster_name): + return f"{cluster_name}-proxy-tls-secret" + + +def delete_openshift_oauth_objects(cluster_name, namespace): + # NOTE: it might be worth adding error handling here, but shouldn't be necessary because cluster.down(...) checks + # for an existing cluster before calling this => the objects should never be deleted twice + api_client = api_config_handler() + oauth_sa_name = f"{cluster_name}-oauth-proxy" + service_name = f"{cluster_name}-oauth" + client.CoreV1Api(api_client).delete_namespaced_service_account( + name=oauth_sa_name, namespace=namespace + ) + client.CoreV1Api(api_client).delete_namespaced_service( + name=service_name, namespace=namespace + ) + client.NetworkingV1Api(api_client).delete_namespaced_ingress( + name=f"{cluster_name}-ingress", namespace=namespace + ) + client.RbacAuthorizationV1Api(api_client).delete_cluster_role_binding( + name=f"{cluster_name}-rb" + ) + + +def _create_or_replace_oauth_service_obj( + cluster_name: str, + namespace: str, + oauth_port: int, + tls_secret_name: str, + service_name: str, + port_name: str, +) -> client.V1Service: + api_client = api_config_handler() + oauth_service = client.V1Service( + api_version="v1", + kind="Service", + metadata=client.V1ObjectMeta( + annotations={ + "service.beta.openshift.io/serving-cert-secret-name": tls_secret_name + }, + name=service_name, + namespace=namespace, + ), + spec=client.V1ServiceSpec( + ports=[ + client.V1ServicePort( + name=port_name, + protocol="TCP", + port=443, + target_port=oauth_port, + ) + ], + selector={ + "app.kubernetes.io/created-by": "kuberay-operator", + "app.kubernetes.io/name": "kuberay", + "ray.io/cluster": cluster_name, + "ray.io/identifier": f"{cluster_name}-head", + "ray.io/node-type": "head", + }, + ), + ) + try: + client.CoreV1Api(api_client).create_namespaced_service( + namespace=namespace, body=oauth_service + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.CoreV1Api(api_client).replace_namespaced_service( + namespace=namespace, body=oauth_service, name=service_name + ) + else: + raise e + + +def _create_or_replace_oauth_ingress_object( + cluster_name: str, + namespace: str, + service_name: str, + port_name: str, + host: str, +) -> client.V1Ingress: + api_client = api_config_handler() + ingress = client.V1Ingress( + api_version="networking.k8s.io/v1", + kind="Ingress", + metadata=client.V1ObjectMeta( + annotations={"route.openshift.io/termination": "passthrough"}, + name=f"{cluster_name}-ingress", + namespace=namespace, + ), + spec=client.V1IngressSpec( + rules=[ + client.V1IngressRule( + host=host, + http=client.V1HTTPIngressRuleValue( + paths=[ + client.V1HTTPIngressPath( + backend=client.V1IngressBackend( + service=client.V1IngressServiceBackend( + name=service_name, + port=client.V1ServiceBackendPort( + name=port_name + ), + ) + ), + path_type="ImplementationSpecific", + ) + ] + ), + ) + ] + ), + ) + try: + client.NetworkingV1Api(api_client).create_namespaced_ingress( + namespace=namespace, body=ingress + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.NetworkingV1Api(api_client).replace_namespaced_ingress( + namespace=namespace, body=ingress, name=f"{cluster_name}-ingress" + ) + else: + raise e diff --git a/tests/unit_test.py b/tests/unit_test.py index 9b78e9e2f..f2c86f1f9 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: replace all instances of torchx_runner + from pathlib import Path import sys import filecmp @@ -38,8 +40,8 @@ Authentication, KubeConfigFileAuthentication, config_check, - api_config_handler, ) +from codeflare_sdk.utils.openshift_oauth import create_openshift_oauth_objects from codeflare_sdk.utils.pretty_print import ( print_no_resources_found, print_app_wrappers_status, @@ -58,7 +60,6 @@ Job, DDPJobDefinition, DDPJob, - torchx_runner, ) from codeflare_sdk.utils.generate_cert import ( generate_ca_cert, @@ -74,6 +75,8 @@ createDDPJob_with_cluster, ) +import codeflare_sdk.utils.kube_api_helpers + import openshift from openshift.selector import Selector import ray @@ -83,7 +86,9 @@ from torchx.schedulers.kubernetes_mcad_scheduler import KubernetesMCADJob import pytest import yaml - +from unittest.mock import MagicMock +from pytest_mock import MockerFixture +from ray.job_submission import JobSubmissionClient # For mocking openshift client results fake_res = openshift.Result("fake") @@ -1835,7 +1840,7 @@ def test_DDPJobDefinition_creation(): assert ddp.scheduler_args == {"requirements": "test"} -def test_DDPJobDefinition_dry_run(mocker): +def test_DDPJobDefinition_dry_run(mocker: MockerFixture): """ Test that the dry run method returns the correct type: AppDryRunInfo, that the attributes of the returned object are of the correct type, @@ -1846,9 +1851,10 @@ def test_DDPJobDefinition_dry_run(mocker): "codeflare_sdk.cluster.cluster.Cluster.cluster_dashboard_uri", return_value="", ) + mocker.patch.object(Cluster, "client") ddp = createTestDDP() cluster = createClusterWithConfig() - ddp_job = ddp._dry_run(cluster) + ddp_job, _ = ddp._dry_run(cluster) assert type(ddp_job) == AppDryRunInfo assert ddp_job._fmt is not None assert type(ddp_job.request) == RayJob @@ -1884,7 +1890,7 @@ def test_DDPJobDefinition_dry_run_no_cluster(mocker): ddp = createTestDDP() ddp.image = "fake-image" - ddp_job = ddp._dry_run_no_cluster() + ddp_job, _ = ddp._dry_run_no_cluster() assert type(ddp_job) == AppDryRunInfo assert ddp_job._fmt is not None assert type(ddp_job.request) == KubernetesMCADJob @@ -1915,6 +1921,7 @@ def test_DDPJobDefinition_dry_run_no_resource_args(mocker): Test that the dry run correctly gets resources from the cluster object when the job definition does not specify resources. """ + mocker.patch.object(Cluster, "client") mocker.patch( "codeflare_sdk.cluster.cluster.Cluster.cluster_dashboard_uri", return_value="", @@ -1932,7 +1939,7 @@ def test_DDPJobDefinition_dry_run_no_resource_args(mocker): rdzv_port=29500, scheduler_args={"requirements": "test"}, ) - ddp_job = ddp._dry_run(cluster) + ddp_job, _ = ddp._dry_run(cluster) assert ddp_job._app.roles[0].resource.cpu == cluster.config.max_cpus assert ddp_job._app.roles[0].resource.gpu == cluster.config.num_gpus @@ -1998,25 +2005,24 @@ def test_DDPJobDefinition_dry_run_no_cluster_no_resource_args(mocker): assert str(e) == "Job definition missing arg: j (`workers`x`procs`)" -def test_DDPJobDefinition_submit(mocker): +def test_DDPJobDefinition_submit(mocker: MockerFixture): """ Tests that the submit method returns the correct type: DDPJob And that the attributes of the returned object are of the correct type """ - mocker.patch( - "codeflare_sdk.cluster.cluster.Cluster.cluster_dashboard_uri", - return_value="fake-dashboard-uri", - ) + mock_schedule = MagicMock() + mocker.patch.object(Runner, "schedule", mock_schedule) + mock_schedule.return_value = "fake-dashboard-url" + mocker.patch.object(Cluster, "client") ddp_def = createTestDDP() cluster = createClusterWithConfig() mocker.patch( "codeflare_sdk.job.jobs.get_current_namespace", side_effect="opendatahub", ) - mocker.patch( - "codeflare_sdk.job.jobs.torchx_runner.schedule", - return_value="fake-dashboard-url", - ) # a fake app_handle + mocker.patch.object( + Cluster, "cluster_dashboard_uri", return_value="fake-dashboard-url" + ) ddp_job = ddp_def.submit(cluster) assert type(ddp_job) == DDPJob assert type(ddp_job.job_definition) == DDPJobDefinition @@ -2033,24 +2039,23 @@ def test_DDPJobDefinition_submit(mocker): assert ddp_job._app_handle == "fake-dashboard-url" -def test_DDPJob_creation(mocker): - mocker.patch( - "codeflare_sdk.cluster.cluster.Cluster.cluster_dashboard_uri", - return_value="fake-dashboard-uri", +def test_DDPJob_creation(mocker: MockerFixture): + mocker.patch.object(Cluster, "client") + mock_schedule = MagicMock() + mocker.patch.object(Runner, "schedule", mock_schedule) + mocker.patch.object( + Cluster, "cluster_dashboard_uri", return_value="fake-dashboard-url" ) ddp_def = createTestDDP() cluster = createClusterWithConfig() - mocker.patch( - "codeflare_sdk.job.jobs.torchx_runner.schedule", - return_value="fake-dashboard-url", - ) # a fake app_handle + mock_schedule.return_value = "fake-dashboard-url" ddp_job = createDDPJob_with_cluster(ddp_def, cluster) assert type(ddp_job) == DDPJob assert type(ddp_job.job_definition) == DDPJobDefinition assert type(ddp_job.cluster) == Cluster assert type(ddp_job._app_handle) == str assert ddp_job._app_handle == "fake-dashboard-url" - _, args, kwargs = torchx_runner.schedule.mock_calls[0] + _, args, kwargs = mock_schedule.mock_calls[0] assert type(args[0]) == AppDryRunInfo job_info = args[0] assert type(job_info.request) == RayJob @@ -2059,24 +2064,23 @@ def test_DDPJob_creation(mocker): assert type(job_info._scheduler) == type(str()) -def test_DDPJob_creation_no_cluster(mocker): +def test_DDPJob_creation_no_cluster(mocker: MockerFixture): ddp_def = createTestDDP() ddp_def.image = "fake-image" mocker.patch( "codeflare_sdk.job.jobs.get_current_namespace", side_effect="opendatahub", ) - mocker.patch( - "codeflare_sdk.job.jobs.torchx_runner.schedule", - return_value="fake-app-handle", - ) # a fake app_handle + mock_schedule = MagicMock() + mocker.patch.object(Runner, "schedule", mock_schedule) + mock_schedule.return_value = "fake-app-handle" ddp_job = createDDPJob_no_cluster(ddp_def, None) assert type(ddp_job) == DDPJob assert type(ddp_job.job_definition) == DDPJobDefinition assert ddp_job.cluster == None assert type(ddp_job._app_handle) == str assert ddp_job._app_handle == "fake-app-handle" - _, args, kwargs = torchx_runner.schedule.mock_calls[0] + _, args, kwargs = mock_schedule.mock_calls[0] assert type(args[0]) == AppDryRunInfo job_info = args[0] assert type(job_info.request) == KubernetesMCADJob @@ -2085,31 +2089,31 @@ def test_DDPJob_creation_no_cluster(mocker): assert type(job_info._scheduler) == type(str()) -def test_DDPJob_status(mocker): +def test_DDPJob_status(mocker: MockerFixture): # Setup the neccesary mock patches + mock_status = MagicMock() + mocker.patch.object(Runner, "status", mock_status) test_DDPJob_creation(mocker) ddp_def = createTestDDP() cluster = createClusterWithConfig() ddp_job = createDDPJob_with_cluster(ddp_def, cluster) - mocker.patch( - "codeflare_sdk.job.jobs.torchx_runner.status", return_value="fake-status" - ) + mock_status.return_value = "fake-status" assert ddp_job.status() == "fake-status" - _, args, kwargs = torchx_runner.status.mock_calls[0] + _, args, kwargs = mock_status.mock_calls[0] assert args[0] == "fake-dashboard-url" -def test_DDPJob_logs(mocker): +def test_DDPJob_logs(mocker: MockerFixture): + mock_log = MagicMock() + mocker.patch.object(Runner, "log_lines", mock_log) # Setup the neccesary mock patches test_DDPJob_creation(mocker) ddp_def = createTestDDP() cluster = createClusterWithConfig() ddp_job = createDDPJob_with_cluster(ddp_def, cluster) - mocker.patch( - "codeflare_sdk.job.jobs.torchx_runner.log_lines", return_value="fake-logs" - ) + mock_log.return_value = "fake-logs" assert ddp_job.logs() == "fake-logs" - _, args, kwargs = torchx_runner.log_lines.mock_calls[0] + _, args, kwargs = mock_log.mock_calls[0] assert args[0] == "fake-dashboard-url" @@ -2117,7 +2121,9 @@ def arg_check_side_effect(*args): assert args[0] == "fake-app-handle" -def test_DDPJob_cancel(mocker): +def test_DDPJob_cancel(mocker: MockerFixture): + mock_cancel = MagicMock() + mocker.patch.object(Runner, "cancel", mock_cancel) # Setup the neccesary mock patches test_DDPJob_creation_no_cluster(mocker) ddp_def = createTestDDP() @@ -2127,9 +2133,7 @@ def test_DDPJob_cancel(mocker): "openshift.get_project_name", return_value="opendatahub", ) - mocker.patch( - "codeflare_sdk.job.jobs.torchx_runner.cancel", side_effect=arg_check_side_effect - ) + mock_cancel.side_effect = arg_check_side_effect ddp_job.cancel() @@ -2292,6 +2296,137 @@ def test_export_env(): ) +def test_create_openshift_oauth(mocker: MockerFixture): + create_namespaced_service_account = MagicMock() + create_cluster_role_binding = MagicMock() + create_namespaced_service = MagicMock() + create_namespaced_ingress = MagicMock() + mocker.patch.object( + client.CoreV1Api, + "create_namespaced_service_account", + create_namespaced_service_account, + ) + mocker.patch.object( + client.RbacAuthorizationV1Api, + "create_cluster_role_binding", + create_cluster_role_binding, + ) + mocker.patch.object( + client.CoreV1Api, "create_namespaced_service", create_namespaced_service + ) + mocker.patch.object( + client.NetworkingV1Api, "create_namespaced_ingress", create_namespaced_ingress + ) + mocker.patch( + "codeflare_sdk.utils.openshift_oauth._get_api_host", return_value="foo.com" + ) + create_openshift_oauth_objects("foo", "bar") + create_ns_sa_args = create_namespaced_service_account.call_args + create_crb_args = create_cluster_role_binding.call_args + create_ns_serv_args = create_namespaced_service.call_args + create_ns_ingress_args = create_namespaced_ingress.call_args + assert ( + create_ns_sa_args.kwargs["namespace"] == create_ns_serv_args.kwargs["namespace"] + ) + assert ( + create_ns_serv_args.kwargs["namespace"] + == create_ns_ingress_args.kwargs["namespace"] + ) + assert isinstance(create_ns_sa_args.kwargs["body"], client.V1ServiceAccount) + assert isinstance(create_crb_args.kwargs["body"], client.V1ClusterRoleBinding) + assert isinstance(create_ns_serv_args.kwargs["body"], client.V1Service) + assert isinstance(create_ns_ingress_args.kwargs["body"], client.V1Ingress) + assert ( + create_ns_serv_args.kwargs["body"].spec.ports[0].name + == create_ns_ingress_args.kwargs["body"] + .spec.rules[0] + .http.paths[0] + .backend.service.port.name + ) + + +def test_replace_openshift_oauth(mocker: MockerFixture): + # not_found_exception = client.ApiException(reason="Conflict") + create_namespaced_service_account = MagicMock( + side_effect=client.ApiException(reason="Conflict") + ) + create_cluster_role_binding = MagicMock( + side_effect=client.ApiException(reason="Conflict") + ) + create_namespaced_service = MagicMock( + side_effect=client.ApiException(reason="Conflict") + ) + create_namespaced_ingress = MagicMock( + side_effect=client.ApiException(reason="Conflict") + ) + mocker.patch.object( + client.CoreV1Api, + "create_namespaced_service_account", + create_namespaced_service_account, + ) + mocker.patch.object( + client.RbacAuthorizationV1Api, + "create_cluster_role_binding", + create_cluster_role_binding, + ) + mocker.patch.object( + client.CoreV1Api, "create_namespaced_service", create_namespaced_service + ) + mocker.patch.object( + client.NetworkingV1Api, "create_namespaced_ingress", create_namespaced_ingress + ) + mocker.patch( + "codeflare_sdk.utils.openshift_oauth._get_api_host", return_value="foo.com" + ) + replace_namespaced_service_account = MagicMock() + replace_cluster_role_binding = MagicMock() + replace_namespaced_service = MagicMock() + replace_namespaced_ingress = MagicMock() + mocker.patch.object( + client.CoreV1Api, + "replace_namespaced_service_account", + replace_namespaced_service_account, + ) + mocker.patch.object( + client.RbacAuthorizationV1Api, + "replace_cluster_role_binding", + replace_cluster_role_binding, + ) + mocker.patch.object( + client.CoreV1Api, "replace_namespaced_service", replace_namespaced_service + ) + mocker.patch.object( + client.NetworkingV1Api, "replace_namespaced_ingress", replace_namespaced_ingress + ) + create_openshift_oauth_objects("foo", "bar") + replace_namespaced_service_account.assert_called_once() + replace_cluster_role_binding.assert_called_once() + replace_namespaced_service.assert_called_once() + replace_namespaced_ingress.assert_called_once() + + +def test_gen_app_wrapper_with_oauth(mocker: MockerFixture): + mocker.patch( + "codeflare_sdk.utils.generate_yaml._get_api_host", return_value="foo.com" + ) + mocker.patch( + "codeflare_sdk.cluster.cluster.get_current_namespace", + return_value="opendatahub", + ) + write_user_appwrapper = MagicMock() + mocker.patch( + "codeflare_sdk.utils.generate_yaml.write_user_appwrapper", write_user_appwrapper + ) + Cluster(ClusterConfiguration("test_cluster", openshift_oauth=True)) + user_yaml = write_user_appwrapper.call_args.args[0] + assert any( + container["name"] == "oauth-proxy" + for container in user_yaml["spec"]["resources"]["GenericItems"][0][ + "generictemplate" + ]["spec"]["headGroupSpec"]["template"]["spec"]["containers"] + ) + + # Make sure to always keep this function last def test_cleanup(): os.remove("unit-test-cluster.yaml")