diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 078a514e9f..afbdc01297 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,7 +37,7 @@ jobs: id: dependencies run: | pip install pytest-cov - pip install .[dev,pymde] + pip install .[dev,pymde,autotune] # Following checks are independent and are run even if one fails - name: Lint with flake8 diff --git a/docs/api/developer.md b/docs/api/developer.md index db03fa7ee0..7aa468635f 100644 --- a/docs/api/developer.md +++ b/docs/api/developer.md @@ -237,6 +237,24 @@ TrainingPlans define train/test/val optimization steps for modules. ``` +## Model hyperparameter autotuning + +`scvi-tools` supports automatic model hyperparameter tuning using [Ray Tune]. These +classes allow for new model classes to be easily integrated with the module. + +```{eval-rst} +.. currentmodule:: scvi +``` + +```{eval-rst} +.. autosummary:: + :toctree: reference/ + :nosignatures: + + autotune.TunerManager + autotune.Tunable +``` + ## Utilities ```{eval-rst} @@ -254,3 +272,5 @@ Utility functions used by scvi-tools. utils.setup_anndata_dsp utils.attrdict ``` + +[ray tune]: https://docs.ray.io/en/latest/tune/index.html diff --git a/docs/api/user.md b/docs/api/user.md index b8ab54fda3..b5c3126708 100644 --- a/docs/api/user.md +++ b/docs/api/user.md @@ -99,6 +99,22 @@ Here we maintain a few package specific utilities for feature selection, etc. data.organize_multiome_anndatas ``` +```{eval-rst} +.. currentmodule:: scvi +``` + +## Model hyperparameter autotuning + +`scvi-tools` supports automatic model hyperparameter tuning using [Ray Tune]. + +```{eval-rst} +.. autosummary:: + :toctree: reference/ + :nosignatures: + + autotune.ModelTuner +``` + ## Utilities Here we maintain miscellaneous general methods. @@ -126,3 +142,4 @@ An instance of the {class}`~scvi._settings.ScviConfig` is available as `scvi.set [anndata]: https://anndata.readthedocs.io/en/stable/ [scanpy]: https://scanpy.readthedocs.io/en/stable/index.html [utilities]: https://scanpy.readthedocs.io/en/stable/api/index.html#reading +[ray tune]: https://docs.ray.io/en/latest/tune/index.html diff --git a/docs/conf.py b/docs/conf.py index 12ccac036e..ded6b165b5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -92,6 +92,7 @@ "jax": ("https://jax.readthedocs.io/en/latest/", None), "ml_collections": ("https://ml-collections.readthedocs.io/en/latest/", None), "mudata": ("https://mudata.readthedocs.io/en/latest/", None), + "ray": ("https://docs.ray.io/en/latest/", None), } diff --git a/pyproject.toml b/pyproject.toml index 45c309fd4e..92c6973e4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ flake8 = {version = ">=3.7.7", optional = true} flax = "*" furo = {version = ">=2022.2.14.1", optional = true} h5py = ">=2.9.0" +hyperopt = {version = ">=0.2", optional = true} importlib-metadata = {version = ">1.0", python = "<3.8"} ipython = {version = ">=7.20", optional = true, python = ">=3.7"} ipywidgets = "*" @@ -68,6 +69,7 @@ pytest = {version = ">=4.4", optional = true} python = ">=3.7,<4.0" python-igraph = {version = "*", optional = true} pytorch-lightning = ">=1.8.0,<1.9" +ray = {extras = ["tune"], version = ">=2.1.0", optional = true} rich = ">=9.1.0" scanpy = {version = ">=1.6", optional = true} scikit-learn = ">=0.21.2" @@ -100,6 +102,7 @@ docs = [ "sphinxcontrib-bibtex", "myst-parser", ] +autotune = ["hyperopt", "ray", "ipython"] pymde = ["pymde"] tutorials = ["scanpy", "leidenalg", "python-igraph", "loompy", "scikit-misc", "pynndescent", "pymde"] diff --git a/scvi/__init__.py b/scvi/__init__.py index a9fbb7ca10..238b201d8e 100644 --- a/scvi/__init__.py +++ b/scvi/__init__.py @@ -3,11 +3,17 @@ # Set default logging handler to avoid logging with logging.lastResort logger. import logging +try: + # necessary as importing scvi after ray causes kernel crash + from ray import tune # noqa +except ImportError: + pass + from ._constants import REGISTRY_KEYS from ._settings import settings # this import needs to come after prior imports to prevent circular import -from . import data, model, external, utils +from . import autotune, data, model, external, utils # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 @@ -25,4 +31,12 @@ scvi_logger = logging.getLogger("scvi") scvi_logger.propagate = False -__all__ = ["settings", "REGISTRY_KEYS", "data", "model", "external", "utils"] +__all__ = [ + "settings", + "REGISTRY_KEYS", + "autotune", + "data", + "model", + "external", + "utils", +] diff --git a/scvi/_decorators.py b/scvi/_decorators.py new file mode 100644 index 0000000000..2d11947c7a --- /dev/null +++ b/scvi/_decorators.py @@ -0,0 +1,47 @@ +from functools import wraps +from typing import Callable, List, Union + + +class classproperty: + """ + Read-only class property decorator. + + Source: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property + """ + + def __init__(self, f): + self.f = f + + def __get__(self, obj, owner): + return self.f(owner) + + +def dependencies(packages: Union[str, List[str]]) -> Callable: + """ + Decorator to check for dependencies. + + Parameters + ---------- + packages + A string or list of strings of packages to check for. + """ + if isinstance(packages, str): + packages = [packages] + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + def wrapper(*args, **kwargs): + try: + import importlib + + for package in packages: + importlib.import_module(package) + except ImportError: + raise ImportError( + f"Please install {packages} to use this functionality." + ) + return fn(*args, **kwargs) + + return wrapper + + return decorator diff --git a/scvi/autotune/__init__.py b/scvi/autotune/__init__.py new file mode 100644 index 0000000000..61e1517330 --- /dev/null +++ b/scvi/autotune/__init__.py @@ -0,0 +1,5 @@ +from ._manager import TunerManager +from ._tuner import ModelTuner +from ._types import Tunable + +__all__ = ["ModelTuner", "Tunable", "TunerManager"] diff --git a/scvi/autotune/_defaults.py b/scvi/autotune/_defaults.py new file mode 100644 index 0000000000..bde19a92fa --- /dev/null +++ b/scvi/autotune/_defaults.py @@ -0,0 +1,47 @@ +from pytorch_lightning import LightningDataModule, LightningModule, Trainer + +from scvi import model +from scvi.module.base import BaseModuleClass, JaxBaseModuleClass, PyroBaseModuleClass +from scvi.train import TrainRunner + +# colors for rich table columns +COLORS = [ + "dodger_blue1", + "dark_violet", + "green", + "dark_orange", +] + +# default rich table column kwargs +COLUMN_KWARGS = { + "justify": "center", + "no_wrap": True, + "overflow": "fold", +} + +# maps classes to the type of hyperparameters they use +TUNABLE_TYPES = { + "model": [ + BaseModuleClass, + JaxBaseModuleClass, + PyroBaseModuleClass, + ], + "train": [ + LightningDataModule, + Trainer, + TrainRunner, + ], + "train_plan": [ + LightningModule, + ], +} + +# supported model classes +SUPPORTED = [model.SCVI] + +# default hyperparameter search spaces for each model class +DEFAULTS = { + model.SCVI: { + "n_hidden": {"fn": "choice", "args": [[64, 128]]}, + } +} diff --git a/scvi/autotune/_manager.py b/scvi/autotune/_manager.py new file mode 100644 index 0000000000..5a4cac5022 --- /dev/null +++ b/scvi/autotune/_manager.py @@ -0,0 +1,504 @@ +import inspect +import logging +import warnings +from collections import OrderedDict +from typing import Any, Callable, List, Optional, Tuple + +import rich + +try: + from ray import air, tune + from ray.tune.integration.pytorch_lightning import TuneReportCallback +except ImportError: + pass + +from scvi._decorators import dependencies +from scvi._types import AnnOrMuData +from scvi.model.base import BaseModelClass + +from ._defaults import COLORS, COLUMN_KWARGS, DEFAULTS, SUPPORTED, TUNABLE_TYPES +from ._types import TunableMeta +from ._utils import in_notebook + +logger = logging.getLogger(__name__) + + +class TunerManager: + """ + Internal manager for validation of inputs from :class:`~scvi.autotune.ModelTuner`. + + Parameters + ---------- + model_cls + :class:`~scvi.model.base.BaseModelClass` on which to tune hyperparameters. See + :class:`~scvi.autotune.ModelTuner` forsupported model classes. + """ + + def __init__(self, model_cls: BaseModelClass): + self._model_cls: BaseModelClass = self._validate_model_cls(model_cls) + self._defaults: dict = self._get_defaults(self._model_cls) + self._registry: dict = self._get_registry(self._model_cls) + + def _validate_model_cls(self, model_cls: BaseModelClass) -> BaseModelClass: + """Checks if the model class is suppo rted.""" + if model_cls not in SUPPORTED: + raise NotImplementedError( + f"{model_cls} is currently unsupported. Please see ModelTuner for a " + "list of supported model classes." + ) + return model_cls + + def _get_defaults(self, model_cls: BaseModelClass) -> dict: + """Returns the model class's default search space if available.""" + if model_cls not in DEFAULTS: + warnings.warn( + f"No default search space available for {model_cls}.", + UserWarning, + ) + return DEFAULTS.get(model_cls, {}) + + def _get_registry(self, model_cls: BaseModelClass) -> dict: + """ + Returns the model class's registry of tunable hyperparameters and metrics. + + For a given model class, checks whether a `_tunables` class property has been + defined. If so, iterates through the attribute to recursively discover tunable + hyperparameters. + + Parameters + ---------- + model_cls + A validated :class:`~scvi.model.base.BaseModelClass`. + + Returns + ------- + registry: dict + A dictionary with the following keys: + + * ``"tunables"``: a dictionary of tunable hyperparameters and metadata + * ``"metrics"``: a dictionary of available metrics and metadata + """ + + def _cls_to_tunable_type(cls: Any) -> str: + for tunable_type, cls_list in TUNABLE_TYPES.items(): + if any([issubclass(cls, c) for c in cls_list]): + return tunable_type + return "unknown" + + def _get_tunables( + attr: Any, parent: Any = None, tunable_type: Optional[str] = None + ) -> dict: + tunables = {} + if inspect.isfunction(attr): + # check if function kwargs are tunable + for kwarg, metadata in inspect.signature(attr).parameters.items(): + if not isinstance(metadata.annotation, TunableMeta): + continue + default_val = metadata.default + if default_val is inspect.Parameter.empty: + default_val = None + tunables[kwarg] = { + "parent_class": parent, + "default_value": default_val, + "function": attr, + "tunable_type": tunable_type, + } + elif inspect.isclass(attr) and hasattr(attr, "_tunables"): + # recursively check if `_tunables` is implemented + tunable_type = _cls_to_tunable_type(attr) + for child in attr._tunables: + tunables.update( + _get_tunables(child, parent=attr, tunable_type=tunable_type) + ) + return tunables + + def _get_metrics(model_cls: BaseModelClass) -> OrderedDict: + # TODO: discover more metrics + return {"validation_loss": "min"} + + registry = { + "tunables": _get_tunables(model_cls), + "metrics": _get_metrics(model_cls), + } + return registry + + def _get_search_space(self, search_space: dict) -> Tuple[dict, dict]: + """Parses a compact search space into separate kwargs dictionaries.""" + model_kwargs = {} + train_kwargs = {} + plan_kwargs = {} + tunables = self._registry["tunables"] + + for param, value in search_space.items(): + _type = tunables[param]["tunable_type"] + if _type == "model": + model_kwargs[param] = value + elif _type == "train": + train_kwargs[param] = value + elif _type == "plan": + plan_kwargs[param] = value + + train_kwargs["plan_kwargs"] = plan_kwargs + return model_kwargs, train_kwargs + + @dependencies("ray.tune") + def _validate_search_space( + self, search_space: dict, use_defaults: bool, exclude: List[str] + ) -> dict: + """Validates a search space against the hyperparameter registry.""" + # validate user-provided search space + for param in search_space: + if param in self._registry["tunables"]: + continue + warnings.warn( + f"Provided parameter {param} is invalid for {self._model_cls.__name__}." + " Please see available parameters with `ModelTuner.info()`. " + "Ignoring parameter.", + UserWarning, + ) + search_space.pop(param) + + # add defaults if requested + _search_space = {} + if use_defaults: + + # parse defaults into tune sample functions + for param, metadata in self._defaults.items(): + sample_fn = getattr(tune, metadata["fn"]) + fn_args = metadata.get("args", []) + fn_kwargs = metadata.get("kwargs", {}) + _search_space[param] = sample_fn(*fn_args, **fn_kwargs) + + # exclude defaults if requested + logger.info( + f"Merging search space with defaults for {self._model_cls.__name__}." + ) + for param in exclude: + if param not in _search_space: + warnings.warn( + f"Excluded parameter {param} not in defaults search space. " + "Ignoring parameter.", + UserWarning, + ) + _search_space.pop(param, None) + + # priority given to user-provided search space + _search_space.update(search_space) + return _search_space + + def _validate_metrics( + self, metric: str, additional_metrics: List[str] + ) -> OrderedDict: + """Validates a set of metrics against the metric registry.""" + registry_metrics = self._registry["metrics"] + _metrics = OrderedDict() + + # validate primary metric + if metric not in registry_metrics: + raise ValueError( + f"Provided metric {metric} is invalid for {self._model_cls.__name__}. " + "Please see available metrics with `ModelTuner.info()`. ", + ) + _metrics[metric] = registry_metrics[metric] + + # validate additional metrics + for m in additional_metrics: + if m not in registry_metrics: + warnings.warn( + f"Provided metric {m} is invalid for {self._model_cls.__name__}. " + "Please see available metrics with `ModelTuner.info()`. " + "Ignoring metric.", + UserWarning, + ) + continue + _metrics[m] = registry_metrics[m] + + return _metrics + + @dependencies("ray.tune") + def _validate_scheduler( + self, scheduler: str, metrics: OrderedDict, scheduler_kwargs: dict + ) -> Any: + """Validates a trial scheduler.""" + metric = list(metrics.keys())[0] + mode = metrics[metric] + _kwargs = {"metric": metric, "mode": mode} + + if scheduler == "asha": + _default_kwargs = { + "max_t": 100, + "grace_period": 1, + "reduction_factor": 2, + } + _scheduler = tune.schedulers.ASHAScheduler + elif scheduler == "hyperband": + _default_kwargs = {} + _scheduler = tune.schedulers.HyperBandScheduler + elif scheduler == "median": + _default_kwargs = {} + _scheduler = tune.schedulers.MedianStoppingRule + elif scheduler == "pbt": + _default_kwargs = {} + _scheduler = tune.schedulers.PopulationBasedTraining + elif scheduler == "fifo": + _default_kwargs = {} + _scheduler = tune.schedulers.FIFOScheduler + + # prority given to user-provided scheduler kwargs + _default_kwargs.update(scheduler_kwargs) + _kwargs.update(_default_kwargs) + return _scheduler(**_kwargs) + + @dependencies(["ray.tune", "hyperopt"]) + def _validate_search_algorithm( + self, searcher: str, metrics: OrderedDict, searcher_kwargs: dict + ) -> Any: + """Validates a hyperparameter search algorithm.""" + metric = list(metrics.keys())[0] + mode = metrics[metric] + + if searcher == "random": + _default_kwargs = {} + _searcher = tune.search.basic_variant.BasicVariantGenerator + elif searcher == "grid": + _default_kwargs = {} + _searcher = tune.search.basic_variant.BasicVariantGenerator + elif searcher == "hyperopt": + _default_kwargs = { + "metric": metric, + "mode": mode, + } + tune.search.SEARCH_ALG_IMPORT["hyperopt"]() # tune not importing hyperopt + _searcher = tune.search.hyperopt.HyperOptSearch + + # prority given to user-provided searcher kwargs + _default_kwargs.update(searcher_kwargs) + return _searcher(**_default_kwargs) + + def _validate_scheduler_and_search_algorithm( + self, + scheduler: str, + searcher: str, + metrics: OrderedDict, + scheduler_kwargs: dict, + searcher_kwargs: dict, + ) -> Tuple[Any, Any]: + """Validates a scheduler and search algorithm pair for compatibility.""" + if scheduler not in ["asha", "hyperband", "median", "pbt", "fifo"]: + raise ValueError( + f"Provided scheduler {scheduler} is unsupported. Must be one of " + "['asha', 'hyperband', 'median', 'pbt', 'fifo']. ", + ) + if searcher not in ["random", "grid", "hyperopt"]: + raise ValueError( + f"Provided searcher {searcher} is unsupported. Must be one of " + "['random', 'grid', 'hyperopt']. ", + ) + if scheduler not in ["asha", "median", "hyperband"] and searcher not in [ + "random", + "grid", + ]: + raise ValueError( + f"Provided scheduler {scheduler} is incompatible with the provided " + f"searcher {searcher}. Please see " + "https://docs.ray.io/en/latest/tune/key-concepts.html for more info." + ) + + _scheduler = self._validate_scheduler(scheduler, metrics, scheduler_kwargs) + _searcher = self._validate_search_algorithm(searcher, metrics, searcher_kwargs) + return _scheduler, _searcher + + @dependencies("ray.tune") + def _validate_reporter( + self, reporter: bool, search_space: dict, metrics: OrderedDict + ) -> Any: + """Validates a reporter depending on the execution environment.""" + _metric_keys = list(metrics.keys()) + _param_keys = list(search_space.keys()) + _kwargs = { + "metric_columns": _metric_keys, + "parameter_columns": _param_keys, + "metric": _metric_keys[0], + "mode": metrics[_metric_keys[0]], + } + + if not reporter: + _reporter = None + elif in_notebook(): + _reporter = tune.JupyterNotebookReporter(**_kwargs) + else: + _reporter = tune.CLIReporter(**_kwargs) + + return _reporter + + def _validate_resources(self, resources: dict) -> dict: + """Validates a resources-use specification.""" + # TODO: perform resource checking + return resources + + def _get_setup_kwargs(self, adata: AnnOrMuData) -> dict: + """Retrieves the kwargs used for setting up `adata` with the model class.""" + manager = self._model_cls._get_most_recent_anndata_manager(adata) + return manager._get_setup_method_args().get("setup_args", {}) + + @dependencies("ray.tune") + def _get_trainable( + self, + adata: AnnOrMuData, + metrics: OrderedDict, + resources: dict, + setup_kwargs: dict, + max_epochs: int, + ) -> Callable: + """Returns a trainable function consumable by :class:`~ray.tune.Tuner`.""" + + def _trainable( + search_space: dict, + *, + model_cls: BaseModelClass = None, + adata: AnnOrMuData = None, + metric: str = None, + setup_kwargs: dict = None, + max_epochs: int = None, + ) -> None: + model_kwargs, train_kwargs = self._get_search_space(search_space) + # TODO: generalize to models with mudata + model_cls.setup_anndata(adata, **setup_kwargs) + model = model_cls(adata, **model_kwargs) + monitor = TuneReportCallback(metric, on="validation_end") + # TODO: adaptive max_epochs + model.train( + max_epochs=max_epochs, + check_val_every_n_epoch=1, + callbacks=[monitor], + **train_kwargs, + ) + + _wrap_params = tune.with_parameters( + _trainable, + model_cls=self._model_cls, + adata=adata, + metric=list(metrics.keys())[0], + setup_kwargs=setup_kwargs, + max_epochs=max_epochs, + ) + return tune.with_resources(_wrap_params, resources=resources) + + @dependencies(["ray.tune", "ray.air"]) + def _get_tuner( + self, + adata: AnnOrMuData, + *, + metric: Optional[str] = None, + additional_metrics: Optional[List[str]] = None, + search_space: Optional[dict] = None, + use_defaults: bool = True, + exclude: Optional[List[str]] = None, + num_samples: Optional[int] = None, + max_epochs: Optional[int] = None, + scheduler: Optional[str] = None, + scheduler_kwargs: Optional[dict] = None, + searcher: Optional[str] = None, + searcher_kwargs: Optional[dict] = None, + reporter: bool = True, + resources: Optional[dict] = None, + ) -> Any: + """Configures a :class:`~ray.tune.Tuner` instance after validation.""" + metric = metric or list(self._registry["metrics"].keys())[0] + additional_metrics = additional_metrics or [] + search_space = search_space or {} + exclude = exclude or [] + num_samples = num_samples or 10 + max_epochs = max_epochs or 10 + scheduler = scheduler or "asha" + scheduler_kwargs = scheduler_kwargs or {} + searcher = searcher or "hyperopt" + searcher_kwargs = searcher_kwargs or {} + resources = resources or {} + + _ = self._model_cls(adata) + _metrics = self._validate_metrics(metric, additional_metrics) + _search_space = self._validate_search_space(search_space, use_defaults, exclude) + _scheduler, _searcher = self._validate_scheduler_and_search_algorithm( + scheduler, searcher, _metrics, scheduler_kwargs, searcher_kwargs + ) + _reporter = self._validate_reporter(reporter, _search_space, _metrics) + _resources = self._validate_resources(resources) + _setup_kwargs = self._get_setup_kwargs(adata) + _trainable = self._get_trainable( + adata, + _metrics, + _resources, + _setup_kwargs, + max_epochs, + ) + + tune_config = tune.tune_config.TuneConfig( + scheduler=_scheduler, + search_alg=_searcher, + num_samples=num_samples, + ) + # TODO: add kwarg for name or auto-generate name? + run_config = air.config.RunConfig( + name="scvi-tune", + progress_reporter=_reporter, + ) + tuner = tune.Tuner( + trainable=_trainable, + param_space=_search_space, + tune_config=tune_config, + run_config=run_config, + ) + return tuner + + def _add_columns( + self, table: rich.table.Table, columns: List[str] + ) -> rich.table.Table: + """Adds columns to a :class:`~rich.table.Table` with default formatting.""" + for i, column in enumerate(columns): + table.add_column(column, style=COLORS[i], **COLUMN_KWARGS) + return table + + def _view_registry(self, show_resources: bool) -> None: + """Displays a summary of the model class's registry and available resources.""" + console = rich.console.Console(force_jupyter=in_notebook()) + + tunables_table = self._add_columns( + rich.table.Table(title="Tunable hyperparameters"), + ["Hyperparameter", "Tunable type", "Default value", "Source"], + ) + for param, metadata in self._registry["tunables"].items(): + tunables_table.add_row( + str(param), + str(metadata["tunable_type"]), + str(metadata["default_value"]), + str(metadata["parent_class"]), + ) + + metrics_table = self._add_columns( + rich.table.Table(title="Available metrics"), + ["Metric", "Mode"], + ) + for metric, mode in self._registry["metrics"].items(): + metrics_table.add_row(str(metric), str(mode)) + + defaults_table = self._add_columns( + rich.table.Table(title="Default search space"), + ["Hyperparameter", "Sample function", "Arguments", "Keyword arguments"], + ) + for param, metadata in self._defaults.items(): + defaults_table.add_row( + str(param), + str(metadata["fn"]), + str(metadata.get("args", [])), + str(metadata.get("kwargs", {})), + ) + + console.print(f"Registry for {self._model_cls.__name__}") + console.print(tunables_table) + console.print(metrics_table) + console.print(defaults_table) + + if show_resources: + # TODO: retrieve available resources + pass diff --git a/scvi/autotune/_tuner.py b/scvi/autotune/_tuner.py new file mode 100644 index 0000000000..e5ee3fc90b --- /dev/null +++ b/scvi/autotune/_tuner.py @@ -0,0 +1,111 @@ +from scvi._types import AnnOrMuData +from scvi.model.base import BaseModelClass + +from ._manager import TunerManager + + +class ModelTuner: + """ + Automated and parallel hyperparameter tuning with :ref:`~ray.tune`. + + Wraps a :class:`~ray.tune.Tuner` instance attached to a scvi-tools model class. + Note: this API is in beta and is subject to change in future releases. + + Parameters + ---------- + model_cls + :class:`~scvi.model.base.BaseModelClass` on which to tune hyperparameters. + Currently supported model classes are: + + * :class:`~scvi.model.SCVI` + + Examples + -------- + >>> import anndata + >>> import scvi + >>> adata = anndata.read_h5ad(path_to_h5ad) + >>> model_cls = scvi.model.SCVI + >>> model_cls.setup_anndata(adata) + >>> tuner = scvi.autotune.ModelTuner(model_cls) + >>> results = tuner.fit(adata, metric="validation_loss) + """ + + def __init__(self, model_cls: BaseModelClass): + self._manager = TunerManager(model_cls) + + def fit( + self, + adata: AnnOrMuData, + **kwargs, + ) -> None: + """ + Run a specified hyperparameter sweep for the associated model class. + + Parameters + ---------- + adata + :class:`~anndata.AnnData` or :class:`~mudata.MuData` that has been setup + with the associated model class. + metric + The primary metric to optimize. If not provided, defaults to the model + class's validation loss. + additional_metrics + Additional metrics to track during the experiment. If not provided, defaults + to no other metrics. + search_space + Dictionary of hyperparameter names and their respective search spaces + provided as instantiated Ray Tune sample functions. Available + hyperparameters can be viewed with :meth:`~scvi.autotune.ModelTuner.info`. + Must be provided if `use_defaults` is `False`. + use_defaults + Whether to use the model class's default search space, which can be viewed + with :meth:`~scvi.autotune.ModelTuner.info`. If `True` and `search_space` is + provided, the two will be merged, giving priority to user-provided values. + Defaults to `True`. + exclude + List of hyperparameters to exclude from the default search space. If + `use_defaults` is `False`, this argument is ignored. + num_samples + Number of hyperparameter configurations to sample. + max_epochs + Maximum number of epochs to train each model. + scheduler + Ray Tune scheduler to use. Supported options are: + + * ``"asha"``: :class:`~ray.tune.schedulers.ASHAScheduler` + * ``"hyperband"``: :class:`~ray.tune.schedulers.HyperBandScheduler` + * ``"median"``: :class:`~ray.tune.schedulers.MedianStoppingRule` + * ``"pbt"``: :class:`~ray.tune.schedulers.PopulationBasedTraining` + * ``"fifo"``: :class:`~ray.tune.schedulers.FIFOScheduler` + scheduler_kwargs + Keyword arguments to pass to the scheduler. + searcher + Ray Tune search algorithm to use. Supported options are: + + * ``"random"``: :class:`~ray.tune.search.basic_variant.BasicVariantGenerator` + * ``"grid"``: :class:`~ray.tune.search.basic_variant.BasicVariantGenerator` + * ``"hyperopt"``: :class:`~ray.tune.hyperopt.HyperOptSearch` + searcher_kwargs + Keyword arguments to pass to the search algorithm. + reporter + Whether to display progress with a Ray Tune reporter. Depending on the + execution environment, will use one of the following reporters: + + * :class:`~ray.tune.CLIReporter` if running in a script + * :class:`~ray.tune.JupyterNotebookReporter` if running in a notebook + resources + Dictionary of maximum resources to allocate for the experiment. Available + keys include: + + * ``"cpu"``: maximum number of CPU threads to use + * ``"gpu"``: maximum number of GPUs to use + + If not provided, defaults to using one CPU thread and one GPU if available. + """ + tuner = self._manager._get_tuner(adata, **kwargs) + results = tuner.fit() + return results + + def info(self, show_resources: bool = False) -> None: + """Display information about the associated model class.""" + self._manager._view_registry(show_resources=show_resources) diff --git a/scvi/autotune/_types.py b/scvi/autotune/_types.py new file mode 100644 index 0000000000..036ad44288 --- /dev/null +++ b/scvi/autotune/_types.py @@ -0,0 +1,11 @@ +class TunableMeta(type): + """Metaclass for Tunable class.""" + + def __getitem__(cls, values): + if not isinstance(values, tuple): + values = (values,) + return type("Tunable_", (Tunable,), dict(__args__=values)) + + +class Tunable(metaclass=TunableMeta): + """Typing class for tagging keyword arguments as tunable.""" diff --git a/scvi/autotune/_utils.py b/scvi/autotune/_utils.py new file mode 100644 index 0000000000..e7768f43e8 --- /dev/null +++ b/scvi/autotune/_utils.py @@ -0,0 +1,22 @@ +import sys + + +def in_notebook() -> bool: + """ + Check if running in a Jupyter notebook or Colab session. + + Based on: https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook + """ + try: + from IPython import get_ipython + + shell = get_ipython().__class__.__name__ + if shell == "ZMQInteractiveShell": + return True + elif shell == "TerminalInteractiveShell": + return False + else: + return False + except ImportError: + in_colab = "google.colab" in sys.modules + return in_colab diff --git a/scvi/model/_scvi.py b/scvi/model/_scvi.py index 15ccff46c9..4402cf1517 100644 --- a/scvi/model/_scvi.py +++ b/scvi/model/_scvi.py @@ -1,11 +1,12 @@ import logging -from typing import List, Optional +from typing import Any, List, Optional, Tuple from anndata import AnnData from scipy.sparse import csr_matrix from scvi import REGISTRY_KEYS from scvi._compat import Literal +from scvi._decorators import classproperty from scvi._types import LatentDataType from scvi.data import AnnDataManager from scvi.data._constants import _ADATA_LATENT_UNS_KEY @@ -22,6 +23,7 @@ from scvi.model._utils import _init_library_size from scvi.model.base import UnsupervisedTrainingMixin from scvi.module import VAE +from scvi.module.base import BaseModuleClass from scvi.utils import setup_anndata_dsp from .base import ArchesMixin, BaseLatentModeModelClass, RNASeqMixin, VAEMixin @@ -133,7 +135,7 @@ def __init__( self.adata_manager, n_batch ) - self.module = VAE( + self.module = self._module_cls( n_input=self.summary_stats.n_vars, n_batch=n_batch, n_labels=self.summary_stats.n_labels, @@ -166,6 +168,14 @@ def __init__( ) self.init_params_ = self._get_init_params(locals()) + @classproperty + def _module_cls(cls) -> BaseModuleClass: + return VAE + + @classproperty + def _tunables(cls) -> Tuple[Any]: + return (cls._module_cls,) + @classmethod @setup_anndata_dsp.dedent def setup_anndata( diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 51c7d6446e..ed7c7f91b3 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -1,5 +1,5 @@ """Main module.""" -from typing import Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Tuple import numpy as np import torch @@ -10,7 +10,9 @@ from scvi import REGISTRY_KEYS from scvi._compat import Literal +from scvi._decorators import classproperty from scvi._types import LatentDataType +from scvi.autotune._types import Tunable from scvi.distributions import NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial from scvi.module.base import BaseLatentModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderSCVI, Encoder, LinearDecoderSCVI, one_hot @@ -95,7 +97,7 @@ def __init__( n_input: int, n_batch: int = 0, n_labels: int = 0, - n_hidden: int = 128, + n_hidden: Tunable[int] = 128, n_latent: int = 10, n_layers: int = 1, n_continuous_cov: int = 0, @@ -211,6 +213,10 @@ def __init__( scale_activation="softplus" if use_size_factor_key else "softmax", ) + @classproperty + def _tunables(cls) -> Tuple[Any]: + return (cls.__init__,) + def _get_inference_input( self, tensors, diff --git a/tests/autotune/test_manager.py b/tests/autotune/test_manager.py new file mode 100644 index 0000000000..1c09b74d9e --- /dev/null +++ b/tests/autotune/test_manager.py @@ -0,0 +1,36 @@ +import pytest + +import scvi + + +def test_tuner_manager_init(): + model_cls = scvi.model.SCVI + manager = scvi.autotune.TunerManager(model_cls) + assert hasattr(manager, "_model_cls") + assert hasattr(manager, "_defaults") + assert hasattr(manager, "_registry") + + registry = manager._registry + assert "tunables" in registry + assert "metrics" in registry + + +def test_tuner_manager_basic_validation(): + model_cls = scvi.model.SCVI + manager = scvi.autotune.TunerManager(model_cls) + + # invalid params should raise an exception + with pytest.raises(Exception): + manager._validate_search_space({"not_a_param": None}, False, []) + + # search space does not change with `use_defaults == False + search_space = manager._validate_search_space({"n_hidden": None}, False, []) + assert search_space == {"n_hidden": None} + + # search space does not include "n_hidden" if excluded + search_space = manager._validate_search_space({}, True, ["n_hidden"]) + assert "n_hidden" not in search_space + + # invalid metrics should raise an exception + with pytest.raises(Exception): + manager._validate_metrics("not_a_metric", []) diff --git a/tests/autotune/test_tuner.py b/tests/autotune/test_tuner.py new file mode 100644 index 0000000000..c235b5e501 --- /dev/null +++ b/tests/autotune/test_tuner.py @@ -0,0 +1,29 @@ +import pytest + +import scvi + + +def test_model_tuner_init(): + model_cls = scvi.model.SCVI + scvi.autotune.ModelTuner(model_cls) + + +def test_model_tuner_fit(): + model_cls = scvi.model.SCVI + tuner = scvi.autotune.ModelTuner(model_cls) + + # adata should be setup before passing to `fit` + adata = scvi.data.synthetic_iid() + with pytest.raises(Exception): + tuner.fit(adata, num_samples=1, max_epochs=1) + + model_cls.setup_anndata(adata) + results = tuner.fit(adata, num_samples=1, max_epochs=1) + assert results is not None + + +def test_model_tuner_info(): + model_cls = scvi.model.SCVI + tuner = scvi.autotune.ModelTuner(model_cls) + + tuner.info() diff --git a/tests/conftest.py b/tests/conftest.py index c08016c3de..237372f423 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,17 @@ def pytest_addoption(parser): default=False, help="Run tests that retrieve stuff from the internet. This increases test time.", ) + parser.addoption( + "--optional", + action="store_true", + default=False, + help="Run tests that are optional.", + ) + + +def pytest_configure(config): + """Docstring for pytest_configure.""" + config.addinivalue_line("markers", "optional: mark test as optional.") def pytest_collection_modifyitems(config, items): @@ -33,6 +44,14 @@ def pytest_collection_modifyitems(config, items): if not run_internet and ("internet" in item.keywords): item.add_marker(skip_internet) + run_optional = config.getoption("--optional") + skip_optional = pytest.mark.skip(reason="need --optional option to run") + for item in items: + # All tests marked with `pytest.mark.optional` get skipped unless + # `--optional` passed + if not run_optional and ("optional" in item.keywords): + item.add_marker(skip_optional) + @pytest.fixture(scope="session") def save_path(tmpdir_factory):