Skip to content

Commit

Permalink
add options to create_scheduler so that the get_runner method is full…
Browse files Browse the repository at this point in the history
…y configurable

I also added a note to each schedulers __init__ method to help with maintainablility

Signed-off-by: Kevin <[email protected]>
  • Loading branch information
KPostOffice committed Sep 15, 2023
1 parent 1f5eac8 commit ac577c0
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 8 deletions.
12 changes: 11 additions & 1 deletion torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def __init__(
log_client: Optional[Any] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("aws_batch", session_name, docker_client=docker_client)

# pyre-fixme[4]: Attribute annotation cannot be `Any`.
Expand Down Expand Up @@ -796,7 +797,16 @@ def _stream_events(
yield event["message"] + "\n"


def create_scheduler(session_name: str, **kwargs: object) -> AWSBatchScheduler:
def create_scheduler(
session_name: str,
client: Optional[Any] = None,
log_client: Optional[Any] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: object,
) -> AWSBatchScheduler:
return AWSBatchScheduler(
session_name=session_name,
client=client,
log_client=log_client,
docker_client=docker_client,
)
1 change: 1 addition & 0 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("docker", session_name)

def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:
Expand Down
6 changes: 5 additions & 1 deletion torchx/schedulers/gcp_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
client: Optional[Any] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
Scheduler.__init__(self, "gcp_batch", session_name)
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
self.__client = client
Expand Down Expand Up @@ -474,7 +475,10 @@ def _cancel_existing(self, app_id: str) -> None:
self._client.delete_job(request=request)


def create_scheduler(session_name: str, **kwargs: object) -> GCPBatchScheduler:
def create_scheduler(
session_name: str, client: Optional[Any], **kwargs: object
) -> GCPBatchScheduler:
return GCPBatchScheduler(
session_name=session_name,
client=client,
)
10 changes: 9 additions & 1 deletion torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,7 @@ def __init__(
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("kubernetes_mcad", session_name, docker_client=docker_client)

self._client = client
Expand Down Expand Up @@ -1230,9 +1231,16 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesMCADScheduler:
def create_scheduler(
session_name: str,
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: Any,
) -> KubernetesMCADScheduler:
return KubernetesMCADScheduler(
session_name=session_name,
client=client,
docker_client=docker_client,
)


Expand Down
10 changes: 9 additions & 1 deletion torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("kubernetes", session_name, docker_client=docker_client)

self._client = client
Expand Down Expand Up @@ -777,9 +778,16 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesScheduler:
def create_scheduler(
session_name: str,
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: Any,
) -> KubernetesScheduler:
return KubernetesScheduler(
session_name=session_name,
client=client,
docker_client=docker_client,
)


Expand Down
11 changes: 9 additions & 2 deletions torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def __init__(
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("local", session_name)

# TODO T72035686 replace dict with a proper LRUCache data structure
Expand Down Expand Up @@ -1124,9 +1125,15 @@ def __next__(self) -> str:
return line


def create_scheduler(session_name: str, **kwargs: Any) -> LocalScheduler:
def create_scheduler(
session_name: str,
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
**kwargs: Any,
) -> LocalScheduler:
return LocalScheduler(
session_name=session_name,
cache_size=kwargs.get("cache_size", 100),
image_provider_class=CWDImageProvider,
cache_size=cache_size,
extra_paths=extra_paths,
)
1 change: 1 addition & 0 deletions torchx/schedulers/lsf_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class LsfScheduler(Scheduler[LsfOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("lsf", session_name)

def _run_opts(self) -> runopts:
Expand Down
7 changes: 5 additions & 2 deletions torchx/schedulers/ray_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
def __init__(
self, session_name: str, ray_client: Optional[JobSubmissionClient] = None
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("ray", session_name)

# w/o Final None check in _get_ray_client does not work as it pyre assumes mutability
Expand Down Expand Up @@ -441,10 +442,12 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> "RayScheduler":
def create_scheduler(
session_name: str, ray_client: Optional[JobSubmissionClient] = None, **kwargs: Any
) -> "RayScheduler":
if not has_ray(): # pragma: no cover
raise ModuleNotFoundError(
"Ray is not installed in the current Python environment."
)

return RayScheduler(session_name=session_name)
return RayScheduler(session_name=session_name, ray_client=ray_client)
1 change: 1 addition & 0 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("slurm", session_name)

def _run_opts(self) -> runopts:
Expand Down

0 comments on commit ac577c0

Please sign in to comment.