Skip to content

Commit

Permalink
add option for programatically defining ray job client (#762)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin <[email protected]>
  • Loading branch information
KPostOffice authored Aug 24, 2023
1 parent dd2db49 commit b9a1d0d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 21 deletions.
56 changes: 39 additions & 17 deletions torchx/schedulers/ray_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from dataclasses import dataclass, field
from datetime import datetime
from shutil import copy2, rmtree
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple # noqa
from typing import Any, cast, Dict, Final, Iterable, List, Optional, Tuple # noqa

import urllib3

from torchx.schedulers.api import (
AppDryRunInfo,
Expand Down Expand Up @@ -148,9 +150,35 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
"""

def __init__(self, session_name: str) -> None:
def __init__(
self, session_name: str, ray_client: Optional[JobSubmissionClient] = None
) -> None:
super().__init__("ray", session_name)

# w/o Final None check in _get_ray_client does not work as it pyre assumes mutability
self._ray_client: Final[Optional[JobSubmissionClient]] = ray_client

def _get_ray_client(
self, job_submission_netloc: Optional[str] = None
) -> JobSubmissionClient:
if self._ray_client is not None:
client_netloc = urllib3.util.parse_url(
self._ray_client.get_address()
).netloc
if job_submission_netloc and job_submission_netloc != client_netloc:
raise ValueError(
f"client netloc ({client_netloc}) does not match job netloc ({job_submission_netloc})"
)
return self._ray_client
elif os.getenv("RAY_ADDRESS"):
return JobSubmissionClient(os.getenv("RAY_ADDRESS"))
elif not job_submission_netloc:
raise Exception(
"RAY_ADDRESS env variable or a scheduler with an attached Ray JobSubmissionClient is expected."
" See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info"
)
return JobSubmissionClient(f"http://{job_submission_netloc}")

# TODO: Add address as a potential CLI argument after writing ray.status() or passing in config file
def _run_opts(self) -> runopts:
opts = runopts()
Expand Down Expand Up @@ -196,9 +224,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[RayJob]) -> str:
)

# 0. Create Job Client
client: JobSubmissionClient = JobSubmissionClient(
f"http://{job_submission_addr}"
)
client = self._get_ray_client(job_submission_netloc=job_submission_addr)

# 1. Copy Ray driver utilities
current_directory = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -341,12 +367,12 @@ def _parse_app_id(self, app_id: str) -> Tuple[str, str]:

def _cancel_existing(self, app_id: str) -> None: # pragma: no cover
addr, app_id = self._parse_app_id(app_id)
client = JobSubmissionClient(f"http://{addr}")
client = self._get_ray_client(job_submission_netloc=addr)
client.stop_job(app_id)

def _get_job_status(self, app_id: str) -> JobStatus:
addr, app_id = self._parse_app_id(app_id)
client = JobSubmissionClient(f"http://{addr}")
client = self._get_ray_client(job_submission_netloc=addr)
status = client.get_job_status(app_id)
if isinstance(status, str):
return cast(JobStatus, status)
Expand Down Expand Up @@ -393,26 +419,22 @@ def log_iter(
) -> Iterable[str]:
# TODO: support tailing, streams etc..
addr, app_id = self._parse_app_id(app_id)
client: JobSubmissionClient = JobSubmissionClient(f"http://{addr}")
client: JobSubmissionClient = self._get_ray_client(
job_submission_netloc=addr
)
logs: str = client.get_job_logs(app_id)
iterator = split_lines(logs)
if regex:
return filter_regex(regex, iterator)
return iterator

def list(self) -> List[ListAppResponse]:
address = os.getenv("RAY_ADDRESS")
if not address:
raise Exception(
"RAY_ADDRESS env variable is expected to be set to list jobs on ray scheduler."
" See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info"
)
client = JobSubmissionClient(address)
client = self._get_ray_client()
jobs = client.list_jobs()
ip = address.split("http://", 1)[-1]
netloc = urllib3.util.parse_url(client.get_address()).netloc
return [
ListAppResponse(
app_id=f"{ip}-{details.submission_id}",
app_id=f"{netloc}-{details.submission_id}",
state=_ray_status_to_torchx_appstate[details.status],
)
for details in jobs
Expand Down
40 changes: 36 additions & 4 deletions torchx/schedulers/test/ray_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from shutil import copy2
from typing import Any, cast, Iterable, Iterator, List, Optional, Type
from unittest import TestCase
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from torchx.schedulers import get_scheduler_factories
from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, ListAppResponse
Expand All @@ -22,6 +22,7 @@
if has_ray():
import ray
from ray.cluster_utils import Cluster
from ray.dashboard.modules.job.sdk import JobSubmissionClient
from ray.util.placement_group import remove_placement_group
from torchx.schedulers.ray import ray_driver
from torchx.schedulers.ray_scheduler import (
Expand Down Expand Up @@ -83,6 +84,9 @@ def setUp(self) -> None:
}
)

# mock validation step so that instantiation doesn't fail due to inability to reach dashboard
JobSubmissionClient._check_connection_and_version = MagicMock()

self._scheduler = RayScheduler("test_session")

self._isfile_patch = patch("torchx.schedulers.ray_scheduler.os.path.isfile")
Expand Down Expand Up @@ -320,11 +324,15 @@ def test_parse_app_id(self) -> None:
def test_list_throws_without_address(self) -> None:
if "RAY_ADDRESS" in os.environ:
del os.environ["RAY_ADDRESS"]
with self.assertRaisesRegex(
Exception, "RAY_ADDRESS env variable is expected"
):
with self.assertRaisesRegex(Exception, "RAY_ADDRESS env variable"):
self._scheduler.list()

def test_list_doesnt_throw_with_client(self) -> None:
ray_client = JobSubmissionClient(address="https://test.com")
ray_client.list_jobs = MagicMock(return_value=[])
_scheduler_with_client = RayScheduler("client_session", ray_client)
_scheduler_with_client.list() # testing for success (should not throw exception)

def test_min_replicas(self) -> None:
app = AppDef(
name="app",
Expand Down Expand Up @@ -358,6 +366,30 @@ def test_min_replicas(self) -> None:
):
self._scheduler._submit_dryrun(app, cfg={})

def test_nonmatching_address(self) -> None:
ray_client = JobSubmissionClient(address="https://test.address.com")
_scheduler_with_client = RayScheduler("client_session", ray_client)
app = AppDef(
name="app",
roles=[
Role(name="role", image="."),
],
)
with self.assertRaisesRegex(
ValueError, "client netloc .* does not match job netloc .*"
):
_scheduler_with_client.submit(app=app, cfg={})

def test_client_with_headers(self) -> None:
# This tests only one option for the client. Different versions may have more options available.
headers = {"Authorization": "Bearer: token"}
ray_client = JobSubmissionClient(
address="https://test.com", headers=headers, verify=False
)
_scheduler_with_client = RayScheduler("client_session", ray_client)
scheduler_client = _scheduler_with_client._get_ray_client()
self.assertDictContainsSubset(scheduler_client._headers, headers)

class RayClusterSetup:
_instance = None # pyre-ignore
_cluster = None # pyre-ignore
Expand Down

0 comments on commit b9a1d0d

Please sign in to comment.