-
Notifications
You must be signed in to change notification settings - Fork 366
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a basic working
ModelTuner
API with Ray Tune (#1785)
* Add hyperopt and ray as optional dependencies * Add basic module structure * Add classproperty and dependencies decorators * Add file with default search spaces and helpers * Add skeleton ModelTuner API * Add Tunable type and notebook util function * Add TunerManager API, update __init__s * Add skeleton functions for TunerManager * Implement TunerManager validation functions * Address scvi, ray import kernel crashes * Update CI workflow with autotune deps * Implement dummy SCVI autotune interface * Add sanity autotune test * Retrigger checks * Update .github/workflows/test.yml * Potential fix for CUDA forked subprocess error * Potential fix for CUDA forked process error * Force spawn subprocesses * Add optional pytest mark, change to forkserver default * Try import ray on init * Update docs with autotune * Update docs and docstrings for autotune * Update docs, include more basic tests for autotune * Faster tests, validate anndata earlier * Fix missing kwargs in test
- Loading branch information
1 parent
b2784f7
commit 91193e4
Showing
18 changed files
with
909 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from functools import wraps | ||
from typing import Callable, List, Union | ||
|
||
|
||
class classproperty: | ||
""" | ||
Read-only class property decorator. | ||
Source: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property | ||
""" | ||
|
||
def __init__(self, f): | ||
self.f = f | ||
|
||
def __get__(self, obj, owner): | ||
return self.f(owner) | ||
|
||
|
||
def dependencies(packages: Union[str, List[str]]) -> Callable: | ||
""" | ||
Decorator to check for dependencies. | ||
Parameters | ||
---------- | ||
packages | ||
A string or list of strings of packages to check for. | ||
""" | ||
if isinstance(packages, str): | ||
packages = [packages] | ||
|
||
def decorator(fn: Callable) -> Callable: | ||
@wraps(fn) | ||
def wrapper(*args, **kwargs): | ||
try: | ||
import importlib | ||
|
||
for package in packages: | ||
importlib.import_module(package) | ||
except ImportError: | ||
raise ImportError( | ||
f"Please install {packages} to use this functionality." | ||
) | ||
return fn(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._manager import TunerManager | ||
from ._tuner import ModelTuner | ||
from ._types import Tunable | ||
|
||
__all__ = ["ModelTuner", "Tunable", "TunerManager"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer | ||
|
||
from scvi import model | ||
from scvi.module.base import BaseModuleClass, JaxBaseModuleClass, PyroBaseModuleClass | ||
from scvi.train import TrainRunner | ||
|
||
# colors for rich table columns | ||
COLORS = [ | ||
"dodger_blue1", | ||
"dark_violet", | ||
"green", | ||
"dark_orange", | ||
] | ||
|
||
# default rich table column kwargs | ||
COLUMN_KWARGS = { | ||
"justify": "center", | ||
"no_wrap": True, | ||
"overflow": "fold", | ||
} | ||
|
||
# maps classes to the type of hyperparameters they use | ||
TUNABLE_TYPES = { | ||
"model": [ | ||
BaseModuleClass, | ||
JaxBaseModuleClass, | ||
PyroBaseModuleClass, | ||
], | ||
"train": [ | ||
LightningDataModule, | ||
Trainer, | ||
TrainRunner, | ||
], | ||
"train_plan": [ | ||
LightningModule, | ||
], | ||
} | ||
|
||
# supported model classes | ||
SUPPORTED = [model.SCVI] | ||
|
||
# default hyperparameter search spaces for each model class | ||
DEFAULTS = { | ||
model.SCVI: { | ||
"n_hidden": {"fn": "choice", "args": [[64, 128]]}, | ||
} | ||
} |
Oops, something went wrong.