Skip to content

Commit

Permalink
Implement a basic working ModelTuner API with Ray Tune (#1785)
Browse files Browse the repository at this point in the history
* Add hyperopt and ray as optional dependencies

* Add basic module structure

* Add classproperty and dependencies decorators

* Add file with default search spaces and helpers

* Add skeleton ModelTuner API

* Add Tunable type and notebook util function

* Add TunerManager API, update __init__s

* Add skeleton functions for TunerManager

* Implement TunerManager validation functions

* Address scvi, ray import kernel crashes

* Update CI workflow with autotune deps

* Implement dummy SCVI autotune interface

* Add sanity autotune test

* Retrigger checks

* Update .github/workflows/test.yml

* Potential fix for CUDA forked subprocess error

* Potential fix for CUDA forked process error

* Force spawn subprocesses

* Add optional pytest mark, change to forkserver default

* Try import ray on init

* Update docs with autotune

* Update docs and docstrings for autotune

* Update docs, include more basic tests for autotune

* Faster tests, validate anndata earlier

* Fix missing kwargs in test
  • Loading branch information
martinkim0 authored Nov 18, 2022
1 parent b2784f7 commit 91193e4
Show file tree
Hide file tree
Showing 18 changed files with 909 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
id: dependencies
run: |
pip install pytest-cov
pip install .[dev,pymde]
pip install .[dev,pymde,autotune]
# Following checks are independent and are run even if one fails
- name: Lint with flake8
Expand Down
20 changes: 20 additions & 0 deletions docs/api/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,24 @@ TrainingPlans define train/test/val optimization steps for modules.
```

## Model hyperparameter autotuning

`scvi-tools` supports automatic model hyperparameter tuning using [Ray Tune]. These
classes allow for new model classes to be easily integrated with the module.

```{eval-rst}
.. currentmodule:: scvi
```

```{eval-rst}
.. autosummary::
:toctree: reference/
:nosignatures:
autotune.TunerManager
autotune.Tunable
```

## Utilities

```{eval-rst}
Expand All @@ -254,3 +272,5 @@ Utility functions used by scvi-tools.
utils.setup_anndata_dsp
utils.attrdict
```

[ray tune]: https://docs.ray.io/en/latest/tune/index.html
17 changes: 17 additions & 0 deletions docs/api/user.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ Here we maintain a few package specific utilities for feature selection, etc.
data.organize_multiome_anndatas
```

```{eval-rst}
.. currentmodule:: scvi
```

## Model hyperparameter autotuning

`scvi-tools` supports automatic model hyperparameter tuning using [Ray Tune].

```{eval-rst}
.. autosummary::
:toctree: reference/
:nosignatures:
autotune.ModelTuner
```

## Utilities

Here we maintain miscellaneous general methods.
Expand Down Expand Up @@ -126,3 +142,4 @@ An instance of the {class}`~scvi._settings.ScviConfig` is available as `scvi.set
[anndata]: https://anndata.readthedocs.io/en/stable/
[scanpy]: https://scanpy.readthedocs.io/en/stable/index.html
[utilities]: https://scanpy.readthedocs.io/en/stable/api/index.html#reading
[ray tune]: https://docs.ray.io/en/latest/tune/index.html
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"ml_collections": ("https://ml-collections.readthedocs.io/en/latest/", None),
"mudata": ("https://mudata.readthedocs.io/en/latest/", None),
"ray": ("https://docs.ray.io/en/latest/", None),
}


Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ flake8 = {version = ">=3.7.7", optional = true}
flax = "*"
furo = {version = ">=2022.2.14.1", optional = true}
h5py = ">=2.9.0"
hyperopt = {version = ">=0.2", optional = true}
importlib-metadata = {version = ">1.0", python = "<3.8"}
ipython = {version = ">=7.20", optional = true, python = ">=3.7"}
ipywidgets = "*"
Expand Down Expand Up @@ -68,6 +69,7 @@ pytest = {version = ">=4.4", optional = true}
python = ">=3.7,<4.0"
python-igraph = {version = "*", optional = true}
pytorch-lightning = ">=1.8.0,<1.9"
ray = {extras = ["tune"], version = ">=2.1.0", optional = true}
rich = ">=9.1.0"
scanpy = {version = ">=1.6", optional = true}
scikit-learn = ">=0.21.2"
Expand Down Expand Up @@ -100,6 +102,7 @@ docs = [
"sphinxcontrib-bibtex",
"myst-parser",
]
autotune = ["hyperopt", "ray", "ipython"]
pymde = ["pymde"]
tutorials = ["scanpy", "leidenalg", "python-igraph", "loompy", "scikit-misc", "pynndescent", "pymde"]

Expand Down
18 changes: 16 additions & 2 deletions scvi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
# Set default logging handler to avoid logging with logging.lastResort logger.
import logging

try:
# necessary as importing scvi after ray causes kernel crash
from ray import tune # noqa
except ImportError:
pass

from ._constants import REGISTRY_KEYS
from ._settings import settings

# this import needs to come after prior imports to prevent circular import
from . import data, model, external, utils
from . import autotune, data, model, external, utils

# https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
# https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
Expand All @@ -25,4 +31,12 @@
scvi_logger = logging.getLogger("scvi")
scvi_logger.propagate = False

__all__ = ["settings", "REGISTRY_KEYS", "data", "model", "external", "utils"]
__all__ = [
"settings",
"REGISTRY_KEYS",
"autotune",
"data",
"model",
"external",
"utils",
]
47 changes: 47 additions & 0 deletions scvi/_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from functools import wraps
from typing import Callable, List, Union


class classproperty:
"""
Read-only class property decorator.
Source: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
"""

def __init__(self, f):
self.f = f

def __get__(self, obj, owner):
return self.f(owner)


def dependencies(packages: Union[str, List[str]]) -> Callable:
"""
Decorator to check for dependencies.
Parameters
----------
packages
A string or list of strings of packages to check for.
"""
if isinstance(packages, str):
packages = [packages]

def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapper(*args, **kwargs):
try:
import importlib

for package in packages:
importlib.import_module(package)
except ImportError:
raise ImportError(
f"Please install {packages} to use this functionality."
)
return fn(*args, **kwargs)

return wrapper

return decorator
5 changes: 5 additions & 0 deletions scvi/autotune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._manager import TunerManager
from ._tuner import ModelTuner
from ._types import Tunable

__all__ = ["ModelTuner", "Tunable", "TunerManager"]
47 changes: 47 additions & 0 deletions scvi/autotune/_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pytorch_lightning import LightningDataModule, LightningModule, Trainer

from scvi import model
from scvi.module.base import BaseModuleClass, JaxBaseModuleClass, PyroBaseModuleClass
from scvi.train import TrainRunner

# colors for rich table columns
COLORS = [
"dodger_blue1",
"dark_violet",
"green",
"dark_orange",
]

# default rich table column kwargs
COLUMN_KWARGS = {
"justify": "center",
"no_wrap": True,
"overflow": "fold",
}

# maps classes to the type of hyperparameters they use
TUNABLE_TYPES = {
"model": [
BaseModuleClass,
JaxBaseModuleClass,
PyroBaseModuleClass,
],
"train": [
LightningDataModule,
Trainer,
TrainRunner,
],
"train_plan": [
LightningModule,
],
}

# supported model classes
SUPPORTED = [model.SCVI]

# default hyperparameter search spaces for each model class
DEFAULTS = {
model.SCVI: {
"n_hidden": {"fn": "choice", "args": [[64, 128]]},
}
}
Loading

0 comments on commit 91193e4

Please sign in to comment.