diff --git a/docs/installation.md b/docs/installation.md index b1a6ee1421..9c086a977d 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -43,8 +43,8 @@ pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html All scvi-tools models will be faster when accelerated with a GPU. Before installing scvi-tools, you can install GPU versions of PyTorch and jax using conda as follows: ``` -conda install pytorch torchvision torchaudio cudatoolkit=11.6 -c pytorch -c conda-forge -conda install jax jaxlib cuda-nvcc -c conda-forge -c nvidia +conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia +conda install jax jaxlib -c conda-forge ``` Please go to the respective package website for more information on how to install with pip. diff --git a/pyproject.toml b/pyproject.toml index 79ac750f4a..0516f386e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = ["hatchling"] [project] name = "scvi-tools" -version = "0.20.0b0" +version = "0.20.0b1" description = "Deep probabilistic analysis of single-cell omics data." readme = "README.md" requires-python = ">=3.8" @@ -44,11 +44,11 @@ dependencies = [ "scipy", "scikit-learn>=0.21.2", "openpyxl>=3.0", - "rich>=9.1.0", + "rich>=12.0.0", "h5py>=2.9.0", "torch>=1.8.0", - "pytorch-lightning>=1.8.0,<1.9.0", - "torchmetrics>=0.6.0", + "pytorch-lightning>=1.9.0,<1.10.0", + "torchmetrics>=0.11.0", "pyro-ppl>=1.6.0", "tqdm>=4.56.0", "scikit-learn>=0.21.2", diff --git a/scvi/__init__.py b/scvi/__init__.py index 7c516d6941..216af7a3e4 100644 --- a/scvi/__init__.py +++ b/scvi/__init__.py @@ -3,12 +3,6 @@ # Set default logging handler to avoid logging with logging.lastResort logger. import logging -try: - # necessary as importing scvi after ray causes kernel crash - from ray import tune # noqa -except ImportError: - pass - from ._constants import REGISTRY_KEYS from ._settings import settings diff --git a/scvi/_settings.py b/scvi/_settings.py index 17c6988a60..54ecb1beb2 100644 --- a/scvi/_settings.py +++ b/scvi/_settings.py @@ -4,7 +4,7 @@ from typing import Literal, Union import torch -from lightning_lite import seed_everything +from pytorch_lightning import seed_everything from rich.console import Console from rich.logging import RichHandler diff --git a/scvi/train/_trainer.py b/scvi/train/_trainer.py index f1f05ecd12..2c00a478b5 100644 --- a/scvi/train/_trainer.py +++ b/scvi/train/_trainer.py @@ -4,7 +4,7 @@ import numpy as np import pytorch_lightning as pl -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers import Logger from scvi import settings from scvi.autotune._types import Tunable, TunableMixin @@ -103,7 +103,7 @@ def __init__( enable_progress_bar: bool = True, progress_bar_refresh_rate: int = 1, simple_progress_bar: bool = True, - logger: Union[Optional[LightningLoggerBase], bool] = None, + logger: Union[Optional[Logger], bool] = None, log_every_n_steps: int = 10, replace_sampler_ddp: bool = False, **kwargs,