diff --git a/torchx/schedulers/ray_scheduler.py b/torchx/schedulers/ray_scheduler.py index 912b1d680..335095313 100644 --- a/torchx/schedulers/ray_scheduler.py +++ b/torchx/schedulers/ray_scheduler.py @@ -14,7 +14,9 @@ from dataclasses import dataclass, field from datetime import datetime from shutil import copy2, rmtree -from typing import Any, cast, Dict, Iterable, List, Optional, Tuple # noqa +from typing import Any, cast, Dict, Final, Iterable, List, Optional, Tuple # noqa + +import urllib3 from torchx.schedulers.api import ( AppDryRunInfo, @@ -148,9 +150,35 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]): """ - def __init__(self, session_name: str) -> None: + def __init__( + self, session_name: str, ray_client: Optional[JobSubmissionClient] = None + ) -> None: super().__init__("ray", session_name) + # w/o Final None check in _get_ray_client does not work as it pyre assumes mutability + self._ray_client: Final[Optional[JobSubmissionClient]] = ray_client + + def _get_ray_client( + self, job_submission_netloc: Optional[str] = None + ) -> JobSubmissionClient: + if self._ray_client is not None: + client_netloc = urllib3.util.parse_url( + self._ray_client.get_address() + ).netloc + if job_submission_netloc and job_submission_netloc != client_netloc: + raise ValueError( + f"client netloc ({client_netloc}) does not match job netloc ({job_submission_netloc})" + ) + return self._ray_client + elif os.getenv("RAY_ADDRESS"): + return JobSubmissionClient(os.getenv("RAY_ADDRESS")) + elif not job_submission_netloc: + raise Exception( + "RAY_ADDRESS env variable or a scheduler with an attached Ray JobSubmissionClient is expected." + " See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info" + ) + return JobSubmissionClient(f"http://{job_submission_netloc}") + # TODO: Add address as a potential CLI argument after writing ray.status() or passing in config file def _run_opts(self) -> runopts: opts = runopts() @@ -196,9 +224,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[RayJob]) -> str: ) # 0. Create Job Client - client: JobSubmissionClient = JobSubmissionClient( - f"http://{job_submission_addr}" - ) + client = self._get_ray_client(job_submission_netloc=job_submission_addr) # 1. Copy Ray driver utilities current_directory = os.path.dirname(os.path.abspath(__file__)) @@ -341,12 +367,12 @@ def _parse_app_id(self, app_id: str) -> Tuple[str, str]: def _cancel_existing(self, app_id: str) -> None: # pragma: no cover addr, app_id = self._parse_app_id(app_id) - client = JobSubmissionClient(f"http://{addr}") + client = self._get_ray_client(job_submission_netloc=addr) client.stop_job(app_id) def _get_job_status(self, app_id: str) -> JobStatus: addr, app_id = self._parse_app_id(app_id) - client = JobSubmissionClient(f"http://{addr}") + client = self._get_ray_client(job_submission_netloc=addr) status = client.get_job_status(app_id) if isinstance(status, str): return cast(JobStatus, status) @@ -393,7 +419,9 @@ def log_iter( ) -> Iterable[str]: # TODO: support tailing, streams etc.. addr, app_id = self._parse_app_id(app_id) - client: JobSubmissionClient = JobSubmissionClient(f"http://{addr}") + client: JobSubmissionClient = self._get_ray_client( + job_submission_netloc=addr + ) logs: str = client.get_job_logs(app_id) iterator = split_lines(logs) if regex: @@ -401,18 +429,12 @@ def log_iter( return iterator def list(self) -> List[ListAppResponse]: - address = os.getenv("RAY_ADDRESS") - if not address: - raise Exception( - "RAY_ADDRESS env variable is expected to be set to list jobs on ray scheduler." - " See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info" - ) - client = JobSubmissionClient(address) + client = self._get_ray_client() jobs = client.list_jobs() - ip = address.split("http://", 1)[-1] + netloc = urllib3.util.parse_url(client.get_address()).netloc return [ ListAppResponse( - app_id=f"{ip}-{details.submission_id}", + app_id=f"{netloc}-{details.submission_id}", state=_ray_status_to_torchx_appstate[details.status], ) for details in jobs diff --git a/torchx/schedulers/test/ray_scheduler_test.py b/torchx/schedulers/test/ray_scheduler_test.py index 68d321cd9..1e7d181b5 100644 --- a/torchx/schedulers/test/ray_scheduler_test.py +++ b/torchx/schedulers/test/ray_scheduler_test.py @@ -11,7 +11,7 @@ from shutil import copy2 from typing import Any, cast, Iterable, Iterator, List, Optional, Type from unittest import TestCase -from unittest.mock import patch +from unittest.mock import MagicMock, patch from torchx.schedulers import get_scheduler_factories from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, ListAppResponse @@ -22,6 +22,7 @@ if has_ray(): import ray from ray.cluster_utils import Cluster + from ray.dashboard.modules.job.sdk import JobSubmissionClient from ray.util.placement_group import remove_placement_group from torchx.schedulers.ray import ray_driver from torchx.schedulers.ray_scheduler import ( @@ -83,6 +84,9 @@ def setUp(self) -> None: } ) + # mock validation step so that instantiation doesn't fail due to inability to reach dashboard + JobSubmissionClient._check_connection_and_version = MagicMock() + self._scheduler = RayScheduler("test_session") self._isfile_patch = patch("torchx.schedulers.ray_scheduler.os.path.isfile") @@ -320,11 +324,15 @@ def test_parse_app_id(self) -> None: def test_list_throws_without_address(self) -> None: if "RAY_ADDRESS" in os.environ: del os.environ["RAY_ADDRESS"] - with self.assertRaisesRegex( - Exception, "RAY_ADDRESS env variable is expected" - ): + with self.assertRaisesRegex(Exception, "RAY_ADDRESS env variable"): self._scheduler.list() + def test_list_doesnt_throw_with_client(self) -> None: + ray_client = JobSubmissionClient(address="https://test.com") + ray_client.list_jobs = MagicMock(return_value=[]) + _scheduler_with_client = RayScheduler("client_session", ray_client) + _scheduler_with_client.list() # testing for success (should not throw exception) + def test_min_replicas(self) -> None: app = AppDef( name="app", @@ -358,6 +366,30 @@ def test_min_replicas(self) -> None: ): self._scheduler._submit_dryrun(app, cfg={}) + def test_nonmatching_address(self) -> None: + ray_client = JobSubmissionClient(address="https://test.address.com") + _scheduler_with_client = RayScheduler("client_session", ray_client) + app = AppDef( + name="app", + roles=[ + Role(name="role", image="."), + ], + ) + with self.assertRaisesRegex( + ValueError, "client netloc .* does not match job netloc .*" + ): + _scheduler_with_client.submit(app=app, cfg={}) + + def test_client_with_headers(self) -> None: + # This tests only one option for the client. Different versions may have more options available. + headers = {"Authorization": "Bearer: token"} + ray_client = JobSubmissionClient( + address="https://test.com", headers=headers, verify=False + ) + _scheduler_with_client = RayScheduler("client_session", ray_client) + scheduler_client = _scheduler_with_client._get_ray_client() + self.assertDictContainsSubset(scheduler_client._headers, headers) + class RayClusterSetup: _instance = None # pyre-ignore _cluster = None # pyre-ignore