From 6cee07c21f0f3e773e77e6fe854391a4c6c0c67a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Sep 2024 16:10:05 +0300 Subject: [PATCH 01/22] [pre-commit.ci] pre-commit autoupdate (#2957) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.2 → v0.6.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.2...v0.6.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75cd134781..bd473b5d4f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: )$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.2 + rev: v0.6.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 0028cd3d3c92811adcfcbd54d3478ed143a81228 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Mon, 16 Sep 2024 12:30:35 +0300 Subject: [PATCH 02/22] Drop Python 3.9 (#2970) close https://github.com/scverse/scvi-tools/issues/2966 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/build.yml | 2 +- .github/workflows/build_image_base.yaml | 2 +- .github/workflows/build_image_latest.yaml | 2 +- .github/workflows/release.yml | 4 +- .github/workflows/test_linux.yml | 2 +- .github/workflows/test_linux_cuda.yml | 2 +- .github/workflows/test_linux_private.yml | 2 +- .github/workflows/test_linux_resolution.yml | 2 +- .github/workflows/test_macos.yml | 2 +- .github/workflows/test_windows.yml | 2 +- CHANGELOG.md | 1 + docs/installation.md | 4 +- pyproject.toml | 8 +-- src/scvi/_settings.py | 12 ++-- src/scvi/_types.py | 12 ++-- src/scvi/autotune/_experiment.py | 4 +- src/scvi/data/_anntorchdataset.py | 4 +- src/scvi/data/_built_in_data/_cellxgene.py | 5 +- src/scvi/data/_built_in_data/_pbmc.py | 2 +- src/scvi/data/_built_in_data/_synthetic.py | 3 +- src/scvi/data/_compat.py | 5 +- src/scvi/data/_download.py | 3 +- src/scvi/data/_preprocessing.py | 21 ++++--- src/scvi/data/_read.py | 5 +- src/scvi/data/fields/_arraylike_field.py | 20 +++---- src/scvi/data/fields/_base_field.py | 9 ++- src/scvi/data/fields/_dataframe_field.py | 10 ++-- src/scvi/data/fields/_layer_field.py | 7 +-- src/scvi/data/fields/_mudata.py | 14 ++--- src/scvi/data/fields/_protein.py | 5 +- src/scvi/data/fields/_scanvi.py | 5 +- src/scvi/data/fields/_uns_field.py | 5 +- src/scvi/dataloaders/_ann_dataloader.py | 7 +-- src/scvi/dataloaders/_concat_dataloader.py | 7 +-- src/scvi/dataloaders/_data_splitting.py | 29 +++++----- src/scvi/dataloaders/_semi_dataloader.py | 10 ++-- src/scvi/distributions/_beta_binomial.py | 10 ++-- src/scvi/external/cellassign/_module.py | 7 +-- .../_contrastive_data_splitting.py | 8 +-- .../contrastivevi/_contrastive_dataloader.py | 5 +- src/scvi/external/contrastivevi/_model.py | 3 +- src/scvi/external/contrastivevi/_module.py | 6 +- src/scvi/external/gimvi/_model.py | 4 +- src/scvi/external/gimvi/_module.py | 26 ++++----- src/scvi/external/gimvi/_utils.py | 12 ++-- src/scvi/external/mrvi/_components.py | 3 +- src/scvi/external/mrvi/_module.py | 3 +- src/scvi/external/mrvi/_types.py | 4 +- src/scvi/external/scar/_model.py | 14 +++-- src/scvi/external/scbasset/_module.py | 9 +-- src/scvi/external/tangram/_model.py | 20 +++---- src/scvi/external/tangram/_module.py | 4 +- src/scvi/model/_utils.py | 28 +++++----- src/scvi/model/base/_archesmixin.py | 13 ++--- src/scvi/model/base/_de_core.py | 5 +- src/scvi/model/base/_differential.py | 56 +++++++++---------- src/scvi/model/base/_log_likelihood.py | 4 +- src/scvi/model/base/_pyromixin.py | 2 +- src/scvi/module/_amortizedlda.py | 11 ++-- src/scvi/module/_autozivae.py | 18 +++--- src/scvi/module/_jaxvae.py | 10 ++-- src/scvi/module/_mrdeconv.py | 6 +- src/scvi/module/_multivae.py | 8 +-- src/scvi/module/_peakvae.py | 8 +-- src/scvi/module/_totalvae.py | 34 +++++------ src/scvi/module/_vae.py | 3 +- src/scvi/module/base/_base_module.py | 4 +- src/scvi/module/base/_decorators.py | 6 +- src/scvi/nn/_base_components.py | 12 ++-- src/scvi/nn/_embedding.py | 2 +- src/scvi/train/_callbacks.py | 2 +- src/scvi/train/_logger.py | 6 +- src/scvi/train/_trainer.py | 12 ++-- src/scvi/train/_trainingplans.py | 52 ++++++++--------- src/scvi/train/_trainrunner.py | 5 +- src/scvi/utils/_decorators.py | 2 +- src/scvi/utils/_dependencies.py | 2 +- src/scvi/utils/_jax.py | 2 +- tests/data/test_synthetic_iid.py | 4 +- tests/data/utils.py | 42 +++++++------- tests/dataloaders/sparse_utils.py | 8 +-- tests/external/scbasset/test_scbasset.py | 3 +- tests/external/tangram/test_tangram.py | 4 +- tests/model/base/test_base_model.py | 14 ++--- 84 files changed, 364 insertions(+), 401 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 508e6e8127..143ff4eacc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" cache: "pip" cache-dependency-path: "**/pyproject.toml" diff --git a/.github/workflows/build_image_base.yaml b/.github/workflows/build_image_base.yaml index 728030d9db..c3453da680 100644 --- a/.github/workflows/build_image_base.yaml +++ b/.github/workflows/build_image_base.yaml @@ -38,4 +38,4 @@ jobs: cache-from: type=registry,ref=ghcr.io/scverse/scvi-tools:buildcache cache-to: type=inline,ref=ghcr.io/scverse/scvi-tools:buildcache target: base - tags: ghcr.io/scverse/scvi-tools:py3.11-cu12-base + tags: ghcr.io/scverse/scvi-tools:py3.12-cu12-base diff --git a/.github/workflows/build_image_latest.yaml b/.github/workflows/build_image_latest.yaml index 7998be6edf..1b3760bc5f 100644 --- a/.github/workflows/build_image_latest.yaml +++ b/.github/workflows/build_image_latest.yaml @@ -60,6 +60,6 @@ jobs: cache-from: type=registry,ref=ghcr.io/scverse/scvi-tools:buildcache cache-to: type=inline,ref=ghcr.io/scverse/scvi-tools:buildcache target: build - tags: ghcr.io/scverse/scvi-tools:py3.11-cu12-${{ steps.build.outputs.version }}-${{ steps.build.outputs.dependencies }} + tags: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ steps.build.outputs.version }}-${{ steps.build.outputs.dependencies }} build-args: | DEPENDENCIES=${{ matrix.dependencies }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c4b0826b38..49bbf95230 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -112,7 +112,7 @@ jobs: - uses: actions/setup-python@v4 with: - python-version: "3.11" + python-version: "3.12" - run: pip install build @@ -158,6 +158,6 @@ jobs: cache-from: type=registry,ref=ghcr.io/scverse/scvi-tools:buildcache cache-to: type=inline,ref=ghcr.io/scverse/scvi-tools:buildcache target: build - tags: ghcr.io/scverse/scvi-tools:py3.11-cu12-${{ inputs.tag }}-${{ matrix.dependencies }},ghcr.io/scverse/scvi-tools:py3.11-cu12-stable-${{ matrix.dependencies }} + tags: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ inputs.tag }}-${{ matrix.dependencies }},ghcr.io/scverse/scvi-tools:py3.12-cu12-stable-${{ matrix.dependencies }} build-args: | DEPENDENCIES=${{ matrix.dependencies }} diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index 74765a22b8..bd48c9dd75 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: ["3.9", "3.10", "3.11"] + python: ["3.10", "3.11", "3.12"] name: integration diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index e6d979e38e..492fbffeba 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -30,7 +30,7 @@ jobs: shell: bash -e {0} # -e to fail on error container: - image: ghcr.io/scverse/scvi-tools:py3.11-cu12-base + image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base options: --user root --gpus all name: integration diff --git a/.github/workflows/test_linux_private.yml b/.github/workflows/test_linux_private.yml index 55a58b3923..a3fe2e9fcb 100644 --- a/.github/workflows/test_linux_private.yml +++ b/.github/workflows/test_linux_private.yml @@ -34,7 +34,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: ["3.11"] + python: ["3.12"] permissions: id-token: write diff --git a/.github/workflows/test_linux_resolution.yml b/.github/workflows/test_linux_resolution.yml index ff7c18b045..d41447365c 100644 --- a/.github/workflows/test_linux_resolution.yml +++ b/.github/workflows/test_linux_resolution.yml @@ -33,7 +33,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: ["3.9", "3.10", "3.11", "3.12"] + python: ["3.10", "3.11", "3.12"] install-flags: [ "--prerelease if-necessary-or-explicit", diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml index d2da65edf2..ce798f6144 100644 --- a/.github/workflows/test_macos.yml +++ b/.github/workflows/test_macos.yml @@ -34,7 +34,7 @@ jobs: fail-fast: false matrix: os: [macos-14] - python: ["3.9", "3.10", "3.11"] + python: ["3.10", "3.11", "3.12"] name: integration diff --git a/.github/workflows/test_windows.yml b/.github/workflows/test_windows.yml index 6c6cda6a54..c7aa7b5ac5 100644 --- a/.github/workflows/test_windows.yml +++ b/.github/workflows/test_windows.yml @@ -34,7 +34,7 @@ jobs: fail-fast: false matrix: os: [windows-latest] - python: ["3.9", "3.10", "3.11"] + python: ["3.10", "3.11", "3.12"] name: integration diff --git a/CHANGELOG.md b/CHANGELOG.md index 277817beca..4a4c22dc63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- Add support for Python 3.12 {pr}`2966`. - Add support for categorial covariates in scArches in `scvi.model.archesmixin` {pr}`2936`. - Add assertion error in cellAssign for checking duplicates in celltype markers {pr}`2951`. - Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`. diff --git a/docs/installation.md b/docs/installation.md index 05cb8a43a0..b5dfb0450a 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -23,14 +23,14 @@ Don't know how to get started with virtual environments or `conda`/`pip`? Check ### Virtual environment A virtual environment can be created with either `conda` or `venv`. We recommend using `conda`. We -currently support Python 3.9 - 3.11. +currently support Python 3.10 - 3.12. For `conda`, we recommend using the [Miniforge](https://github.com/conda-forge/miniforge) distribution, which is generally faster than the official distribution and comes with conda-forge as the default channel (where scvi-tools is hosted). ```bash -conda create -n scvi-env python=3.11 # any python 3.9 to 3.11 +conda create -n scvi-env python=3.12 # any python 3.10 to 3.12 conda activate scvi-env ``` diff --git a/pyproject.toml b/pyproject.toml index 33617ccde3..8f50db143d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,13 +7,13 @@ name = "scvi-tools" version = "1.1.6" description = "Deep probabilistic analysis of single-cell omics data." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = {file = "LICENSE"} authors = [ {name = "The scvi-tools development team"}, ] maintainers = [ - {name = "The scvi-tools development team", email = "martinkim@berkeley.edu"}, + {name = "The scvi-tools development team", email = "ori.kronfeld@weizmann.ac.il"}, ] urls.Documentation = "https://scvi-tools.org" urls.Source = "https://github.com/scverse/scvi-tools" @@ -22,9 +22,9 @@ classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", "Natural Language :: English", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", @@ -137,7 +137,7 @@ markers = [ src = ["src"] line-length = 99 indent-width = 4 -target-version = "py39" +target-version = "py310" # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/src/scvi/_settings.py b/src/scvi/_settings.py index bfdc9cbc2d..25979de587 100644 --- a/src/scvi/_settings.py +++ b/src/scvi/_settings.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal import torch from lightning.pytorch import seed_everything @@ -47,7 +47,7 @@ def __init__( verbosity: int = logging.INFO, progress_bar_style: Literal["rich", "tqdm"] = "tqdm", batch_size: int = 128, - seed: Optional[int] = None, + seed: int | None = None, logging_dir: str = "./scvi_log/", dl_num_workers: int = 0, dl_persistent_workers: bool = False, @@ -111,7 +111,7 @@ def logging_dir(self) -> Path: return self._logging_dir @logging_dir.setter - def logging_dir(self, logging_dir: Union[str, Path]): + def logging_dir(self, logging_dir: str | Path): self._logging_dir = Path(logging_dir).resolve() @property @@ -141,7 +141,7 @@ def seed(self) -> int: return self._seed @seed.setter - def seed(self, seed: Union[int, None] = None): + def seed(self, seed: int | None = None): """Random seed for torch and numpy.""" if seed is None: self._seed = None @@ -162,7 +162,7 @@ def verbosity(self) -> int: return self._verbosity @verbosity.setter - def verbosity(self, level: Union[str, int]): + def verbosity(self, level: str | int): """Sets logging configuration for scvi based on chosen level of verbosity. If "scvi" logger has no StreamHandler, add one. @@ -220,7 +220,7 @@ def jax_preallocate_gpu_memory(self): return self._jax_gpu @jax_preallocate_gpu_memory.setter - def jax_preallocate_gpu_memory(self, value: Union[float, bool]): + def jax_preallocate_gpu_memory(self, value: float | bool): # see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation if value is False: os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" diff --git a/src/scvi/_types.py b/src/scvi/_types.py index 0233283e36..5335df8d40 100644 --- a/src/scvi/_types.py +++ b/src/scvi/_types.py @@ -1,13 +1,15 @@ -from typing import Literal, Union +from __future__ import annotations + +from typing import Literal import anndata import jax.numpy as jnp import mudata import torch -Number = Union[int, float] -AnnOrMuData = Union[anndata.AnnData, mudata.MuData] -Tensor = Union[torch.Tensor, jnp.ndarray] -LossRecord = Union[dict[str, Tensor], Tensor] +Number = int | float +AnnOrMuData = anndata.AnnData | mudata.MuData +Tensor = torch.Tensor | jnp.ndarray +LossRecord = dict[str, Tensor] | Tensor # TODO(adamgayoso): Add constants for minified data types. MinifiedDataType = Literal["latent_posterior_parameters"] diff --git a/src/scvi/autotune/_experiment.py b/src/scvi/autotune/_experiment.py index ed5572572e..4b59536b14 100644 --- a/src/scvi/autotune/_experiment.py +++ b/src/scvi/autotune/_experiment.py @@ -160,7 +160,7 @@ def data(self, value: AnnOrMuData | LightningDataModule) -> None: raise AttributeError("Cannot reassign `data`") self._data = value - if isinstance(value, (AnnData, MuData)): + if isinstance(value, AnnData | MuData): data_manager = self.model_cls._get_most_recent_anndata_manager(value, required=True) self._setup_method_name = data_manager._registry.get( _SETUP_METHOD_NAME, "setup_anndata" @@ -537,7 +537,7 @@ def _trainable( } settings.seed = experiment.seed - if isinstance(experiment.data, (AnnData, MuData)): + if isinstance(experiment.data, AnnData | MuData): getattr(experiment.model_cls, experiment.setup_method_name)( experiment.data, **experiment.setup_method_args, diff --git a/src/scvi/data/_anntorchdataset.py b/src/scvi/data/_anntorchdataset.py index 3374c5016e..f0725efb13 100644 --- a/src/scvi/data/_anntorchdataset.py +++ b/src/scvi/data/_anntorchdataset.py @@ -133,7 +133,7 @@ def __getitem__( if isinstance(indexes, int): indexes = [indexes] # force batched single observations - if self.adata_manager.adata.isbacked and isinstance(indexes, (list, np.ndarray)): + if self.adata_manager.adata.isbacked and isinstance(indexes, list | np.ndarray): # need to sort indexes for h5py datasets indexes = np.sort(indexes) @@ -142,7 +142,7 @@ def __getitem__( for key, dtype in self.keys_and_dtypes.items(): data = self.data[key] - if isinstance(data, (np.ndarray, h5py.Dataset)): + if isinstance(data, np.ndarray | h5py.Dataset): sliced_data = data[indexes].astype(dtype, copy=False) elif isinstance(data, pd.DataFrame): sliced_data = data.iloc[indexes, :].to_numpy().astype(dtype, copy=False) diff --git a/src/scvi/data/_built_in_data/_cellxgene.py b/src/scvi/data/_built_in_data/_cellxgene.py index ec4c8fc0eb..7027d8820d 100644 --- a/src/scvi/data/_built_in_data/_cellxgene.py +++ b/src/scvi/data/_built_in_data/_cellxgene.py @@ -1,6 +1,5 @@ import os import re -from typing import Optional, Union from anndata import AnnData, read_h5ad @@ -18,10 +17,10 @@ def _parse_dataset_id(url: str): @dependencies("cellxgene_census") def _load_cellxgene_dataset( url: str, - filename: Optional[str] = None, + filename: str | None = None, save_path: str = "data/", return_path: bool = False, -) -> Union[AnnData, str]: +) -> AnnData | str: """Loads a file from `cellxgene `_ portal. Parameters diff --git a/src/scvi/data/_built_in_data/_pbmc.py b/src/scvi/data/_built_in_data/_pbmc.py index 69f0f02208..b67a16b250 100644 --- a/src/scvi/data/_built_in_data/_pbmc.py +++ b/src/scvi/data/_built_in_data/_pbmc.py @@ -75,7 +75,7 @@ def _load_pbmc_dataset( adata = pbmc8k.concatenate(pbmc4k) adata.obs_names = barcodes - dict_barcodes = dict(zip(barcodes, np.arange(len(barcodes)))) + dict_barcodes = dict(zip(barcodes, np.arange(len(barcodes)), strict=True)) subset_cells = [] barcodes_metadata = pbmc_metadata["barcodes"].index.values.ravel().astype(str) for barcode in barcodes_metadata: diff --git a/src/scvi/data/_built_in_data/_synthetic.py b/src/scvi/data/_built_in_data/_synthetic.py index b2a4bede0c..714e8111f5 100644 --- a/src/scvi/data/_built_in_data/_synthetic.py +++ b/src/scvi/data/_built_in_data/_synthetic.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import numpy as np import pandas as pd @@ -21,7 +20,7 @@ def _generate_synthetic( n_batches: int, n_labels: int, dropout_ratio: float, - sparse_format: Optional[str], + sparse_format: str | None, generate_coordinates: bool, return_mudata: bool, batch_key: str = "batch", diff --git a/src/scvi/data/_compat.py b/src/scvi/data/_compat.py index 1fa4d55b37..7c89609fb9 100644 --- a/src/scvi/data/_compat.py +++ b/src/scvi/data/_compat.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import Optional import numpy as np @@ -26,7 +25,7 @@ } -def _infer_setup_args(model_cls, setup_dict: dict, unlabeled_category: Optional[str]) -> dict: +def _infer_setup_args(model_cls, setup_dict: dict, unlabeled_category: str | None) -> dict: setup_args = {} data_registry = setup_dict[_constants._DATA_REGISTRY_KEY] categorical_mappings = setup_dict["categorical_mappings"] @@ -73,7 +72,7 @@ def _infer_setup_args(model_cls, setup_dict: dict, unlabeled_category: Optional[ def registry_from_setup_dict( - model_cls, setup_dict: dict, unlabeled_category: Optional[str] = None + model_cls, setup_dict: dict, unlabeled_category: str | None = None ) -> dict: """Converts old setup dict format to new registry dict format. diff --git a/src/scvi/data/_download.py b/src/scvi/data/_download.py index b4e4f5dee7..8bc7cc657d 100644 --- a/src/scvi/data/_download.py +++ b/src/scvi/data/_download.py @@ -1,7 +1,6 @@ import logging import os import urllib -from typing import Optional import numpy as np @@ -10,7 +9,7 @@ logger = logging.getLogger(__name__) -def _download(url: Optional[str], save_path: str, filename: str): +def _download(url: str | None, save_path: str, filename: str): """Writes data from url to file.""" if os.path.exists(os.path.join(save_path, filename)): logger.info(f"File {os.path.join(save_path, filename)} already downloaded") diff --git a/src/scvi/data/_preprocessing.py b/src/scvi/data/_preprocessing.py index 172fef4309..2cf251a3d9 100644 --- a/src/scvi/data/_preprocessing.py +++ b/src/scvi/data/_preprocessing.py @@ -1,7 +1,6 @@ import logging import tempfile from pathlib import Path -from typing import Optional, Union import anndata import numpy as np @@ -22,17 +21,17 @@ @devices_dsp.dedent def poisson_gene_selection( adata, - layer: Optional[str] = None, + layer: str | None = None, n_top_genes: int = 4000, accelerator: str = "auto", - device: Union[int, str] = "auto", + device: int | str = "auto", subset: bool = False, inplace: bool = True, n_samples: int = 10000, batch_key: str = None, silent: bool = False, minibatch_size: int = 5000, -) -> Optional[pd.DataFrame]: +) -> pd.DataFrame | None: """Rank and select genes based on the enrichment of zero counts. Enrichment is considered by comparing data to a Poisson count model. @@ -231,7 +230,7 @@ def poisson_gene_selection( return df -def organize_cite_seq_10x(adata: anndata.AnnData, copy: bool = False) -> Optional[anndata.AnnData]: +def organize_cite_seq_10x(adata: anndata.AnnData, copy: bool = False) -> anndata.AnnData | None: """Organize anndata object loaded from 10x for scvi models. Parameters @@ -277,8 +276,8 @@ def organize_cite_seq_10x(adata: anndata.AnnData, copy: bool = False) -> Optiona def organize_multiome_anndatas( multi_anndata: anndata.AnnData, - rna_anndata: Optional[anndata.AnnData] = None, - atac_anndata: Optional[anndata.AnnData] = None, + rna_anndata: anndata.AnnData | None = None, + atac_anndata: anndata.AnnData | None = None, modality_key: str = "modality", ) -> anndata.AnnData: """Concatenate multiome and single-modality input anndata objects. @@ -362,8 +361,8 @@ def add_dna_sequence( adata: anndata.AnnData, seq_len: int = 1344, genome_name: str = "hg38", - genome_dir: Optional[Path] = None, - genome_provider: Optional[str] = None, + genome_dir: Path | None = None, + genome_provider: str | None = None, install_genome: bool = True, chr_var_key: str = "chr", start_var_key: str = "start", @@ -432,7 +431,7 @@ def add_dna_sequence( block_ends = block_starts + seq_len seqs = [] - for start, end in zip(block_starts, block_ends - 1): + for start, end in zip(block_starts, block_ends - 1, strict=True): seq = str(g.get_seq(chrom, start, end)).upper() seqs.append(list(seq)) @@ -446,7 +445,7 @@ def add_dna_sequence( def reads_to_fragments( adata: anndata.AnnData, - read_layer: Optional[str] = None, + read_layer: str | None = None, fragment_layer: str = "fragments", ) -> None: """Convert scATAC-seq read counts to appoximate fragment counts. diff --git a/src/scvi/data/_read.py b/src/scvi/data/_read.py index 4947fd975d..c8b4df6f9c 100644 --- a/src/scvi/data/_read.py +++ b/src/scvi/data/_read.py @@ -1,13 +1,12 @@ import os from pathlib import Path -from typing import Union import pandas as pd from anndata import AnnData from scipy.io import mmread -def read_10x_atac(base_path: Union[str, Path]) -> AnnData: +def read_10x_atac(base_path: str | Path) -> AnnData: """Read scATAC-seq data outputted by 10x Genomics software. Parameters @@ -39,7 +38,7 @@ def read_10x_atac(base_path: Union[str, Path]) -> AnnData: return AnnData(data.tocsr(), var=coords, obs=cell_annot) -def read_10x_multiome(base_path: Union[str, Path]) -> AnnData: +def read_10x_multiome(base_path: str | Path) -> AnnData: """Read Multiome (scRNA + scATAC) data outputted by 10x Genomics software. Parameters diff --git a/src/scvi/data/fields/_arraylike_field.py b/src/scvi/data/fields/_arraylike_field.py index 4326c6e7f8..4e02751d8c 100644 --- a/src/scvi/data/fields/_arraylike_field.py +++ b/src/scvi/data/fields/_arraylike_field.py @@ -1,6 +1,6 @@ import logging import warnings -from typing import Literal, Optional, Union +from typing import Literal import numpy as np import pandas as pd @@ -75,7 +75,7 @@ def __init__( registry_key: str, attr_key: str, field_type: Literal["obsm", "varm"] = None, - colnames_uns_key: Optional[str] = None, + colnames_uns_key: str | None = None, is_count_data: bool = False, correct_data_format: bool = True, ) -> None: @@ -117,7 +117,7 @@ def validate_field(self, adata: AnnData) -> None: stacklevel=settings.warnings_stacklevel, ) - def _setup_column_names(self, adata: AnnData) -> Union[list, np.ndarray]: + def _setup_column_names(self, adata: AnnData) -> list | np.ndarray: """Returns a list or ndarray of column names that will be used for the relevant .obsm data. If the ``colnames_uns_key`` was specified, then the columns stored in that @@ -175,7 +175,7 @@ def get_summary_stats(self, state_registry: dict) -> dict: n_array_cols = len(state_registry[self.COLUMN_NAMES_KEY]) return {self.count_stat_key: n_array_cols} - def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, state_registry: dict) -> rich.table.Table | None: """View the state registry.""" return None @@ -217,7 +217,7 @@ class BaseJointField(BaseArrayLikeField): def __init__( self, registry_key: str, - attr_keys: Optional[list[str]], + attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, ) -> None: super().__init__(registry_key) @@ -289,7 +289,7 @@ class NumericalJointField(BaseJointField): def __init__( self, registry_key: str, - attr_keys: Optional[list[str]], + attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, ) -> None: super().__init__(registry_key, attr_keys, field_type=field_type) @@ -317,7 +317,7 @@ def get_summary_stats(self, _state_registry: dict) -> dict: n_keys = len(self.attr_keys) return {self.count_stat_key: n_keys} - def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, state_registry: dict) -> rich.table.Table | None: """View the state registry.""" if self.is_empty: return None @@ -376,7 +376,7 @@ class CategoricalJointField(BaseJointField): def __init__( self, registry_key: str, - attr_keys: Optional[list[str]], + attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, ) -> None: super().__init__(registry_key, attr_keys, field_type=field_type) @@ -390,7 +390,7 @@ def _default_mappings_dict(self) -> dict: } def _make_array_categorical( - self, adata: AnnData, category_dict: Optional[dict[str, list[str]]] = None + self, adata: AnnData, category_dict: dict[str, list[str]] | None = None ) -> dict: """Make the .obsm categorical.""" if self.attr_keys != getattr(adata, self.attr_name)[self.attr_key].columns.tolist(): @@ -458,7 +458,7 @@ def get_summary_stats(self, _state_registry: dict) -> dict: self.count_stat_key: n_keys, } - def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, state_registry: dict) -> rich.table.Table | None: """View the state registry.""" if self.is_empty: return None diff --git a/src/scvi/data/fields/_base_field.py b/src/scvi/data/fields/_base_field.py index 920c4baa73..5f69d8ac31 100644 --- a/src/scvi/data/fields/_base_field.py +++ b/src/scvi/data/fields/_base_field.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional, Union import numpy as np import pandas as pd @@ -29,11 +28,11 @@ def attr_name(self) -> str: @property @abstractmethod - def attr_key(self) -> Optional[str]: + def attr_key(self) -> str | None: """The key of the data field within the relevant AnnData attribute.""" @property - def mod_key(self) -> Optional[str]: + def mod_key(self) -> str | None: """The modality key of the data field within the MuData (if applicable).""" return None @@ -108,7 +107,7 @@ def get_summary_stats(self, state_registry: dict) -> dict: """ @abstractmethod - def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, state_registry: dict) -> rich.table.Table | None: """Returns a :class:`rich.table.Table` summarizing a state registry produced by this field. Parameters @@ -123,7 +122,7 @@ def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table Optional :class:`rich.table.Table` summarizing the ``state_registry``. """ - def get_field_data(self, adata: AnnOrMuData) -> Union[np.ndarray, pd.DataFrame]: + def get_field_data(self, adata: AnnOrMuData) -> np.ndarray | pd.DataFrame: """Returns the requested data as determined by the field.""" if self.is_empty: raise AssertionError(f"The {self.registry_key} field is empty.") diff --git a/src/scvi/data/fields/_dataframe_field.py b/src/scvi/data/fields/_dataframe_field.py index bd5c691e0c..633f3e8673 100644 --- a/src/scvi/data/fields/_dataframe_field.py +++ b/src/scvi/data/fields/_dataframe_field.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Optional +from typing import Literal import numpy as np import rich @@ -34,7 +34,7 @@ class BaseDataFrameField(BaseAnnDataField): def __init__( self, registry_key: str, - attr_key: Optional[str], + attr_key: str | None, field_type: Literal["obs", "var"] = None, required: bool = True, ) -> None: @@ -102,7 +102,7 @@ def get_summary_stats(self, _state_registry: dict) -> dict: """Get summary stats.""" return {} - def view_state_registry(self, _state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, _state_registry: dict) -> rich.table.Table | None: """View state registry.""" return None @@ -145,7 +145,7 @@ class CategoricalDataFrameField(BaseDataFrameField): def __init__( self, registry_key: str, - attr_key: Optional[str], + attr_key: str | None, field_type: Literal["obs", "var"] = None, ) -> None: self.is_default = attr_key is None @@ -238,7 +238,7 @@ def get_summary_stats(self, state_registry: dict) -> dict: n_categories = len(np.unique(categorical_mapping)) return {self.count_stat_key: n_categories} - def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, state_registry: dict) -> rich.table.Table | None: """View state registry.""" source_key = state_registry[self.ORIGINAL_ATTR_KEY] mapping = state_registry[self.CATEGORICAL_MAPPING_KEY] diff --git a/src/scvi/data/fields/_layer_field.py b/src/scvi/data/fields/_layer_field.py index 7ff36bc24f..41daa00848 100644 --- a/src/scvi/data/fields/_layer_field.py +++ b/src/scvi/data/fields/_layer_field.py @@ -1,5 +1,4 @@ import warnings -from typing import Optional import numpy as np import rich @@ -43,7 +42,7 @@ class LayerField(BaseAnnDataField): def __init__( self, registry_key: str, - layer: Optional[str], + layer: str | None, is_count_data: bool = True, correct_data_format: bool = True, check_fragment_counts: bool = False, @@ -72,7 +71,7 @@ def attr_name(self) -> str: return self._attr_name @property - def attr_key(self) -> Optional[str]: + def attr_key(self) -> str | None: return self._attr_key @property @@ -140,7 +139,7 @@ def get_summary_stats(self, state_registry: dict) -> dict: summary_stats[self.N_CELLS_KEY] = state_registry[self.N_OBS_KEY] return summary_stats - def view_state_registry(self, _state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, _state_registry: dict) -> rich.table.Table | None: """View the state registry.""" return None diff --git a/src/scvi/data/fields/_mudata.py b/src/scvi/data/fields/_mudata.py index c6c3b66e86..c16e4f4c0e 100644 --- a/src/scvi/data/fields/_mudata.py +++ b/src/scvi/data/fields/_mudata.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from collections.abc import Callable import rich from anndata import AnnData @@ -24,7 +24,7 @@ class BaseMuDataWrapperClass(BaseAnnDataField): If ``True``, raises ``ValueError`` when ``mod_key`` is ``None``. """ - def __init__(self, mod_key: Optional[str] = None, mod_required: bool = False) -> None: + def __init__(self, mod_key: str | None = None, mod_required: bool = False) -> None: super().__init__() if mod_required and mod_key is None: raise ValueError(f"Modality required for {self.__class__.__name__} but not provided.") @@ -42,7 +42,7 @@ def registry_key(self) -> str: return self.adata_field.registry_key @property - def mod_key(self) -> Optional[str]: + def mod_key(self) -> str | None: """The modality key of the data field within the MuData (if applicable).""" return self._mod_key @@ -52,7 +52,7 @@ def attr_name(self) -> str: return self.adata_field.attr_name @property - def attr_key(self) -> Optional[str]: + def attr_key(self) -> str | None: """The key of the data field within the relevant AnnData/MuData attribute.""" return self.adata_field.attr_key @@ -104,13 +104,13 @@ def get_summary_stats(self, state_registry: dict) -> dict: """Get summary stats.""" return self.adata_field.get_summary_stats(state_registry) - def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, state_registry: dict) -> rich.table.Table | None: """View the state registry.""" return self.adata_field.view_state_registry(state_registry) def MuDataWrapper( - adata_field_cls: AnnDataField, preregister_fn: Optional[Callable] = None + adata_field_cls: AnnDataField, preregister_fn: Callable | None = None ) -> AnnDataField: """Wraps an AnnDataField with :class:`~scvi.data.fields.BaseMuDataWrapperClass`. @@ -127,7 +127,7 @@ def MuDataWrapper( raise ValueError("`adata_field_cls` must be a class, not an instance.") def mudata_field_init( - self, *args, mod_key: Optional[str] = None, mod_required: bool = False, **kwargs + self, *args, mod_key: str | None = None, mod_required: bool = False, **kwargs ): BaseMuDataWrapperClass.__init__(self, mod_key=mod_key, mod_required=mod_required) self._adata_field = adata_field_cls(*args, **kwargs) diff --git a/src/scvi/data/fields/_protein.py b/src/scvi/data/fields/_protein.py index 5000f93494..f39984d408 100644 --- a/src/scvi/data/fields/_protein.py +++ b/src/scvi/data/fields/_protein.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import numpy as np import pandas as pd @@ -35,7 +34,7 @@ def __init__( self, *base_field_args, use_batch_mask: bool = True, - batch_field: Optional[str] = None, + batch_field: str | None = None, **base_field_kwargs, ) -> None: if use_batch_mask and batch_field is None: @@ -50,7 +49,7 @@ def __init__( **base_field_kwargs, ) - def _get_batch_mask_protein_data(self, adata: AnnData) -> Optional[dict]: + def _get_batch_mask_protein_data(self, adata: AnnData) -> dict | None: """Returns a dict with length number of batches where each entry is a mask. The mask is over cell measurement columns that are present (observed) diff --git a/src/scvi/data/fields/_scanvi.py b/src/scvi/data/fields/_scanvi.py index 65ef6357fd..de264179fe 100644 --- a/src/scvi/data/fields/_scanvi.py +++ b/src/scvi/data/fields/_scanvi.py @@ -1,5 +1,4 @@ import warnings -from typing import Optional, Union import numpy as np from anndata import AnnData @@ -32,8 +31,8 @@ class LabelsWithUnlabeledObsField(CategoricalObsField): def __init__( self, registry_key: str, - obs_key: Optional[str], - unlabeled_category: Union[str, int, float], + obs_key: str | None, + unlabeled_category: str | int | float, ) -> None: super().__init__(registry_key, obs_key) self._unlabeled_category = unlabeled_category diff --git a/src/scvi/data/fields/_uns_field.py b/src/scvi/data/fields/_uns_field.py index 799657caa9..65f4bd601b 100644 --- a/src/scvi/data/fields/_uns_field.py +++ b/src/scvi/data/fields/_uns_field.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import rich from anndata import AnnData @@ -26,7 +25,7 @@ class BaseUnsField(BaseAnnDataField): _attr_name = _constants._ADATA_ATTRS.UNS - def __init__(self, registry_key: str, uns_key: Optional[str], required: bool = True) -> None: + def __init__(self, registry_key: str, uns_key: str | None, required: bool = True) -> None: super().__init__() if required and uns_key is None: raise ValueError( @@ -80,6 +79,6 @@ def get_summary_stats(self, _state_registry: dict) -> dict: """Get summary stats.""" return {} - def view_state_registry(self, _state_registry: dict) -> Optional[rich.table.Table]: + def view_state_registry(self, _state_registry: dict) -> rich.table.Table | None: """View the state registry.""" return None diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index d10f8156de..27e17302d5 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -1,6 +1,5 @@ import copy import logging -from typing import Optional, Union import numpy as np from torch.utils.data import ( @@ -75,13 +74,13 @@ class AnnDataLoader(DataLoader): def __init__( self, adata_manager: AnnDataManager, - indices: Optional[Union[list[int], list[bool]]] = None, + indices: list[int] | list[bool] | None = None, batch_size: int = 128, shuffle: bool = False, - sampler: Optional[Sampler] = None, + sampler: Sampler | None = None, drop_last: bool = False, drop_dataset_tail: bool = False, - data_and_attributes: Optional[Union[list[str], dict[str, np.dtype]]] = None, + data_and_attributes: list[str] | dict[str, np.dtype] | None = None, iter_ndarray: bool = False, distributed_sampler: bool = False, load_sparse_tensor: bool = False, diff --git a/src/scvi/dataloaders/_concat_dataloader.py b/src/scvi/dataloaders/_concat_dataloader.py index dcb41b0a80..9aa9071a85 100644 --- a/src/scvi/dataloaders/_concat_dataloader.py +++ b/src/scvi/dataloaders/_concat_dataloader.py @@ -1,5 +1,4 @@ from itertools import cycle -from typing import Optional, Union import numpy as np from torch.utils.data import DataLoader @@ -36,8 +35,8 @@ def __init__( indices_list: list[list[int]], shuffle: bool = False, batch_size: int = 128, - data_and_attributes: Optional[dict] = None, - drop_last: Union[bool, int] = False, + data_and_attributes: dict | None = None, + drop_last: bool | int = False, **data_loader_kwargs, ): self.adata_manager = adata_manager @@ -75,4 +74,4 @@ def __iter__(self): is the same as indices_list. """ iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders] - return zip(*iter_list) + return zip(*iter_list, strict=True) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 7bc746b703..9530eb3975 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -1,6 +1,5 @@ import warnings from math import ceil, floor -from typing import Optional, Union import lightning.pytorch as pl import numpy as np @@ -22,9 +21,7 @@ from scvi.utils._docstrings import devices_dsp -def validate_data_split( - n_samples: int, train_size: float, validation_size: Optional[float] = None -): +def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None): """Check data splitting parameters and return n_train and n_val. Parameters @@ -71,7 +68,7 @@ def validate_data_split( def validate_data_split_with_external_indexing( n_samples: int, - external_indexing: Optional[list[np.array, np.array, np.array]] = None, + external_indexing: list[np.array, np.array, np.array] | None = None, ): """Check data splitting parameters and return n_train and n_val. @@ -186,11 +183,11 @@ def __init__( self, adata_manager: AnnDataManager, train_size: float = 0.9, - validation_size: Optional[float] = None, + validation_size: float | None = None, shuffle_set_split: bool = True, load_sparse_tensor: bool = False, pin_memory: bool = False, - external_indexing: Optional[list[np.array, np.array, np.array]] = None, + external_indexing: list[np.array, np.array, np.array] | None = None, **kwargs, ): super().__init__() @@ -214,7 +211,7 @@ def __init__( self.adata_manager.adata.n_obs, self.train_size, self.validation_size ) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str | None = None): """Split indices in train/test/val sets.""" if self.external_indexing is not None: # The structure and its order are guaranteed at this stage @@ -337,11 +334,11 @@ def __init__( self, adata_manager: AnnDataManager, train_size: float = 0.9, - validation_size: Optional[float] = None, + validation_size: float | None = None, shuffle_set_split: bool = True, - n_samples_per_label: Optional[int] = None, + n_samples_per_label: int | None = None, pin_memory: bool = False, - external_indexing: Optional[list[np.array, np.array, np.array]] = None, + external_indexing: list[np.array, np.array, np.array] | None = None, **kwargs, ): super().__init__() @@ -366,7 +363,7 @@ def __init__( self.pin_memory = pin_memory self.external_indexing = external_indexing - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str | None = None): """Split indices in train/test/val sets.""" n_labeled_idx = len(self._labeled_indices) n_unlabeled_idx = len(self._unlabeled_indices) @@ -540,13 +537,13 @@ def __init__( self, adata_manager: AnnDataManager, train_size: float = 1.0, - validation_size: Optional[float] = None, + validation_size: float | None = None, accelerator: str = "auto", - device: Union[int, str] = "auto", + device: int | str = "auto", pin_memory: bool = False, shuffle: bool = False, shuffle_test_val: bool = False, - batch_size: Optional[int] = None, + batch_size: int | None = None, **kwargs, ): super().__init__( @@ -563,7 +560,7 @@ def __init__( accelerator=accelerator, devices=device, return_device="torch" ) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str | None = None): """Create the train, validation, and test indices.""" super().setup() diff --git a/src/scvi/dataloaders/_semi_dataloader.py b/src/scvi/dataloaders/_semi_dataloader.py index 5a7c79c3c2..545f3c6d9b 100644 --- a/src/scvi/dataloaders/_semi_dataloader.py +++ b/src/scvi/dataloaders/_semi_dataloader.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import numpy as np from scvi import REGISTRY_KEYS @@ -37,12 +35,12 @@ class SemiSupervisedDataLoader(ConcatDataLoader): def __init__( self, adata_manager: AnnDataManager, - n_samples_per_label: Optional[int] = None, - indices: Optional[list[int]] = None, + n_samples_per_label: int | None = None, + indices: list[int] | None = None, shuffle: bool = False, batch_size: int = 128, - data_and_attributes: Optional[dict] = None, - drop_last: Union[bool, int] = False, + data_and_attributes: dict | None = None, + drop_last: bool | int = False, **data_loader_kwargs, ): adata = adata_manager.adata diff --git a/src/scvi/distributions/_beta_binomial.py b/src/scvi/distributions/_beta_binomial.py index 2cce66b4a8..c388523d71 100644 --- a/src/scvi/distributions/_beta_binomial.py +++ b/src/scvi/distributions/_beta_binomial.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from pyro.distributions import BetaBinomial as BetaBinomialDistribution from torch.distributions import constraints @@ -71,10 +69,10 @@ class BetaBinomial(BetaBinomialDistribution): def __init__( self, total_count: torch.Tensor, - alpha: Optional[torch.Tensor] = None, - beta: Optional[torch.Tensor] = None, - mu: Optional[torch.Tensor] = None, - gamma: Optional[torch.Tensor] = None, + alpha: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + mu: torch.Tensor | None = None, + gamma: torch.Tensor | None = None, validate_args: bool = False, eps: float = 1e-8, ): diff --git a/src/scvi/external/cellassign/_module.py b/src/scvi/external/cellassign/_module.py index 8b91e19d4c..1f4b50ca8f 100644 --- a/src/scvi/external/cellassign/_module.py +++ b/src/scvi/external/cellassign/_module.py @@ -1,5 +1,4 @@ from collections.abc import Iterable -from typing import Optional import torch import torch.nn.functional as F @@ -46,10 +45,10 @@ def __init__( n_genes: int, rho: torch.Tensor, basis_means: torch.Tensor, - b_g_0: Optional[torch.Tensor] = None, + b_g_0: torch.Tensor | None = None, random_b_g_0: bool = True, n_batch: int = 0, - n_cats_per_cov: Optional[Iterable[int]] = None, + n_cats_per_cov: Iterable[int] | None = None, n_continuous_cov: int = 0, ): super().__init__() @@ -122,7 +121,7 @@ def _get_generative_input(self, tensors, inference_outputs): cat_key = REGISTRY_KEYS.CAT_COVS_KEY if cat_key in tensors.keys(): for cat_input, n_cat in zip( - torch.split(tensors[cat_key], 1, dim=1), self.n_cats_per_cov + torch.split(tensors[cat_key], 1, dim=1), self.n_cats_per_cov, strict=True ): to_cat.append(F.one_hot(cat_input.squeeze(-1), n_cat)) diff --git a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py index dc2827f5e0..659a91f406 100644 --- a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py +++ b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py @@ -1,5 +1,3 @@ -from typing import Optional - import numpy as np from scvi import settings @@ -56,11 +54,11 @@ def __init__( background_indices: list[int], target_indices: list[int], train_size: float = 0.9, - validation_size: Optional[float] = None, + validation_size: float | None = None, shuffle_set_split: bool = True, load_sparse_tensor: bool = False, pin_memory: bool = False, - external_indexing: Optional[list[np.array, np.array, np.array]] = None, + external_indexing: list[np.array, np.array, np.array] | None = None, **kwargs, ) -> None: super().__init__( @@ -119,7 +117,7 @@ def __init__( self.n_train = self.n_background_train + self.n_target_train self.n_val = self.n_background_val + self.n_target_val - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str | None = None): """Split background and target indices into train/val/test sets.""" background_indices = self.background_indices n_background_train = self.n_background_train diff --git a/src/scvi/external/contrastivevi/_contrastive_dataloader.py b/src/scvi/external/contrastivevi/_contrastive_dataloader.py index 0c88ee10a1..5fd972d324 100644 --- a/src/scvi/external/contrastivevi/_contrastive_dataloader.py +++ b/src/scvi/external/contrastivevi/_contrastive_dataloader.py @@ -1,6 +1,5 @@ import warnings from itertools import cycle -from typing import Optional, Union from scvi import settings from scvi.data import AnnDataManager @@ -74,8 +73,8 @@ def __init__( target_indices: list[int], shuffle: bool = False, batch_size: int = 128, - data_and_attributes: Optional[dict] = None, - drop_last: Union[bool, int] = False, + data_and_attributes: dict | None = None, + drop_last: bool | int = False, distributed_sampler: bool = False, load_sparse_tensor: bool = False, **data_loader_kwargs, diff --git a/src/scvi/external/contrastivevi/_model.py b/src/scvi/external/contrastivevi/_model.py index 35d3e01e74..1eba0c40cc 100644 --- a/src/scvi/external/contrastivevi/_model.py +++ b/src/scvi/external/contrastivevi/_model.py @@ -6,7 +6,6 @@ import warnings from collections.abc import Iterable, Sequence from functools import partial -from typing import Union import numpy as np import pandas as pd @@ -40,7 +39,7 @@ from ._module import ContrastiveVAE logger = logging.getLogger(__name__) -Number = Union[int, float] +Number = int | float class ContrastiveVI(BaseModelClass): diff --git a/src/scvi/external/contrastivevi/_module.py b/src/scvi/external/contrastivevi/_module.py index f8596dc969..3b61378b12 100644 --- a/src/scvi/external/contrastivevi/_module.py +++ b/src/scvi/external/contrastivevi/_module.py @@ -1,7 +1,5 @@ """PyTorch module for Contrastive VI for single cell expression data.""" -from typing import Optional - import numpy as np import torch import torch.nn.functional as F @@ -61,8 +59,8 @@ def __init__( n_layers: int = 1, dropout_rate: float = 0.1, use_observed_lib_size: bool = True, - library_log_means: Optional[np.ndarray] = None, - library_log_vars: Optional[np.ndarray] = None, + library_log_means: np.ndarray | None = None, + library_log_vars: np.ndarray | None = None, wasserstein_penalty: float = 0, ) -> None: super().__init__() diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index 8c7f3dd21b..7789970555 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -540,7 +540,7 @@ def load( ) registries = attr_dict.pop("registries_") - for adata, registry in zip(adatas, registries): + for adata, registry in zip(adatas, registries, strict=True): if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") @@ -689,4 +689,4 @@ def __iter__(self): dl if i == self.largest_train_dl_idx else cycle(dl) for i, dl in enumerate(self.data_loader_list) ] - return zip(*train_dls) + return zip(*train_dls, strict=True) diff --git a/src/scvi/external/gimvi/_module.py b/src/scvi/external/gimvi/_module.py index 54e4de4acf..a2db2637c3 100644 --- a/src/scvi/external/gimvi/_module.py +++ b/src/scvi/external/gimvi/_module.py @@ -1,7 +1,5 @@ """Main module.""" -from typing import Optional, Union - import numpy as np import torch import torch.nn.functional as F @@ -80,11 +78,11 @@ def __init__( self, dim_input_list: list[int], total_genes: int, - indices_mappings: list[Union[np.ndarray, slice]], + indices_mappings: list[np.ndarray | slice], gene_likelihoods: list[str], model_library_bools: list[bool], - library_log_means: list[Optional[np.ndarray]], - library_log_vars: list[Optional[np.ndarray]], + library_log_means: list[np.ndarray | None], + library_log_vars: list[np.ndarray | None], n_latent: int = 10, n_layers_encoder_individual: int = 1, n_layers_encoder_shared: int = 1, @@ -235,9 +233,9 @@ def sample_scale( x: torch.Tensor, mode: int, batch_index: torch.Tensor, - y: Optional[torch.Tensor] = None, + y: torch.Tensor | None = None, deterministic: bool = False, - decode_mode: Optional[int] = None, + decode_mode: int | None = None, ) -> torch.Tensor: """Return the tensor of predicted frequencies of expression. @@ -282,7 +280,7 @@ def _run_forward( x: torch.Tensor, mode: int, batch_index: torch.Tensor, - y: Optional[torch.Tensor] = None, + y: torch.Tensor | None = None, deterministic: bool = False, decode_mode: int = None, ) -> dict: @@ -307,7 +305,7 @@ def sample_rate( x: torch.Tensor, mode: int, batch_index: torch.Tensor, - y: Optional[torch.Tensor] = None, + y: torch.Tensor | None = None, deterministic: bool = False, decode_mode: int = None, ) -> torch.Tensor: @@ -379,7 +377,7 @@ def _get_generative_input(self, tensors, inference_outputs): return {"z": z, "library": library, "batch_index": batch_index, "y": y} @auto_move_data - def inference(self, x: torch.Tensor, mode: Optional[int] = None) -> dict: + def inference(self, x: torch.Tensor, mode: int | None = None) -> dict: """Run the inference model.""" x_ = x if self.log_variational: @@ -399,9 +397,9 @@ def generative( self, z: torch.Tensor, library: torch.Tensor, - batch_index: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, - mode: Optional[int] = None, + batch_index: torch.Tensor | None = None, + y: torch.Tensor | None = None, + mode: int | None = None, ) -> dict: """Run the generative model.""" px_scale, px_r, px_rate, px_dropout = self.decoder( @@ -432,7 +430,7 @@ def loss( tensors, inference_outputs, generative_outputs, - mode: Optional[int] = None, + mode: int | None = None, kl_weight=1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """Return the reconstruction loss and the Kullback divergences. diff --git a/src/scvi/external/gimvi/_utils.py b/src/scvi/external/gimvi/_utils.py index 1596a09ff7..7d4d32235b 100644 --- a/src/scvi/external/gimvi/_utils.py +++ b/src/scvi/external/gimvi/_utils.py @@ -1,6 +1,6 @@ import os import pickle -from typing import Literal, Optional +from typing import Literal import numpy as np import torch @@ -14,7 +14,7 @@ def _load_legacy_saved_gimvi_files( file_name_prefix: str, load_seq_adata: bool, load_spatial_adata: bool, -) -> tuple[dict, np.ndarray, np.ndarray, dict, Optional[AnnData], Optional[AnnData]]: +) -> tuple[dict, np.ndarray, np.ndarray, dict, AnnData | None, AnnData | None]: model_path = os.path.join(dir_path, f"{file_name_prefix}model_params.pt") setup_dict_path = os.path.join(dir_path, f"{file_name_prefix}attr.pkl") seq_var_names_path = os.path.join(dir_path, f"{file_name_prefix}var_names_seq.csv") @@ -56,10 +56,10 @@ def _load_saved_gimvi_files( dir_path: str, load_seq_adata: bool, load_spatial_adata: bool, - prefix: Optional[str] = None, - map_location: Optional[Literal["cpu", "cuda"]] = None, - backup_url: Optional[str] = None, -) -> tuple[dict, dict, np.ndarray, np.ndarray, dict, Optional[AnnData], Optional[AnnData]]: + prefix: str | None = None, + map_location: Literal["cpu", "cuda"] | None = None, + backup_url: str | None = None, +) -> tuple[dict, dict, np.ndarray, np.ndarray, dict, AnnData | None, AnnData | None]: file_name_prefix = prefix or "" model_file_name = f"{file_name_prefix}model.pt" diff --git a/src/scvi/external/mrvi/_components.py b/src/scvi/external/mrvi/_components.py index a00303a9b0..3ead06092b 100644 --- a/src/scvi/external/mrvi/_components.py +++ b/src/scvi/external/mrvi/_components.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Callable, Literal +from collections.abc import Callable +from typing import Any, Literal import flax.linen as nn import jax diff --git a/src/scvi/external/mrvi/_module.py b/src/scvi/external/mrvi/_module.py index 88e7d1ae77..03a1a5c37d 100644 --- a/src/scvi/external/mrvi/_module.py +++ b/src/scvi/external/mrvi/_module.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import flax.linen as nn import jax diff --git a/src/scvi/external/mrvi/_types.py b/src/scvi/external/mrvi/_types.py index 5a40e23157..7b7e33e4c9 100644 --- a/src/scvi/external/mrvi/_types.py +++ b/src/scvi/external/mrvi/_types.py @@ -1,8 +1,8 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Callable, Literal +from typing import Literal from xarray import DataArray diff --git a/src/scvi/external/scar/_model.py b/src/scvi/external/scar/_model.py index 6c3cc7c2ad..b7affed0bc 100644 --- a/src/scvi/external/scar/_model.py +++ b/src/scvi/external/scar/_model.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from typing import Literal, Optional, Union +from typing import Literal import numpy as np import pandas as pd @@ -81,7 +83,7 @@ class SCAR(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, adata: AnnData, - ambient_profile: Union[str, np.ndarray, pd.DataFrame, torch.tensor] = None, + ambient_profile: str | np.ndarray | pd.DataFrame | torch.tensor | None = None, n_hidden: int = 150, n_latent: int = 15, n_layers: int = 2, @@ -146,8 +148,8 @@ def __init__( def setup_anndata( cls, adata: AnnData, - layer: Optional[str] = None, - size_factor_key: Optional[str] = None, + layer: str | None = None, + size_factor_key: str | None = None, **kwargs, ): """%(summary)s. @@ -261,9 +263,9 @@ def get_ambient_profile( @torch.no_grad() def get_denoised_counts( self, - adata: Optional[AnnData] = None, + adata: AnnData | None = None, n_samples: int = 1, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> np.ndarray: r"""Generate observation samples from the posterior predictive distribution. diff --git a/src/scvi/external/scbasset/_module.py b/src/scvi/external/scbasset/_module.py index 4923be501a..53c5b6fb7c 100644 --- a/src/scvi/external/scbasset/_module.py +++ b/src/scvi/external/scbasset/_module.py @@ -1,5 +1,6 @@ import logging -from typing import Callable, NamedTuple, Optional +from collections.abc import Callable +from typing import NamedTuple import numpy as np import torch @@ -64,7 +65,7 @@ def __init__( pool_size: int = None, batch_norm: bool = True, dropout: float = 0.0, - activation_fn: Optional[Callable] = None, + activation_fn: Callable | None = None, ceil_mode: bool = False, ): super().__init__() @@ -100,7 +101,7 @@ def __init__( out_features: int, batch_norm: bool = True, dropout: float = 0.2, - activation_fn: Optional[Callable] = None, + activation_fn: Callable | None = None, ): super().__init__() self.dense = _Linear(in_features, out_features, bias=not batch_norm) @@ -223,7 +224,7 @@ class ScBassetModule(BaseModuleClass): def __init__( self, n_cells: int, - batch_ids: Optional[np.ndarray] = None, + batch_ids: np.ndarray | None = None, n_filters_init: int = 288, n_repeat_blocks_tower: int = 6, filters_mult: float = 1.122, diff --git a/src/scvi/external/tangram/_model.py b/src/scvi/external/tangram/_model.py index 49dda4a0d0..88f658e53b 100644 --- a/src/scvi/external/tangram/_model.py +++ b/src/scvi/external/tangram/_model.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from typing import Literal, Optional, Union +from typing import Literal import flax import jax @@ -82,7 +84,7 @@ def __init__( self, sc_adata: AnnData, constrained: bool = False, - target_count: Optional[int] = None, + target_count: int | None = None, **model_kwargs, ): super().__init__(sc_adata) @@ -125,9 +127,9 @@ def train( self, max_epochs: int = 1000, accelerator: str = "auto", - devices: Union[int, list[int], str] = "auto", + devices: int | list[int] | str = "auto", lr: float = 0.1, - plan_kwargs: Optional[dict] = None, + plan_kwargs: dict | None = None, ): """Train the model. @@ -195,12 +197,10 @@ def train( def setup_mudata( cls, mdata: MuData, - density_prior_key: Union[ - str, Literal["rna_count_based", "uniform"], None - ] = "rna_count_based", - sc_layer: Optional[str] = None, - sp_layer: Optional[str] = None, - modalities: Optional[dict[str, str]] = None, + density_prior_key: str | Literal["rna_count_based", "uniform"] | None = "rna_count_based", + sc_layer: str | None = None, + sp_layer: str | None = None, + modalities: dict[str, str] | None = None, **kwargs, ): """%(summary)s. diff --git a/src/scvi/external/tangram/_module.py b/src/scvi/external/tangram/_module.py index bfce474061..97c22034f7 100644 --- a/src/scvi/external/tangram/_module.py +++ b/src/scvi/external/tangram/_module.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Optional +from typing import NamedTuple import jax import jax.numpy as jnp @@ -40,7 +40,7 @@ class TangramMapper(JaxBaseModuleClass): lambda_count: float = 1.0 lambda_f_reg: float = 1.0 constrained: bool = False - target_count: Optional[int] = None + target_count: int | None = None training: bool = True def setup(self): diff --git a/src/scvi/model/_utils.py b/src/scvi/model/_utils.py index 6cc3ef0cde..0759253006 100644 --- a/src/scvi/model/_utils.py +++ b/src/scvi/model/_utils.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Iterable as IterableClass from collections.abc import Sequence -from typing import Literal, Optional, Union +from typing import Literal import jax import numpy as np @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -def use_distributed_sampler(strategy: Union[str, Strategy]) -> bool: +def use_distributed_sampler(strategy: str | Strategy) -> bool: """``EXPERIMENTAL`` Return whether to use a distributed sampler. Currently only supports DDP. @@ -73,8 +73,8 @@ def get_max_epochs_heuristic( @devices_dsp.dedent def parse_device_args( accelerator: str = "auto", - devices: Union[int, list[int], str] = "auto", - return_device: Optional[Literal["torch", "jax"]] = None, + devices: int | list[int] | str = "auto", + return_device: Literal["torch", "jax"] | None = None, validate_single_device: bool = False, ): """Parses device-related arguments. @@ -150,9 +150,9 @@ def parse_device_args( def scrna_raw_counts_properties( adata_manager: AnnDataManager, - idx1: Union[list[int], np.ndarray], - idx2: Union[list[int], np.ndarray], - var_idx: Optional[Union[list[int], np.ndarray]] = None, + idx1: list[int] | np.ndarray, + idx2: list[int] | np.ndarray, + var_idx: list[int] | np.ndarray | None = None, ) -> dict[str, np.ndarray]: """Computes and returns some statistics on the raw counts of two sub-populations. @@ -218,8 +218,8 @@ def scrna_raw_counts_properties( def cite_seq_raw_counts_properties( adata_manager: AnnDataManager, - idx1: Union[list[int], np.ndarray], - idx2: Union[list[int], np.ndarray], + idx1: list[int] | np.ndarray, + idx2: list[int] | np.ndarray, ) -> dict[str, np.ndarray]: """Computes and returns some statistics on the raw counts of two sub-populations. @@ -262,9 +262,9 @@ def cite_seq_raw_counts_properties( def scatac_raw_counts_properties( adata_manager: AnnDataManager, - idx1: Union[list[int], np.ndarray], - idx2: Union[list[int], np.ndarray], - var_idx: Optional[Union[list[int], np.ndarray]] = None, + idx1: list[int] | np.ndarray, + idx2: list[int] | np.ndarray, + var_idx: list[int] | np.ndarray | None = None, ) -> dict[str, np.ndarray]: """Computes and returns some statistics on the raw counts of two sub-populations. @@ -296,9 +296,7 @@ def scatac_raw_counts_properties( return properties -def _get_batch_code_from_category( - adata_manager: AnnDataManager, category: Sequence[Union[Number, str]] -): +def _get_batch_code_from_category(adata_manager: AnnDataManager, category: Sequence[Number | str]): if not isinstance(category, IterableClass) or isinstance(category, str): category = [category] diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 363be6c0ed..c490360c2e 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -1,7 +1,6 @@ import logging import warnings from copy import deepcopy -from typing import Optional, Union import anndata import numpy as np @@ -40,10 +39,10 @@ class ArchesMixin: def load_query_data( cls, adata: AnnOrMuData, - reference_model: Union[str, BaseModelClass], + reference_model: str | BaseModelClass, inplace_subset_query_vars: bool = False, accelerator: str = "auto", - device: Union[int, str] = "auto", + device: int | str = "auto", unfrozen: bool = False, freeze_dropout: bool = False, freeze_expression: bool = True, @@ -175,10 +174,10 @@ def load_query_data( @staticmethod def prepare_query_anndata( adata: AnnData, - reference_model: Union[str, BaseModelClass], + reference_model: str | BaseModelClass, return_reference_var_names: bool = False, inplace: bool = True, - ) -> Optional[Union[AnnData, pd.Index]]: + ) -> AnnData | pd.Index | None: """Prepare data for query integration. This function will return a new AnnData object with padded zeros @@ -214,10 +213,10 @@ def prepare_query_anndata( @staticmethod def prepare_query_mudata( mdata: MuData, - reference_model: Union[str, BaseModelClass], + reference_model: str | BaseModelClass, return_reference_var_names: bool = False, inplace: bool = True, - ) -> Optional[Union[MuData, dict[str, pd.Index]]]: + ) -> MuData | dict[str, pd.Index] | None: """Prepare multimodal dataset for query integration. This function will return a new MuData object such that the diff --git a/src/scvi/model/base/_de_core.py b/src/scvi/model/base/_de_core.py index dbc05febf3..bef79ad900 100644 --- a/src/scvi/model/base/_de_core.py +++ b/src/scvi/model/base/_de_core.py @@ -1,6 +1,5 @@ import logging from collections.abc import Iterable as IterableClass -from typing import Union import anndata import numpy as np @@ -14,8 +13,8 @@ def _prepare_obs( - idx1: Union[list[bool], np.ndarray, str], - idx2: Union[list[bool], np.ndarray, str], + idx1: list[bool] | np.ndarray | str, + idx2: list[bool] | np.ndarray | str, adata: anndata.AnnData, ): """Construct an array used for masking. diff --git a/src/scvi/model/base/_differential.py b/src/scvi/model/base/_differential.py index 4796e8cde8..60817788c4 100644 --- a/src/scvi/model/base/_differential.py +++ b/src/scvi/model/base/_differential.py @@ -1,8 +1,8 @@ import inspect import logging import warnings -from collections.abc import Sequence -from typing import Callable, Literal, Optional, Union +from collections.abc import Callable, Sequence +from typing import Literal import numpy as np import pandas as pd @@ -47,7 +47,7 @@ def __init__( self.model_fn = model_fn self.representation_fn = representation_fn - def filter_outlier_cells(self, selection: Union[list[bool], np.ndarray]): + def filter_outlier_cells(self, selection: list[bool] | np.ndarray): """Filters out cells that are outliers in the representation space.""" selection = self.process_selection(selection) reps = self.representation_fn( @@ -68,20 +68,20 @@ def filter_outlier_cells(self, selection: Union[list[bool], np.ndarray]): def get_bayes_factors( self, - idx1: Union[list[bool], np.ndarray], - idx2: Union[list[bool], np.ndarray], + idx1: list[bool] | np.ndarray, + idx2: list[bool] | np.ndarray, mode: Literal["vanilla", "change"] = "vanilla", - batchid1: Optional[Sequence[Union[Number, str]]] = None, - batchid2: Optional[Sequence[Union[Number, str]]] = None, - use_observed_batches: Optional[bool] = False, + batchid1: Sequence[Number | str] | None = None, + batchid2: Sequence[Number | str] | None = None, + use_observed_batches: bool | None = False, n_samples: int = 5000, use_permutation: bool = False, m_permutation: int = 10000, - change_fn: Optional[Union[str, Callable]] = None, - m1_domain_fn: Optional[Callable] = None, - delta: Optional[float] = 0.5, - pseudocounts: Union[float, None] = 0.0, - cred_interval_lvls: Optional[Union[list[float], np.ndarray]] = None, + change_fn: str | Callable | None = None, + m1_domain_fn: Callable | None = None, + delta: float | None = 0.5, + pseudocounts: float | None = 0.0, + cred_interval_lvls: list[float] | np.ndarray | None = None, ) -> dict[str, np.ndarray]: r"""A unified method for differential expression inference. @@ -375,12 +375,12 @@ def m1_domain_fn(samples): @torch.inference_mode() def scale_sampler( self, - selection: Union[list[bool], np.ndarray], - n_samples: Optional[int] = 5000, - n_samples_per_cell: Optional[int] = None, - batchid: Optional[Sequence[Union[Number, str]]] = None, - use_observed_batches: Optional[bool] = False, - give_mean: Optional[bool] = False, + selection: list[bool] | np.ndarray, + n_samples: int | None = 5000, + n_samples_per_cell: int | None = None, + batchid: Sequence[Number | str] | None = None, + use_observed_batches: bool | None = False, + give_mean: bool | None = False, ) -> dict: """Samples the posterior scale using the variational posterior distribution. @@ -471,7 +471,7 @@ def scale_sampler( px_scales = px_scales.mean(0) return {"scale": px_scales, "batch": batch_ids} - def process_selection(self, selection: Union[list[bool], np.ndarray]) -> np.ndarray: + def process_selection(self, selection: list[bool] | np.ndarray) -> np.ndarray: """If selection is a mask, convert it to indices.""" selection = np.asarray(selection) if selection.dtype is np.dtype("bool"): @@ -509,7 +509,7 @@ def estimate_pseudocounts_offset( scales_b: list[np.ndarray], where_zero_a: list[np.ndarray], where_zero_b: list[np.ndarray], - percentile: Optional[float] = 0.9, + percentile: float | None = 0.9, ): """Determines pseudocount offset. @@ -551,13 +551,13 @@ def estimate_pseudocounts_offset( def pairs_sampler( - arr1: Union[list[float], np.ndarray, torch.Tensor], - arr2: Union[list[float], np.ndarray, torch.Tensor], + arr1: list[float] | np.ndarray | torch.Tensor, + arr2: list[float] | np.ndarray | torch.Tensor, use_permutation: bool = True, m_permutation: int = None, sanity_check_perm: bool = False, - weights1: Union[list[float], np.ndarray, torch.Tensor] = None, - weights2: Union[list[float], np.ndarray, torch.Tensor] = None, + weights1: list[float] | np.ndarray | torch.Tensor = None, + weights2: list[float] | np.ndarray | torch.Tensor = None, ) -> tuple: """Creates more pairs. @@ -615,7 +615,7 @@ def pairs_sampler( def credible_intervals( - ary: np.ndarray, confidence_level: Union[float, list[float], np.ndarray] = 0.94 + ary: np.ndarray, confidence_level: float | list[float] | np.ndarray = 0.94 ) -> np.ndarray: """Calculate highest posterior density (HPD) of array for given credible_interval. @@ -660,8 +660,8 @@ def credible_intervals( def describe_continuous_distrib( - samples: Union[np.ndarray, torch.Tensor], - credible_intervals_levels: Optional[Union[list[float], np.ndarray]] = None, + samples: np.ndarray | torch.Tensor, + credible_intervals_levels: list[float] | np.ndarray | None = None, ) -> dict: """Computes properties of distribution based on its samples. diff --git a/src/scvi/model/base/_log_likelihood.py b/src/scvi/model/base/_log_likelihood.py index 6275dcb5fd..0c07e38a02 100644 --- a/src/scvi/model/base/_log_likelihood.py +++ b/src/scvi/model/base/_log_likelihood.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Iterator -from typing import Any, Callable +from collections.abc import Callable, Iterator +from typing import Any import torch from torch import Tensor diff --git a/src/scvi/model/base/_pyromixin.py b/src/scvi/model/base/_pyromixin.py index 2e066f827e..58992bb085 100755 --- a/src/scvi/model/base/_pyromixin.py +++ b/src/scvi/model/base/_pyromixin.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Callable +from collections.abc import Callable import numpy as np import torch diff --git a/src/scvi/module/_amortizedlda.py b/src/scvi/module/_amortizedlda.py index 516b6356da..3179d5538d 100644 --- a/src/scvi/module/_amortizedlda.py +++ b/src/scvi/module/_amortizedlda.py @@ -1,6 +1,5 @@ import math from collections.abc import Iterable, Sequence -from typing import Optional, Union import pyro import pyro.distributions as dist @@ -98,7 +97,7 @@ def __init__( @staticmethod def _get_fn_args_from_batch( tensor_dict: dict[str, torch.Tensor], - ) -> Union[Iterable, dict]: + ) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] library = torch.sum(x, dim=1) return (x, library), {} @@ -108,7 +107,7 @@ def forward( self, x: torch.Tensor, library: torch.Tensor, - n_obs: Optional[int] = None, + n_obs: int | None = None, kl_weight: float = 1.0, ): """Forward pass.""" @@ -184,7 +183,7 @@ def forward( self, x: torch.Tensor, _library: torch.Tensor, - n_obs: Optional[int] = None, + n_obs: int | None = None, kl_weight: float = 1.0, ): """Forward pass.""" @@ -242,8 +241,8 @@ def __init__( n_input: int, n_topics: int, n_hidden: int, - cell_topic_prior: Optional[Union[float, Sequence[float]]] = None, - topic_feature_prior: Optional[Union[float, Sequence[float]]] = None, + cell_topic_prior: float | Sequence[float] | None = None, + topic_feature_prior: float | Sequence[float] | None = None, ): super().__init__() diff --git a/src/scvi/module/_autozivae.py b/src/scvi/module/_autozivae.py index 804aa5d612..73b8624ca2 100644 --- a/src/scvi/module/_autozivae.py +++ b/src/scvi/module/_autozivae.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal import numpy as np import torch @@ -129,9 +129,7 @@ def __init__( else: # gene-cell raise Exception("Gene-cell not implemented yet for AutoZI") - def get_alphas_betas( - self, as_numpy: bool = True - ) -> dict[str, Union[torch.Tensor, np.ndarray]]: + def get_alphas_betas(self, as_numpy: bool = True) -> dict[str, torch.Tensor | np.ndarray]: """Get the parameters of the Bernoulli beta prior and posterior distributions.""" # Return parameters of Bernoulli Beta distributions in a dictionary outputs = {} @@ -181,8 +179,8 @@ def sample_from_beta_distribution( def reshape_bernoulli( self, bernoulli_params: torch.Tensor, - batch_index: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + batch_index: torch.Tensor | None = None, + y: torch.Tensor | None = None, ) -> torch.Tensor: """Reshape Bernoulli parameters to match the input tensor.""" if self.zero_inflation == "gene-label": @@ -214,8 +212,8 @@ def reshape_bernoulli( def sample_bernoulli_params( self, - batch_index: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + batch_index: torch.Tensor | None = None, + y: torch.Tensor | None = None, n_samples: int = 1, ) -> torch.Tensor: """Sample Bernoulli parameters from the posterior distribution.""" @@ -261,8 +259,8 @@ def generative( self, z, library, - batch_index: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, + batch_index: torch.Tensor | None = None, + y: torch.Tensor | None = None, size_factor=None, cont_covs=None, cat_covs=None, diff --git a/src/scvi/module/_jaxvae.py b/src/scvi/module/_jaxvae.py index ea03b8335b..d2811bf3cd 100644 --- a/src/scvi/module/_jaxvae.py +++ b/src/scvi/module/_jaxvae.py @@ -1,5 +1,3 @@ -from typing import Optional - import jax import jax.numpy as jnp import numpyro.distributions as dist @@ -30,7 +28,7 @@ class FlaxEncoder(nn.Module): n_latent: int n_hidden: int dropout_rate: int - training: Optional[bool] = None + training: bool | None = None def setup(self): """Setup encoder.""" @@ -44,7 +42,7 @@ def setup(self): self.dropout1 = nn.Dropout(self.dropout_rate) self.dropout2 = nn.Dropout(self.dropout_rate) - def __call__(self, x: jnp.ndarray, training: Optional[bool] = None): + def __call__(self, x: jnp.ndarray, training: bool | None = None): """Forward pass.""" training = nn.merge_param("training", self.training, training) is_eval = not training @@ -72,7 +70,7 @@ class FlaxDecoder(nn.Module): n_input: int dropout_rate: float n_hidden: int - training: Optional[bool] = None + training: bool | None = None def setup(self): """Setup decoder.""" @@ -91,7 +89,7 @@ def setup(self): "disp", lambda rng, shape: jax.random.normal(rng, shape), (self.n_input, 1) ) - def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, training: Optional[bool] = None): + def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, training: bool | None = None): """Forward pass.""" # TODO(adamgayoso): Test this training = nn.merge_param("training", self.training, training) diff --git a/src/scvi/module/_mrdeconv.py b/src/scvi/module/_mrdeconv.py index 9f06fcfb2d..1e00d5e4fe 100644 --- a/src/scvi/module/_mrdeconv.py +++ b/src/scvi/module/_mrdeconv.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Literal, Optional +from typing import Literal import numpy as np import torch @@ -89,8 +89,8 @@ def __init__( l1_reg: float = 0.0, beta_reg: float = 5.0, eta_reg: float = 1e-4, - extra_encoder_kwargs: Optional[dict] = None, - extra_decoder_kwargs: Optional[dict] = None, + extra_encoder_kwargs: dict | None = None, + extra_decoder_kwargs: dict | None = None, ): super().__init__() self.n_spots = n_spots diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 7a32707876..6ad24b65d2 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Literal, Optional +from typing import Literal import numpy as np import torch @@ -278,7 +278,7 @@ def __init__( n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, - n_cats_per_cov: Optional[Iterable[int]] = None, + n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, region_factors: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", @@ -287,8 +287,8 @@ def __init__( deeply_inject_covariates: bool = False, encode_covariates: bool = False, use_size_factor_key: bool = False, - protein_background_prior_mean: Optional[np.ndarray] = None, - protein_background_prior_scale: Optional[np.ndarray] = None, + protein_background_prior_mean: np.ndarray | None = None, + protein_background_prior_scale: np.ndarray | None = None, protein_dispersion: str = "protein", ): super().__init__() diff --git a/src/scvi/module/_peakvae.py b/src/scvi/module/_peakvae.py index 9912fe12de..10f4898354 100644 --- a/src/scvi/module/_peakvae.py +++ b/src/scvi/module/_peakvae.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Literal, Optional +from typing import Literal import numpy as np import torch @@ -143,7 +143,7 @@ def __init__( n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, - n_cats_per_cov: Optional[Iterable[int]] = None, + n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, model_depth: bool = True, region_factors: bool = True, @@ -152,8 +152,8 @@ def __init__( latent_distribution: Literal["normal", "ln"] = "normal", deeply_inject_covariates: bool = False, encode_covariates: bool = False, - extra_encoder_kwargs: Optional[dict] = None, - extra_decoder_kwargs: Optional[dict] = None, + extra_encoder_kwargs: dict | None = None, + extra_decoder_kwargs: dict | None = None, ): super().__init__() diff --git a/src/scvi/module/_totalvae.py b/src/scvi/module/_totalvae.py index 6ad7e716df..bdb80638b6 100644 --- a/src/scvi/module/_totalvae.py +++ b/src/scvi/module/_totalvae.py @@ -1,7 +1,7 @@ """Main module.""" from collections.abc import Iterable -from typing import Literal, Optional, Union +from typing import Literal import numpy as np import torch @@ -120,7 +120,7 @@ def __init__( n_layers_encoder: int = 2, n_layers_decoder: int = 1, n_continuous_cov: int = 0, - n_cats_per_cov: Optional[Iterable[int]] = None, + n_cats_per_cov: Iterable[int] | None = None, dropout_rate_decoder: float = 0.2, dropout_rate_encoder: float = 0.2, gene_dispersion: Literal["gene", "gene-batch", "gene-label"] = "gene", @@ -128,18 +128,18 @@ def __init__( log_variational: bool = True, gene_likelihood: Literal["zinb", "nb"] = "nb", latent_distribution: Literal["normal", "ln"] = "normal", - protein_batch_mask: dict[Union[str, int], np.ndarray] = None, + protein_batch_mask: dict[str | int, np.ndarray] = None, encode_covariates: bool = True, - protein_background_prior_mean: Optional[np.ndarray] = None, - protein_background_prior_scale: Optional[np.ndarray] = None, + protein_background_prior_mean: np.ndarray | None = None, + protein_background_prior_scale: np.ndarray | None = None, use_size_factor_key: bool = False, use_observed_lib_size: bool = True, - library_log_means: Optional[np.ndarray] = None, - library_log_vars: Optional[np.ndarray] = None, + library_log_means: np.ndarray | None = None, + library_log_vars: np.ndarray | None = None, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", - extra_encoder_kwargs: Optional[dict] = None, - extra_decoder_kwargs: Optional[dict] = None, + extra_encoder_kwargs: dict | None = None, + extra_decoder_kwargs: dict | None = None, ): super().__init__() self.gene_dispersion = gene_dispersion @@ -255,8 +255,8 @@ def get_sample_dispersion( self, x: torch.Tensor, y: torch.Tensor, - batch_index: Optional[torch.Tensor] = None, - label: Optional[torch.Tensor] = None, + batch_index: torch.Tensor | None = None, + label: torch.Tensor | None = None, n_samples: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: """Returns the tensors of dispersions for genes and proteins. @@ -290,7 +290,7 @@ def get_reconstruction_loss( y: torch.Tensor, px_dict: dict[str, torch.Tensor], py_dict: dict[str, torch.Tensor], - pro_batch_mask_minibatch: Optional[torch.Tensor] = None, + pro_batch_mask_minibatch: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute reconstruction loss.""" px_ = px_dict @@ -379,8 +379,8 @@ def generative( cont_covs=None, cat_covs=None, size_factor=None, - transform_batch: Optional[int] = None, - ) -> dict[str, Union[torch.Tensor, dict[str, torch.Tensor]]]: + transform_batch: int | None = None, + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: """Run the generative step.""" if cont_covs is None: decoder_input = z @@ -437,12 +437,12 @@ def inference( self, x: torch.Tensor, y: torch.Tensor, - batch_index: Optional[torch.Tensor] = None, - label: Optional[torch.Tensor] = None, + batch_index: torch.Tensor | None = None, + label: torch.Tensor | None = None, n_samples=1, cont_covs=None, cat_covs=None, - ) -> dict[str, Union[torch.Tensor, dict[str, torch.Tensor]]]: + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: """Internal helper function to compute necessary inference quantities. We use the dictionary ``px_`` to contain the parameters of the ZINB/NB for genes. diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index c48ba89693..c40faf60f2 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -2,7 +2,8 @@ import logging import warnings -from typing import Callable, Literal +from collections.abc import Callable +from typing import Literal import numpy as np import torch diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index f198fb2f66..c61b758b12 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -1,9 +1,9 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import field -from typing import Any, Callable +from typing import Any import flax import jax diff --git a/src/scvi/module/base/_decorators.py b/src/scvi/module/base/_decorators.py index 5e6a51daeb..b0052a95b1 100644 --- a/src/scvi/module/base/_decorators.py +++ b/src/scvi/module/base/_decorators.py @@ -1,6 +1,6 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from functools import wraps -from typing import Any, Callable, Union +from typing import Any import flax.linen as nn import torch @@ -71,7 +71,7 @@ def batch_to(data): def _apply_to_collection( data: Any, - dtype: Union[type, tuple], + dtype: type | tuple, function: Callable, *args, **kwargs, diff --git a/src/scvi/nn/_base_components.py b/src/scvi/nn/_base_components.py index 909fef573b..988f288bbe 100644 --- a/src/scvi/nn/_base_components.py +++ b/src/scvi/nn/_base_components.py @@ -1,6 +1,6 @@ import collections -from collections.abc import Iterable -from typing import Callable, Literal, Optional +from collections.abc import Callable, Iterable +from typing import Literal import torch from torch import nn @@ -96,7 +96,9 @@ def __init__( nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None, ), ) - for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])) + for i, (n_in, n_out) in enumerate( + zip(layers_dim[:-1], layers_dim[1:], strict=True) + ) ] ) ) @@ -152,7 +154,7 @@ def forward(self, x: torch.Tensor, *cat_list: int): if len(self.n_cat_list) > len(cat_list): raise ValueError("nb. categorical args provided doesn't match init. params.") - for n_cat, cat in zip(self.n_cat_list, cat_list): + for n_cat, cat in zip(self.n_cat_list, cat_list, strict=False): if n_cat and cat is None: raise ValueError("cat not provided while n_cat != 0 in init. params.") if n_cat > 1: # n_cat = 1 will be ignored - no additional information @@ -229,7 +231,7 @@ def __init__( dropout_rate: float = 0.1, distribution: str = "normal", var_eps: float = 1e-4, - var_activation: Optional[Callable] = None, + var_activation: Callable | None = None, return_dist: bool = False, **kwargs, ): diff --git a/src/scvi/nn/_embedding.py b/src/scvi/nn/_embedding.py index 9231b61c00..396e6a47a8 100644 --- a/src/scvi/nn/_embedding.py +++ b/src/scvi/nn/_embedding.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from collections.abc import Callable import torch from torch import nn diff --git a/src/scvi/train/_callbacks.py b/src/scvi/train/_callbacks.py index cc2dfd173c..b097f03440 100644 --- a/src/scvi/train/_callbacks.py +++ b/src/scvi/train/_callbacks.py @@ -2,10 +2,10 @@ import os import warnings +from collections.abc import Callable from copy import deepcopy from datetime import datetime from shutil import rmtree -from typing import Callable import flax import lightning.pytorch as pl diff --git a/src/scvi/train/_logger.py b/src/scvi/train/_logger.py index 357078cfdc..0c973e1b55 100644 --- a/src/scvi/train/_logger.py +++ b/src/scvi/train/_logger.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any import pandas as pd import torch @@ -15,7 +15,7 @@ def __init__(self): def log_hparams(self, params: dict[str, Any]) -> None: """Record hparams.""" - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: """Record metrics.""" def _handle_value(value): @@ -45,7 +45,7 @@ def save(self) -> None: class SimpleLogger(Logger): """Simple logger class.""" - def __init__(self, name: str = "lightning_logs", version: Optional[Union[int, str]] = None): + def __init__(self, name: str = "lightning_logs", version: int | str | None = None): super().__init__() self._name = name self._experiment = None diff --git a/src/scvi/train/_trainer.py b/src/scvi/train/_trainer.py index 3a7051dbde..d091383f8e 100644 --- a/src/scvi/train/_trainer.py +++ b/src/scvi/train/_trainer.py @@ -1,6 +1,6 @@ import sys import warnings -from typing import Literal, Optional, Union +from typing import Literal import lightning.pytorch as pl from lightning.pytorch.accelerators import Accelerator @@ -89,12 +89,12 @@ class Trainer(pl.Trainer): def __init__( self, - accelerator: Optional[Union[str, Accelerator]] = None, - devices: Optional[Union[list[int], str, int]] = None, + accelerator: str | Accelerator | None = None, + devices: list[int] | str | int | None = None, benchmark: bool = True, - check_val_every_n_epoch: Optional[int] = None, + check_val_every_n_epoch: int | None = None, max_epochs: int = 400, - default_root_dir: Optional[str] = None, + default_root_dir: str | None = None, enable_checkpointing: bool = False, checkpointing_monitor: str = "validation_loss", num_sanity_val_steps: int = 0, @@ -109,7 +109,7 @@ def __init__( enable_progress_bar: bool = True, progress_bar_refresh_rate: int = 1, simple_progress_bar: bool = True, - logger: Union[Optional[Logger], bool] = None, + logger: Logger | None | bool = None, log_every_n_steps: int = 10, learning_rate_monitor: bool = False, **kwargs, diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index b6eea8ab7e..79aa4bf0e3 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -1,9 +1,9 @@ import warnings from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import partial from inspect import signature -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Literal import jax import jax.numpy as jnp @@ -37,8 +37,8 @@ def _compute_kl_weight( epoch: int, step: int, - n_epochs_kl_warmup: Optional[int], - n_steps_kl_warmup: Optional[int], + n_epochs_kl_warmup: int | None, + n_steps_kl_warmup: int | None, max_kl_weight: float = 1.0, min_kl_weight: float = 0.0, ) -> float: @@ -144,7 +144,7 @@ def __init__( module: BaseModuleClass, *, optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", - optimizer_creator: Optional[TorchOptimizerCreator] = None, + optimizer_creator: TorchOptimizerCreator | None = None, lr: float = 1e-3, weight_decay: float = 1e-6, eps: float = 0.01, @@ -196,7 +196,7 @@ def __init__( self.initialize_val_metrics() @staticmethod - def _create_elbo_metric_components(mode: str, n_total: Optional[int] = None): + def _create_elbo_metric_components(mode: str, n_total: int | None = None): """Initialize ELBO metric and the metric collection.""" rec_loss = ElboMetric("reconstruction_loss", mode, "obs") kl_local = ElboMetric("kl_local", mode, "obs") @@ -366,7 +366,7 @@ def validation_step(self, batch, batch_idx): ) self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") - def _optimizer_creator_fn(self, optimizer_cls: Union[torch.optim.Adam, torch.optim.AdamW]): + def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW): """Create optimizer for the model. This type of function can be passed as the `optimizer_creator` @@ -481,7 +481,7 @@ def __init__( module: BaseModuleClass, *, optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", - optimizer_creator: Optional[TorchOptimizerCreator] = None, + optimizer_creator: TorchOptimizerCreator | None = None, lr: float = 1e-3, weight_decay: float = 1e-6, n_steps_kl_warmup: int = None, @@ -494,8 +494,8 @@ def __init__( "elbo_validation", "reconstruction_loss_validation", "kl_local_validation" ] = "elbo_validation", lr_min: float = 0, - adversarial_classifier: Union[bool, Classifier] = False, - scale_adversarial_loss: Union[float, Literal["auto"]] = "auto", + adversarial_classifier: bool | Classifier = False, + scale_adversarial_loss: float | Literal["auto"] = "auto", **loss_kwargs, ): super().__init__( @@ -698,8 +698,8 @@ def __init__( classification_ratio: int = 50, lr: float = 1e-3, weight_decay: float = 1e-6, - n_steps_kl_warmup: Union[int, None] = None, - n_epochs_kl_warmup: Union[int, None] = 400, + n_steps_kl_warmup: int | None = None, + n_epochs_kl_warmup: int | None = 400, reduce_lr_on_plateau: bool = False, lr_factor: float = 0.6, lr_patience: int = 30, @@ -879,11 +879,11 @@ class LowLevelPyroTrainingPlan(pl.LightningModule): def __init__( self, pyro_module: PyroBaseModuleClass, - loss_fn: Optional[pyro.infer.ELBO] = None, - optim: Optional[torch.optim.Adam] = None, - optim_kwargs: Optional[dict] = None, - n_steps_kl_warmup: Union[int, None] = None, - n_epochs_kl_warmup: Union[int, None] = 400, + loss_fn: pyro.infer.ELBO | None = None, + optim: torch.optim.Adam | None = None, + optim_kwargs: dict | None = None, + n_steps_kl_warmup: int | None = None, + n_epochs_kl_warmup: int | None = 400, scale_elbo: float = 1.0, ): super().__init__() @@ -1012,11 +1012,11 @@ class PyroTrainingPlan(LowLevelPyroTrainingPlan): def __init__( self, pyro_module: PyroBaseModuleClass, - loss_fn: Optional[pyro.infer.ELBO] = None, - optim: Optional[pyro.optim.PyroOptim] = None, - optim_kwargs: Optional[dict] = None, - n_steps_kl_warmup: Union[int, None] = None, - n_epochs_kl_warmup: Union[int, None] = 400, + loss_fn: pyro.infer.ELBO | None = None, + optim: pyro.optim.PyroOptim | None = None, + optim_kwargs: dict | None = None, + n_steps_kl_warmup: int | None = None, + n_epochs_kl_warmup: int | None = 400, scale_elbo: float = 1.0, ): super().__init__( @@ -1195,13 +1195,13 @@ def __init__( module: JaxBaseModuleClass, *, optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", - optimizer_creator: Optional[JaxOptimizerCreator] = None, + optimizer_creator: JaxOptimizerCreator | None = None, lr: float = 1e-3, weight_decay: float = 1e-6, eps: float = 0.01, - max_norm: Optional[float] = None, - n_steps_kl_warmup: Union[int, None] = None, - n_epochs_kl_warmup: Union[int, None] = 400, + max_norm: float | None = None, + n_steps_kl_warmup: int | None = None, + n_epochs_kl_warmup: int | None = 400, **loss_kwargs, ): super().__init__( diff --git a/src/scvi/train/_trainrunner.py b/src/scvi/train/_trainrunner.py index 53e056c429..c619947741 100644 --- a/src/scvi/train/_trainrunner.py +++ b/src/scvi/train/_trainrunner.py @@ -1,6 +1,5 @@ import logging import warnings -from typing import Union import lightning.pytorch as pl import numpy as np @@ -58,10 +57,10 @@ def __init__( self, model: BaseModelClass, training_plan: pl.LightningModule, - data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter], + data_splitter: SemiSupervisedDataSplitter | DataSplitter, max_epochs: int, accelerator: str = "auto", - devices: Union[int, list[int], str] = "auto", + devices: int | list[int] | str = "auto", **trainer_kwargs, ): self.training_plan = training_plan diff --git a/src/scvi/utils/_decorators.py b/src/scvi/utils/_decorators.py index aa205a2da7..156dc90e74 100644 --- a/src/scvi/utils/_decorators.py +++ b/src/scvi/utils/_decorators.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from functools import wraps -from typing import Callable def unsupported_if_adata_minified(fn: Callable) -> Callable: diff --git a/src/scvi/utils/_dependencies.py b/src/scvi/utils/_dependencies.py index c37f827031..2b08fa256f 100644 --- a/src/scvi/utils/_dependencies.py +++ b/src/scvi/utils/_dependencies.py @@ -1,6 +1,6 @@ import importlib +from collections.abc import Callable from functools import wraps -from typing import Callable def error_on_missing_dependencies(*modules): diff --git a/src/scvi/utils/_jax.py b/src/scvi/utils/_jax.py index 48516eb4b0..35e3ed11b1 100644 --- a/src/scvi/utils/_jax.py +++ b/src/scvi/utils/_jax.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import jax from jax import random diff --git a/tests/data/test_synthetic_iid.py b/tests/data/test_synthetic_iid.py index 53473c634c..ce2e560061 100644 --- a/tests/data/test_synthetic_iid.py +++ b/tests/data/test_synthetic_iid.py @@ -1,5 +1,3 @@ -from typing import Optional - import numpy as np import pytest @@ -7,7 +5,7 @@ @pytest.mark.parametrize("sparse_format", ["csr_matrix", "csc_matrix", None]) -def test_synthetic_iid_sparse_format(sparse_format: Optional[str]): +def test_synthetic_iid_sparse_format(sparse_format: str | None): _ = synthetic_iid(sparse_format=sparse_format) diff --git a/tests/data/utils.py b/tests/data/utils.py index 73f14b3cc2..517716ed31 100644 --- a/tests/data/utils.py +++ b/tests/data/utils.py @@ -1,5 +1,3 @@ -from typing import Optional - from anndata import AnnData from mudata import MuData @@ -25,8 +23,8 @@ def unsupervised_training_one_epoch( adata: AnnData, run_setup_anndata: bool = True, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, + batch_key: str | None = None, + labels_key: str | None = None, ): if run_setup_anndata: SCVI.setup_anndata(adata, batch_key=batch_key, labels_key=labels_key) @@ -36,13 +34,13 @@ def unsupervised_training_one_epoch( def generic_setup_adata_manager( adata: AnnData, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, - categorical_covariate_keys: Optional[list[str]] = None, - continuous_covariate_keys: Optional[list[str]] = None, - layer: Optional[str] = None, - protein_expression_obsm_key: Optional[str] = None, - protein_names_uns_key: Optional[str] = None, + batch_key: str | None = None, + labels_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + layer: str | None = None, + protein_expression_obsm_key: str | None = None, + protein_names_uns_key: str | None = None, ) -> AnnDataManager: setup_args = locals() setup_args.pop("adata") @@ -75,8 +73,8 @@ def generic_setup_adata_manager( def scanvi_setup_adata_manager( adata: AnnData, unlabeled_category: str, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, + batch_key: str | None = None, + labels_key: str | None = None, ) -> AnnDataManager: setup_args = locals() setup_args.pop("adata") @@ -93,15 +91,15 @@ def scanvi_setup_adata_manager( def generic_setup_mudata_manager( mdata: MuData, layer_mod, - layer: Optional[str] = None, - batch_mod: Optional[str] = None, - batch_key: Optional[str] = None, - categorical_covariate_mod: Optional[str] = None, - categorical_covariate_keys: Optional[list[str]] = None, - continuous_covariate_mod: Optional[str] = None, - continuous_covariate_keys: Optional[list[str]] = None, - protein_expression_mod: Optional[str] = None, - protein_expression_layer: Optional[str] = None, + layer: str | None = None, + batch_mod: str | None = None, + batch_key: str | None = None, + categorical_covariate_mod: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_mod: str | None = None, + continuous_covariate_keys: list[str] | None = None, + protein_expression_mod: str | None = None, + protein_expression_layer: str | None = None, ) -> AnnDataManager: setup_args = locals() setup_args.pop("mdata") diff --git a/tests/dataloaders/sparse_utils.py b/tests/dataloaders/sparse_utils.py index 21d6b8026f..50abdb0c68 100644 --- a/tests/dataloaders/sparse_utils.py +++ b/tests/dataloaders/sparse_utils.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal import torch from anndata import AnnData @@ -72,8 +72,8 @@ def __init__(self, adata: AnnData): def setup_anndata( cls, adata: AnnData, - layer: Optional[str] = None, - batch_key: Optional[str] = None, + layer: str | None = None, + batch_key: str | None = None, ): setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ @@ -88,7 +88,7 @@ def train( self, max_epochs: int = 1, accelerator: str = "auto", - devices: Union[int, list[int], str] = "auto", + devices: int | list[int] | str = "auto", expected_sparse_layout: Literal["csr", "csc"] = None, ): data_splitter = TestSparseDataSplitter( diff --git a/tests/external/scbasset/test_scbasset.py b/tests/external/scbasset/test_scbasset.py index 135d813446..789a36c915 100644 --- a/tests/external/scbasset/test_scbasset.py +++ b/tests/external/scbasset/test_scbasset.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional import numpy as np import pytest @@ -10,7 +9,7 @@ _DNA_CODE_KEY = "code" -def _get_adata(sparse_format: Optional[str] = None): +def _get_adata(sparse_format: str | None = None): dataset1 = synthetic_iid(batch_size=100, sparse_format=sparse_format).transpose() dataset1.X = (dataset1.X > 0).astype(float) dataset1.obsm[_DNA_CODE_KEY] = np.random.randint(0, 3, size=(dataset1.n_obs, 1344)) diff --git a/tests/external/tangram/test_tangram.py b/tests/external/tangram/test_tangram.py index 075d8a0bc9..d1c9707d3e 100644 --- a/tests/external/tangram/test_tangram.py +++ b/tests/external/tangram/test_tangram.py @@ -1,5 +1,3 @@ -from typing import Optional - import mudata import numpy as np import pytest @@ -10,7 +8,7 @@ modalities = {"density_prior_key": "sp", "sc_layer": "sc", "sp_layer": "sp"} -def _get_mdata(sparse_format: Optional[str] = None): +def _get_mdata(sparse_format: str | None = None): dataset1 = synthetic_iid(batch_size=100, sparse_format=sparse_format) dataset2 = dataset1[-25:].copy() dataset1 = dataset1[:-25].copy() diff --git a/tests/model/base/test_base_model.py b/tests/model/base/test_base_model.py index 4f3e8e3e8f..1435829406 100644 --- a/tests/model/base/test_base_model.py +++ b/tests/model/base/test_base_model.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from anndata import AnnData @@ -14,12 +12,12 @@ class TestModelClass(BaseModelClass): def setup_anndata( cls, adata: AnnData, - layer: Optional[str] = None, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, - size_factor_key: Optional[str] = None, - categorical_covariate_keys: Optional[list[str]] = None, - continuous_covariate_keys: Optional[list[str]] = None, + layer: str | None = None, + batch_key: str | None = None, + labels_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, **kwargs, ): setup_method_args = cls._get_setup_method_args(**locals()) From e579cbbc1f62dab4b80c45210f9cd66080ba59ce Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Tue, 17 Sep 2024 16:43:36 +0300 Subject: [PATCH 03/22] update docker to py3.12 (#2974) --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index e3fe047758..5f4d005e50 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ FROM nvidia/cuda:12.5.0-runtime-ubuntu22.04 -FROM python:3.11 AS base +FROM python:3.12 AS base RUN pip install --no-cache-dir uv @@ -9,7 +9,7 @@ CMD ["/bin/bash"] FROM base AS build -ENV SCVI_PATH="/usr/local/lib/python3.11/site-packages/scvi-tools" +ENV SCVI_PATH="/usr/local/lib/python3.12/site-packages/scvi-tools" COPY . ${SCVI_PATH} From 9c53fd60f8e89b4743d79f310981226ea21ecb24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:44:07 +0300 Subject: [PATCH 04/22] [pre-commit.ci] pre-commit autoupdate (#2975) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.4 → v0.6.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.4...v0.6.5) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ori Kronfeld --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd473b5d4f..e0239d02f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: )$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 7545970737b64b5c77b728f377db371e28e96a84 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:04:58 +0300 Subject: [PATCH 05/22] docs: automated update of tutorials (#2976) automated update of tutorials submodule Co-authored-by: ori-kron-wis <175299014+ori-kron-wis@users.noreply.github.com> --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 2f9c2ac012..0f9908fdc7 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 2f9c2ac012942f3478405c2d489c4334abc1e22f +Subproject commit 0f9908fdc7565b147493db0bca6c46d2608db5eb From 405241299a5852df66cb1a4856055cb0641c44a3 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Tue, 17 Sep 2024 23:48:49 +0300 Subject: [PATCH 06/22] update ruff rules (#2973) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Can Ergen --- docs/conf.py | 5 +- docs/extensions/typed_returns.py | 7 +- pyproject.toml | 18 +++-- src/scvi/_settings.py | 7 +- src/scvi/autotune/_experiment.py | 19 +++-- src/scvi/autotune/_tune.py | 14 ++-- src/scvi/criticism/_ppc.py | 6 +- src/scvi/data/_anntorchdataset.py | 3 +- src/scvi/data/_built_in_data/_cellxgene.py | 8 ++- src/scvi/data/_built_in_data/_synthetic.py | 6 +- src/scvi/data/_datasets.py | 9 ++- src/scvi/data/_manager.py | 16 +++-- src/scvi/data/_preprocessing.py | 10 ++- src/scvi/data/_utils.py | 14 ++-- src/scvi/dataloaders/_data_splitting.py | 2 +- src/scvi/distributions/_normal.py | 6 +- src/scvi/external/cellassign/_model.py | 5 +- src/scvi/external/contrastivevi/_model.py | 8 ++- src/scvi/external/gimvi/_model.py | 9 ++- src/scvi/external/mrvi/_components.py | 7 +- src/scvi/external/mrvi/_model.py | 12 ++-- src/scvi/external/mrvi/_module.py | 7 +- src/scvi/external/mrvi/_types.py | 7 +- src/scvi/external/mrvi/_utils.py | 13 +++- src/scvi/external/poissonvi/_model.py | 12 ++-- src/scvi/external/scar/_model.py | 8 ++- src/scvi/external/scbasset/_model.py | 10 ++- src/scvi/external/solo/_model.py | 8 ++- src/scvi/external/stereoscope/_model.py | 8 ++- src/scvi/external/tangram/_model.py | 14 ++-- src/scvi/external/velovi/_model.py | 10 ++- src/scvi/external/velovi/_module.py | 10 ++- src/scvi/hub/_metadata.py | 5 +- src/scvi/hub/_model.py | 8 ++- src/scvi/model/_amortizedlda.py | 8 ++- src/scvi/model/_autozi.py | 10 ++- src/scvi/model/_condscvi.py | 5 +- src/scvi/model/_destvi.py | 13 ++-- src/scvi/model/_jaxscvi.py | 12 ++-- src/scvi/model/_linear_scvi.py | 8 ++- src/scvi/model/_multivi.py | 13 ++-- src/scvi/model/_peakvi.py | 10 ++- src/scvi/model/_scanvi.py | 19 +++-- src/scvi/model/_scvi.py | 15 ++-- src/scvi/model/_totalvi.py | 15 ++-- src/scvi/model/base/_base_model.py | 14 ++-- src/scvi/model/base/_embedding_mixin.py | 8 ++- src/scvi/model/base/_log_likelihood.py | 12 ++-- src/scvi/model/base/_pyromixin.py | 5 +- src/scvi/model/base/_rnamixin.py | 11 ++- src/scvi/model/base/_save_load.py | 11 ++- src/scvi/model/base/_training_mixin.py | 5 +- src/scvi/model/base/_vaemixin.py | 16 +++-- src/scvi/model/utils/_mde.py | 7 +- src/scvi/module/_scanvae.py | 14 ++-- src/scvi/module/_vae.py | 10 ++- src/scvi/module/_vaec.py | 8 ++- src/scvi/module/base/_base_module.py | 21 +++--- src/scvi/nn/_embedding.py | 5 +- src/scvi/train/_callbacks.py | 16 +++-- src/scvi/utils/_track.py | 3 +- tests/autotune/test_experiment.py | 72 +++++++++---------- tests/conftest.py | 2 +- tests/criticism/test_criticism.py | 6 +- tests/data/test_preprocessing.py | 6 +- tests/external/gimvi/test_gimvi.py | 10 ++- tests/external/mrvi/test_model.py | 10 ++- tests/external/tangram/test_tangram.py | 12 ++-- tests/hub/test_hub_model.py | 6 +- tests/model/test_amortizedlda.py | 16 ++--- tests/model/test_autozi.py | 5 +- tests/model/test_differential.py | 3 +- tests/model/test_jaxscvi.py | 3 +- tests/model/test_models_with_minified_data.py | 6 +- tests/model/test_pyro.py | 5 +- tests/model/test_scanvi.py | 6 +- tests/model/test_scvi.py | 39 +++++----- tests/model/test_totalvi.py | 26 ++++--- tests/nn/test_embedding.py | 14 ++-- tests/train/test_trainingplans.py | 5 +- 80 files changed, 556 insertions(+), 301 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 44d3773acc..de40f3572b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,9 +5,12 @@ import subprocess import sys from pathlib import Path -from typing import Any from importlib.metadata import metadata from datetime import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any HERE = Path(__file__).parent sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")] diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py index 113520471c..47292453af 100644 --- a/docs/extensions/typed_returns.py +++ b/docs/extensions/typed_returns.py @@ -3,10 +3,13 @@ from __future__ import annotations import re -from collections.abc import Generator, Iterable -from sphinx.application import Sphinx from sphinx.ext.napoleon import NumpyDocstring +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sphinx.application import Sphinx + from collections.abc import Generator, Iterable def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: diff --git a/pyproject.toml b/pyproject.toml index 8f50db143d..7d4c961d48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,6 +177,10 @@ select = [ "BLE", # flake8-blind-except "UP", # pyupgrade "RUF100", # Report unused noqa directives + "PT", # pytest style + "NPY", # numpy formatting + "TCH", # flake8-type-checking + "FA", # flake8-future-annotations ] ignore = [ # allow I, O, l as variable names -> I is the identity matrix @@ -189,17 +193,16 @@ ignore = [ "D107", # Errors from function calls in argument defaults. These are fine when the result is immutable. "B008", - # __magic__ methods are are often self-explanatory, allow missing docstrings - "D105", # first line should end with a period [Bug: doesn't work with single-line docstrings] "D400", # First line should be in imperative mood; try rephrasing "D401", - ## Disable one in each pair of mutually incompatible rules - # We don’t want a blank line before a class docstring - "D203", # We want docstrings to start immediately after the opening triple quote "D213", + # Raising ValueError is sufficient in tests. + "PT011", + # We support np.random functions. + "NPY002" ] [tool.ruff.lint.pydocstyle] @@ -212,6 +215,7 @@ convention = "numpy" "src/scvi/__init__.py" = ["I"] [tool.ruff.format] +docstring-code-format = true # Like Black, use double quotes for strings. quote-style = "double" @@ -226,3 +230,7 @@ line-ending = "auto" [tool.jupytext] formats = "ipynb,md" + +[tool.ruff.lint.flake8-type-checking] +exempt-modules = [] +strict = true diff --git a/src/scvi/_settings.py b/src/scvi/_settings.py index 25979de587..062ea2607b 100644 --- a/src/scvi/_settings.py +++ b/src/scvi/_settings.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import logging import os from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING import torch from lightning.pytorch import seed_everything from rich.console import Console from rich.logging import RichHandler +if TYPE_CHECKING: + from typing import Literal + scvi_logger = logging.getLogger("scvi") diff --git a/src/scvi/autotune/_experiment.py b/src/scvi/autotune/_experiment.py index 4b59536b14..f9154e2217 100644 --- a/src/scvi/autotune/_experiment.py +++ b/src/scvi/autotune/_experiment.py @@ -1,19 +1,24 @@ from __future__ import annotations from os.path import join -from typing import Any, Literal +from typing import TYPE_CHECKING from anndata import AnnData -from lightning.pytorch import LightningDataModule from lightning.pytorch.callbacks import Callback from lightning.pytorch.loggers import TensorBoardLogger from mudata import MuData -from ray.tune import ResultGrid, Tuner -from ray.tune.schedulers import TrialScheduler -from ray.tune.search import SearchAlgorithm +from ray.tune import Tuner -from scvi._types import AnnOrMuData -from scvi.model.base import BaseModelClass +if TYPE_CHECKING: + from typing import Any, Literal + + from lightning.pytorch import LightningDataModule + from ray.tune import ResultGrid + from ray.tune.schedulers import TrialScheduler + from ray.tune.search import SearchAlgorithm + + from scvi._types import AnnOrMuData + from scvi.model.base import BaseModelClass _ASHA_DEFAULT_KWARGS = { "max_t": 100, diff --git a/src/scvi/autotune/_tune.py b/src/scvi/autotune/_tune.py index ed079432b8..7bd90d5f16 100644 --- a/src/scvi/autotune/_tune.py +++ b/src/scvi/autotune/_tune.py @@ -1,13 +1,17 @@ from __future__ import annotations import logging -from typing import Any, Literal +from typing import TYPE_CHECKING -from lightning.pytorch import LightningDataModule - -from scvi._types import AnnOrMuData from scvi.autotune._experiment import AutotuneExperiment -from scvi.model.base import BaseModelClass + +if TYPE_CHECKING: + from typing import Any, Literal + + from lightning.pytorch import LightningDataModule + + from scvi._types import AnnOrMuData + from scvi.model.base import BaseModelClass logger = logging.getLogger(__name__) diff --git a/src/scvi/criticism/_ppc.py b/src/scvi/criticism/_ppc.py index 055ffc299e..a1b66478ab 100644 --- a/src/scvi/criticism/_ppc.py +++ b/src/scvi/criticism/_ppc.py @@ -2,7 +2,7 @@ import json import warnings -from typing import Literal +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd @@ -17,7 +17,6 @@ from sparse import GCXS, SparseArray from xarray import DataArray, Dataset -from scvi.model.base import BaseModelClass from scvi.utils import dependencies from ._constants import ( @@ -33,6 +32,9 @@ UNS_NAME_RGG_RAW, ) +if TYPE_CHECKING: + from scvi.model.base import BaseModelClass + Dims = Literal["cells", "features"] diff --git a/src/scvi/data/_anntorchdataset.py b/src/scvi/data/_anntorchdataset.py index f0725efb13..e6bf34142b 100644 --- a/src/scvi/data/_anntorchdataset.py +++ b/src/scvi/data/_anntorchdataset.py @@ -6,7 +6,6 @@ import h5py import numpy as np import pandas as pd -import torch try: # anndata >= 0.10 @@ -22,6 +21,8 @@ from scvi._constants import REGISTRY_KEYS if TYPE_CHECKING: + import torch + from ._manager import AnnDataManager from ._utils import registry_key_to_default_dtype, scipy_to_torch_sparse diff --git a/src/scvi/data/_built_in_data/_cellxgene.py b/src/scvi/data/_built_in_data/_cellxgene.py index 7027d8820d..8ef90a4fc9 100644 --- a/src/scvi/data/_built_in_data/_cellxgene.py +++ b/src/scvi/data/_built_in_data/_cellxgene.py @@ -1,10 +1,16 @@ +from __future__ import annotations + import os import re +from typing import TYPE_CHECKING -from anndata import AnnData, read_h5ad +from anndata import read_h5ad from scvi.utils import dependencies +if TYPE_CHECKING: + from anndata import AnnData + def _parse_dataset_id(url: str): match = re.search(r"/e/(.+)", url) diff --git a/src/scvi/data/_built_in_data/_synthetic.py b/src/scvi/data/_built_in_data/_synthetic.py index 714e8111f5..54425961ca 100644 --- a/src/scvi/data/_built_in_data/_synthetic.py +++ b/src/scvi/data/_built_in_data/_synthetic.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -6,7 +9,8 @@ from anndata import AnnData from mudata import MuData -from scvi._types import AnnOrMuData +if TYPE_CHECKING: + from scvi._types import AnnOrMuData logger = logging.getLogger(__name__) diff --git a/src/scvi/data/_datasets.py b/src/scvi/data/_datasets.py index 8bbaea68de..1560679b5f 100644 --- a/src/scvi/data/_datasets.py +++ b/src/scvi/data/_datasets.py @@ -1,11 +1,9 @@ from __future__ import annotations import warnings - -import anndata +from typing import TYPE_CHECKING from scvi import settings -from scvi._types import AnnOrMuData from ._built_in_data._brain_large import _load_brainlarge_dataset from ._built_in_data._cellxgene import _load_cellxgene_dataset @@ -28,6 +26,11 @@ from ._built_in_data._smfish import _load_smfish from ._built_in_data._synthetic import _generate_synthetic +if TYPE_CHECKING: + import anndata + + from scvi._types import AnnOrMuData + def pbmc_dataset( save_path: str = "data/", diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 10d0219041..cf1c0708d1 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -2,14 +2,12 @@ import sys from collections import defaultdict -from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass from io import StringIO +from typing import TYPE_CHECKING from uuid import uuid4 -import numpy as np -import pandas as pd import rich from mudata import MuData from rich import box @@ -17,7 +15,6 @@ from torch.utils.data import Subset import scvi -from scvi._types import AnnOrMuData from scvi.utils import attrdict from . import _constants @@ -28,7 +25,16 @@ _check_mudata_fully_paired, get_anndata_attribute, ) -from .fields import AnnDataField + +if TYPE_CHECKING: + from collections.abc import Sequence + + import numpy as np + import pandas as pd + + from scvi._types import AnnOrMuData + + from .fields import AnnDataField @dataclass diff --git a/src/scvi/data/_preprocessing.py b/src/scvi/data/_preprocessing.py index 2cf251a3d9..fa18dcd134 100644 --- a/src/scvi/data/_preprocessing.py +++ b/src/scvi/data/_preprocessing.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging import tempfile -from pathlib import Path +from typing import TYPE_CHECKING -import anndata import numpy as np import pandas as pd import torch @@ -14,6 +15,11 @@ from ._utils import _check_nonnegative_integers +if TYPE_CHECKING: + from pathlib import Path + + import anndata + logger = logging.getLogger(__name__) diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index 20fbfef293..fc6228a29f 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -6,7 +6,6 @@ import h5py import numpy as np -import numpy.typing as npt import pandas as pd import scipy.sparse as sp_sparse from anndata import AnnData @@ -25,15 +24,22 @@ except ImportError: from anndata._io.specs import read_elem +from typing import TYPE_CHECKING + from mudata import MuData -from pandas.api.types import CategoricalDtype -from torch import Tensor, as_tensor, sparse_csc_tensor, sparse_csr_tensor +from torch import as_tensor, sparse_csc_tensor, sparse_csr_tensor from scvi import REGISTRY_KEYS, settings -from scvi._types import AnnOrMuData, MinifiedDataType from . import _constants +if TYPE_CHECKING: + import numpy.typing as npt + from pandas.api.types import CategoricalDtype + from torch import Tensor + + from scvi._types import AnnOrMuData, MinifiedDataType + logger = logging.getLogger(__name__) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 9530eb3975..a33748d40c 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -324,7 +324,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, >>> adata = scvi.data.synthetic_iid() >>> scvi.model.SCVI.setup_anndata(adata, labels_key="labels") >>> adata_manager = scvi.model.SCVI(adata).adata_manager - >>> unknown_label = 'label_0' + >>> unknown_label = "label_0" >>> splitter = SemiSupervisedDataSplitter(adata, unknown_label) >>> splitter.setup() >>> train_dl = splitter.train_dataloader() diff --git a/src/scvi/distributions/_normal.py b/src/scvi/distributions/_normal.py index 05eb337797..d9336bbb0f 100644 --- a/src/scvi/distributions/_normal.py +++ b/src/scvi/distributions/_normal.py @@ -1,8 +1,12 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING + from torch.distributions import Normal as NormalTorch +if TYPE_CHECKING: + import torch + class Normal(NormalTorch): """Normal distribution. diff --git a/src/scvi/external/cellassign/_model.py b/src/scvi/external/cellassign/_model.py index 3607819856..a6ecaeb5ef 100644 --- a/src/scvi/external/cellassign/_model.py +++ b/src/scvi/external/cellassign/_model.py @@ -1,11 +1,11 @@ from __future__ import annotations import logging +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from lightning.pytorch.callbacks import Callback from scvi import REGISTRY_KEYS @@ -25,6 +25,9 @@ from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from anndata import AnnData + logger = logging.getLogger(__name__) B = 10 diff --git a/src/scvi/external/contrastivevi/_model.py b/src/scvi/external/contrastivevi/_model.py index 1eba0c40cc..c5622d4588 100644 --- a/src/scvi/external/contrastivevi/_model.py +++ b/src/scvi/external/contrastivevi/_model.py @@ -4,13 +4,12 @@ import logging import warnings -from collections.abc import Iterable, Sequence from functools import partial +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager @@ -38,6 +37,11 @@ from ._contrastive_data_splitting import ContrastiveDataSplitter from ._module import ContrastiveVAE +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from anndata import AnnData + logger = logging.getLogger(__name__) Number = int | float diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index 7789970555..8cb78eaa9f 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -4,10 +4,10 @@ import os import warnings from itertools import cycle +from typing import TYPE_CHECKING import numpy as np import torch -from anndata import AnnData from torch.utils.data import DataLoader from scvi import REGISTRY_KEYS, settings @@ -15,7 +15,7 @@ from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY from scvi.data.fields import CategoricalObsField, LayerField -from scvi.dataloaders import AnnDataLoader, DataSplitter +from scvi.dataloaders import DataSplitter from scvi.model._utils import _init_library_size, parse_device_args from scvi.model.base import BaseModelClass, VAEMixin from scvi.train import Trainer @@ -26,6 +26,11 @@ from ._task import GIMVITrainingPlan from ._utils import _load_legacy_saved_gimvi_files, _load_saved_gimvi_files +if TYPE_CHECKING: + from anndata import AnnData + + from scvi.dataloaders import AnnDataLoader + logger = logging.getLogger(__name__) diff --git a/src/scvi/external/mrvi/_components.py b/src/scvi/external/mrvi/_components.py index 3ead06092b..906dd59e7f 100644 --- a/src/scvi/external/mrvi/_components.py +++ b/src/scvi/external/mrvi/_components.py @@ -1,13 +1,16 @@ from __future__ import annotations -from collections.abc import Callable -from typing import Any, Literal +from typing import TYPE_CHECKING import flax.linen as nn import jax import jax.numpy as jnp import numpyro.distributions as dist +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any, Literal + PYTORCH_DEFAULT_SCALE = 1 / 3 diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 38245d62f7..a0d362018b 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -2,15 +2,12 @@ import logging import warnings -from typing import Literal +from typing import TYPE_CHECKING import jax import jax.numpy as jnp import numpy as np -import numpy.typing as npt import xarray as xr -from anndata import AnnData -from numpyro.distributions import Distribution from tqdm import tqdm from scvi import REGISTRY_KEYS @@ -21,6 +18,13 @@ from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from typing import Literal + + import numpy.typing as npt + from anndata import AnnData + from numpyro.distributions import Distribution + logger = logging.getLogger(__name__) DEFAULT_TRAIN_KWARGS = { diff --git a/src/scvi/external/mrvi/_module.py b/src/scvi/external/mrvi/_module.py index 03a1a5c37d..303f4589d5 100644 --- a/src/scvi/external/mrvi/_module.py +++ b/src/scvi/external/mrvi/_module.py @@ -1,8 +1,7 @@ from __future__ import annotations import warnings -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING import flax.linen as nn import jax @@ -14,6 +13,10 @@ from scvi.external.mrvi._components import AttentionBlock, Dense from scvi.module.base import JaxBaseModuleClass, LossOutput, flax_configure +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any + DEFAULT_PX_KWARGS = { "n_hidden": 32, "stop_gradients": False, diff --git a/src/scvi/external/mrvi/_types.py b/src/scvi/external/mrvi/_types.py index 7b7e33e4c9..c9a5e5ee41 100644 --- a/src/scvi/external/mrvi/_types.py +++ b/src/scvi/external/mrvi/_types.py @@ -1,11 +1,14 @@ from __future__ import annotations -from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Literal +from typing import TYPE_CHECKING from xarray import DataArray +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from typing import Literal + @dataclass(frozen=True) class MRVIReduction: diff --git a/src/scvi/external/mrvi/_utils.py b/src/scvi/external/mrvi/_utils.py index 9b327c832d..6335492c0d 100644 --- a/src/scvi/external/mrvi/_utils.py +++ b/src/scvi/external/mrvi/_utils.py @@ -1,9 +1,16 @@ from __future__ import annotations -from jax import Array, jit -from jax.typing import ArrayLike +from typing import TYPE_CHECKING -from scvi.external.mrvi._types import MRVIReduction, _ComputeLocalStatisticsRequirements +from jax import jit + +from scvi.external.mrvi._types import _ComputeLocalStatisticsRequirements + +if TYPE_CHECKING: + from jax import Array + from jax.typing import ArrayLike + + from scvi.external.mrvi._types import MRVIReduction def _parse_local_statistics_requirements( diff --git a/src/scvi/external/poissonvi/_model.py b/src/scvi/external/poissonvi/_model.py index 5d18feb03c..a27d3b0c60 100644 --- a/src/scvi/external/poissonvi/_model.py +++ b/src/scvi/external/poissonvi/_model.py @@ -1,14 +1,11 @@ from __future__ import annotations import logging -from collections.abc import Iterable, Sequence from functools import partial -from typing import Literal +from typing import TYPE_CHECKING import numpy as np -import pandas as pd import torch -from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager @@ -29,6 +26,13 @@ from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import de_dsp +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Literal + + import pandas as pd + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/external/scar/_model.py b/src/scvi/external/scar/_model.py index b7affed0bc..c2a5b481d3 100644 --- a/src/scvi/external/scar/_model.py +++ b/src/scvi/external/scar/_model.py @@ -1,12 +1,11 @@ from __future__ import annotations import logging -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from torch.distributions.multinomial import Multinomial from scvi import REGISTRY_KEYS @@ -23,6 +22,11 @@ from ._module import SCAR_VAE +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/external/scbasset/_model.py b/src/scvi/external/scbasset/_model.py index 7f86ad6e10..c93ed9a254 100644 --- a/src/scvi/external/scbasset/_model.py +++ b/src/scvi/external/scbasset/_model.py @@ -2,12 +2,11 @@ import logging from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from scvi.data import AnnDataManager from scvi.data._download import _download @@ -20,6 +19,11 @@ from scvi.utils import dependencies, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + logger = logging.getLogger(__name__) @@ -48,7 +52,7 @@ class SCBASSET(BaseModelClass): -------- >>> adata = anndata.read_h5ad(path_to_sc_anndata) >>> scvi.data.add_dna_sequence(adata) - >>> adata = adata.transpose() # regions by cells + >>> adata = adata.transpose() # regions by cells >>> scvi.external.SCBASSET.setup_anndata(adata, dna_code_key="dna_code") >>> model = scvi.external.SCBASSET(adata) >>> model.train() diff --git a/src/scvi/external/solo/_model.py b/src/scvi/external/solo/_model.py index c686d58f5f..69f29bade4 100644 --- a/src/scvi/external/solo/_model.py +++ b/src/scvi/external/solo/_model.py @@ -3,8 +3,8 @@ import io import logging import warnings -from collections.abc import Sequence from contextlib import redirect_stdout +from typing import TYPE_CHECKING import anndata import numpy as np @@ -16,7 +16,6 @@ from scvi.data import AnnDataManager from scvi.data.fields import CategoricalObsField, LayerField from scvi.dataloaders import DataSplitter -from scvi.model import SCVI from scvi.model._utils import get_max_epochs_heuristic from scvi.model.base import BaseModelClass from scvi.module import Classifier @@ -25,6 +24,11 @@ from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from collections.abc import Sequence + + from scvi.model import SCVI + logger = logging.getLogger(__name__) LABELS_KEY = "_solo_doub_sim" diff --git a/src/scvi/external/stereoscope/_model.py b/src/scvi/external/stereoscope/_model.py index 10f8e1f089..05e1ad0bf7 100644 --- a/src/scvi/external/stereoscope/_model.py +++ b/src/scvi/external/stereoscope/_model.py @@ -1,12 +1,11 @@ from __future__ import annotations import logging -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager @@ -16,6 +15,11 @@ from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/external/tangram/_model.py b/src/scvi/external/tangram/_model.py index 88f658e53b..48cc26b804 100644 --- a/src/scvi/external/tangram/_model.py +++ b/src/scvi/external/tangram/_model.py @@ -1,17 +1,14 @@ from __future__ import annotations import logging -from typing import Literal +from typing import TYPE_CHECKING import flax import jax -import jax.numpy as jnp import numpy as np import pandas as pd import scipy from anndata import AnnData -from jaxlib.xla_extension import Device -from mudata import MuData from scvi.data import AnnDataManager, AnnDataManagerValidationCheck, fields from scvi.external.tangram._module import TANGRAM_REGISTRY_KEYS, TangramMapper @@ -21,6 +18,13 @@ from scvi.utils import setup_anndata_dsp, track from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from typing import Literal + + import jax.numpy as jnp + from jaxlib.xla_extension import Device + from mudata import MuData + logger = logging.getLogger(__name__) @@ -52,7 +56,7 @@ class Tangram(BaseModelClass): >>> from scvi.external import Tangram >>> ad_sc = anndata.read_h5ad(path_to_sc_anndata) >>> ad_sp = anndata.read_h5ad(path_to_sp_anndata) - >>> markers = pd.read_csv(path_to_markers, index_col=0) # genes to use for mapping + >>> markers = pd.read_csv(path_to_markers, index_col=0) # genes to use for mapping >>> mdata = mudata.MuData( { "sp_full": ad_sp, diff --git a/src/scvi/external/velovi/_model.py b/src/scvi/external/velovi/_model.py index 008b699360..39c5bc9c5f 100644 --- a/src/scvi/external/velovi/_model.py +++ b/src/scvi/external/velovi/_model.py @@ -2,14 +2,12 @@ import logging import warnings -from collections.abc import Iterable, Sequence -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch import torch.nn.functional as F -from anndata import AnnData from joblib import Parallel, delayed from scipy.stats import ttest_ind @@ -23,6 +21,12 @@ from scvi.train import TrainingPlan, TrainRunner from scvi.utils._docstrings import devices_dsp, setup_anndata_dsp +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Literal + + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/external/velovi/_module.py b/src/scvi/external/velovi/_module.py index 176fb0081c..114ac74f2d 100644 --- a/src/scvi/external/velovi/_module.py +++ b/src/scvi/external/velovi/_module.py @@ -1,9 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import Literal +from typing import TYPE_CHECKING -import numpy as np import torch import torch.nn.functional as F from torch import nn as nn @@ -15,6 +13,12 @@ from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import Encoder, FCLayers +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Literal + + import numpy as np + class DecoderVELOVI(nn.Module): """Decodes data from latent space of ``n_input`` dimensions ``n_output``dimensions. diff --git a/src/scvi/hub/_metadata.py b/src/scvi/hub/_metadata.py index bb8ef48fac..adbaa0f1e0 100644 --- a/src/scvi/hub/_metadata.py +++ b/src/scvi/hub/_metadata.py @@ -3,8 +3,8 @@ import json import os from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING -import torch from huggingface_hub import ModelCard, ModelCardData from scvi.data import AnnDataManager @@ -15,6 +15,9 @@ from ._template import template from ._url import validate_url +if TYPE_CHECKING: + import torch + @dataclass class HubMetadata: diff --git a/src/scvi/hub/_model.py b/src/scvi/hub/_model.py index 197bd87bb6..cd0ee758a1 100644 --- a/src/scvi/hub/_model.py +++ b/src/scvi/hub/_model.py @@ -8,10 +8,10 @@ import warnings from dataclasses import asdict from pathlib import Path +from typing import TYPE_CHECKING import anndata import rich -from anndata import AnnData from huggingface_hub import ModelCard, snapshot_download from rich.markdown import Markdown @@ -19,11 +19,15 @@ from scvi.data import cellxgene from scvi.data._download import _download from scvi.hub._metadata import HubMetadata, HubModelCardHelper -from scvi.model.base import BaseModelClass from scvi.utils import dependencies from ._constants import _SCVI_HUB +if TYPE_CHECKING: + from anndata import AnnData + + from scvi.model.base import BaseModelClass + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_amortizedlda.py b/src/scvi/model/_amortizedlda.py index 2b87006291..1817d96fc0 100644 --- a/src/scvi/model/_amortizedlda.py +++ b/src/scvi/model/_amortizedlda.py @@ -2,13 +2,12 @@ import collections.abc import logging -from collections.abc import Sequence +from typing import TYPE_CHECKING import numpy as np import pandas as pd import pyro import torch -from anndata import AnnData from scvi._constants import REGISTRY_KEYS from scvi.data import AnnDataManager @@ -18,6 +17,11 @@ from .base import BaseModelClass, PyroSviTrainMixin +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_autozi.py b/src/scvi/model/_autozi.py index 51475191be..08b4e35131 100644 --- a/src/scvi/model/_autozi.py +++ b/src/scvi/model/_autozi.py @@ -1,12 +1,10 @@ from __future__ import annotations import logging -from collections.abc import Sequence -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import torch -from anndata import AnnData from torch import logsumexp from torch.distributions import Beta @@ -20,6 +18,12 @@ from .base import BaseModelClass, VAEMixin +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + from anndata import AnnData + logger = logging.getLogger(__name__) # register buffer diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index b154d5e55e..b5e49a711d 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -2,10 +2,10 @@ import logging import warnings +from typing import TYPE_CHECKING import numpy as np import torch -from anndata import AnnData from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields @@ -19,6 +19,9 @@ from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_destvi.py b/src/scvi/model/_destvi.py index 1721efff50..fb5f5a9add 100644 --- a/src/scvi/model/_destvi.py +++ b/src/scvi/model/_destvi.py @@ -1,23 +1,28 @@ from __future__ import annotations import logging -from collections import OrderedDict -from collections.abc import Sequence +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager from scvi.data.fields import LayerField, NumericalObsField -from scvi.model import CondSCVI from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin from scvi.module import MRDeconv from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from collections import OrderedDict + from collections.abc import Sequence + + from anndata import AnnData + + from scvi.model import CondSCVI + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_jaxscvi.py b/src/scvi/model/_jaxscvi.py index 5f57d6f6e7..2a212aad53 100644 --- a/src/scvi/model/_jaxscvi.py +++ b/src/scvi/model/_jaxscvi.py @@ -1,12 +1,9 @@ from __future__ import annotations import logging -from collections.abc import Sequence -from typing import Literal +from typing import TYPE_CHECKING import jax.numpy as jnp -import numpy as np -from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager @@ -16,6 +13,13 @@ from .base import BaseModelClass, JaxTrainingMixin +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + import numpy as np + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_linear_scvi.py b/src/scvi/model/_linear_scvi.py index 5c3618540a..0cbefc3968 100644 --- a/src/scvi/model/_linear_scvi.py +++ b/src/scvi/model/_linear_scvi.py @@ -1,10 +1,9 @@ from __future__ import annotations import logging -from typing import Literal +from typing import TYPE_CHECKING import pandas as pd -from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager @@ -16,6 +15,11 @@ from .base import BaseModelClass, RNASeqMixin, VAEMixin +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index f28d6c1ab4..e17fa012ca 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -2,20 +2,17 @@ import logging import warnings -from collections.abc import Iterable, Sequence from collections.abc import Iterable as IterableClass from functools import partial -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from scipy.sparse import csr_matrix, vstack from torch.distributions import Normal from scvi import REGISTRY_KEYS, settings -from scvi._types import Number from scvi.data import AnnDataManager from scvi.data.fields import ( CategoricalJointObsField, @@ -42,6 +39,14 @@ from scvi.train._callbacks import SaveBestState from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Literal + + from anndata import AnnData + + from scvi._types import Number + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_peakvi.py b/src/scvi/model/_peakvi.py index 890cbde26d..a7204c17f3 100644 --- a/src/scvi/model/_peakvi.py +++ b/src/scvi/model/_peakvi.py @@ -2,14 +2,12 @@ import logging import warnings -from collections.abc import Iterable, Sequence from functools import partial -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData from scipy.sparse import csr_matrix, vstack from scvi import settings @@ -33,6 +31,12 @@ from .base import ArchesMixin, BaseModelClass, VAEMixin +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Literal + + from anndata import AnnData + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 55c6e7a980..dfab56bb74 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -2,9 +2,8 @@ import logging import warnings -from collections.abc import Sequence from copy import deepcopy -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -12,7 +11,6 @@ from anndata import AnnData from scvi import REGISTRY_KEYS, settings -from scvi._types import MinifiedDataType from scvi.data import AnnDataManager from scvi.data._constants import ( _ADATA_MINIFY_TYPE_UNS_KEY, @@ -21,7 +19,6 @@ ) from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute from scvi.data.fields import ( - BaseAnnDataField, CategoricalJointObsField, CategoricalObsField, LabelsWithUnlabeledObsField, @@ -40,9 +37,21 @@ from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp -from ._scvi import SCVI from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + from anndata import AnnData + + from scvi._types import MinifiedDataType + from scvi.data.fields import ( + BaseAnnDataField, + ) + + from ._scvi import SCVI + _SCANVI_LATENT_QZM = "_scanvi_latent_qzm" _SCANVI_LATENT_QZV = "_scanvi_latent_qzv" _SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 95c88c3541..ee6b7765e3 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -2,18 +2,15 @@ import logging import warnings -from typing import Literal +from typing import TYPE_CHECKING import numpy as np -from anndata import AnnData from scvi import REGISTRY_KEYS, settings -from scvi._types import MinifiedDataType from scvi.data import AnnDataManager from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( - BaseAnnDataField, CategoricalJointObsField, CategoricalObsField, LayerField, @@ -30,6 +27,16 @@ from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + + from scvi._types import MinifiedDataType + from scvi.data.fields import ( + BaseAnnDataField, + ) + _SCVI_LATENT_QZM = "_scvi_latent_qzm" _SCVI_LATENT_QZV = "_scvi_latent_qzv" _SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index a3d9f4da54..929ee1ad6c 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -2,19 +2,15 @@ import logging import warnings -from collections.abc import Iterable, Sequence from collections.abc import Iterable as IterableClass from functools import partial -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch -from anndata import AnnData -from mudata import MuData from scvi import REGISTRY_KEYS, settings -from scvi._types import Number from scvi.data import AnnDataManager, fields from scvi.data._utils import _check_nonnegative_integers from scvi.dataloaders import DataSplitter @@ -32,6 +28,15 @@ from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Literal + + from anndata import AnnData + from mudata import MuData + + from scvi._types import Number + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 09d597ef05..fd47bf1926 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -5,7 +5,7 @@ import os import warnings from abc import ABCMeta, abstractmethod -from collections.abc import Sequence +from typing import TYPE_CHECKING from uuid import uuid4 import numpy as np @@ -15,7 +15,6 @@ from mudata import MuData from scvi import REGISTRY_KEYS, settings -from scvi._types import AnnOrMuData, MinifiedDataType from scvi.data import AnnDataManager from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( @@ -37,6 +36,11 @@ from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from collections.abc import Sequence + + from scvi._types import AnnOrMuData, MinifiedDataType + logger = logging.getLogger(__name__) @@ -145,9 +149,9 @@ def to_device(self, device: str | int): -------- >>> adata = scvi.data.synthetic_iid() >>> model = scvi.model.SCVI(adata) - >>> model.to_device('cpu') # moves model to CPU - >>> model.to_device('cuda:0') # moves model to GPU 0 - >>> model.to_device(0) # also moves model to GPU 0 + >>> model.to_device("cpu") # moves model to CPU + >>> model.to_device("cuda:0") # moves model to GPU 0 + >>> model.to_device(0) # also moves model to GPU 0 """ my_device = torch.device(device) self.module.to(my_device) diff --git a/src/scvi/model/base/_embedding_mixin.py b/src/scvi/model/base/_embedding_mixin.py index 4c54c1a295..b099019c73 100644 --- a/src/scvi/model/base/_embedding_mixin.py +++ b/src/scvi/model/base/_embedding_mixin.py @@ -1,12 +1,16 @@ from __future__ import annotations -import numpy as np +from typing import TYPE_CHECKING + import torch -from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.module.base import EmbeddingModuleMixin +if TYPE_CHECKING: + import numpy as np + from anndata import AnnData + class EmbeddingMixin: """``EXPERIMENTAL`` Mixin class for initializing and using embeddings in a model. diff --git a/src/scvi/model/base/_log_likelihood.py b/src/scvi/model/base/_log_likelihood.py index 0c07e38a02..97df6f557b 100644 --- a/src/scvi/model/base/_log_likelihood.py +++ b/src/scvi/model/base/_log_likelihood.py @@ -1,12 +1,16 @@ from __future__ import annotations -from collections.abc import Callable, Iterator -from typing import Any +from typing import TYPE_CHECKING import torch -from torch import Tensor -from scvi.module.base import LossOutput +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing import Any + + from torch import Tensor + + from scvi.module.base import LossOutput def compute_elbo( diff --git a/src/scvi/model/base/_pyromixin.py b/src/scvi/model/base/_pyromixin.py index 58992bb085..8623764be3 100755 --- a/src/scvi/model/base/_pyromixin.py +++ b/src/scvi/model/base/_pyromixin.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from collections.abc import Callable +from typing import TYPE_CHECKING import numpy as np import torch @@ -15,6 +15,9 @@ from scvi.utils import track from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from collections.abc import Callable + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/base/_rnamixin.py b/src/scvi/model/base/_rnamixin.py index d7a7b93b16..ea681724a0 100644 --- a/src/scvi/model/base/_rnamixin.py +++ b/src/scvi/model/base/_rnamixin.py @@ -4,23 +4,28 @@ import logging import warnings from functools import partial -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import pandas as pd import torch import torch.distributions as db -from anndata import AnnData from pyro.distributions.util import deep_to from scvi import REGISTRY_KEYS, settings -from scvi._types import Number from scvi.distributions._utils import DistributionConcatenator, subset_distribution from scvi.model._utils import _get_batch_code_from_category, scrna_raw_counts_properties from scvi.model.base._de_core import _de_core from scvi.module.base._decorators import _move_data_to_device from scvi.utils import de_dsp, dependencies, unsupported_if_adata_minified +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + + from scvi._types import Number + try: from sparse import GCXS except ImportError: diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index 63c41adfda..c9dc2ca57c 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -3,22 +3,27 @@ import logging import os import warnings -from typing import Literal +from typing import TYPE_CHECKING import anndata import mudata import numpy as np -import numpy.typing as npt import pandas as pd import torch from anndata import AnnData, read_h5ad from scvi import settings -from scvi._types import AnnOrMuData from scvi.data._constants import _SETUP_METHOD_NAME from scvi.data._download import _download from scvi.model.base._constants import SAVE_KEYS +if TYPE_CHECKING: + from typing import Literal + + import numpy.typing as npt + + from scvi._types import AnnOrMuData + logger = logging.getLogger(__name__) diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index de3efd9fcb..82e83704fd 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -1,12 +1,15 @@ from __future__ import annotations -from lightning import LightningDataModule +from typing import TYPE_CHECKING from scvi.dataloaders import DataSplitter from scvi.model._utils import get_max_epochs_heuristic, use_distributed_sampler from scvi.train import TrainingPlan, TrainRunner from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + from lightning import LightningDataModule + class UnsupervisedTrainingMixin: """General purpose unsupervised train method.""" diff --git a/src/scvi/model/base/_vaemixin.py b/src/scvi/model/base/_vaemixin.py index f134c31fc7..7de4fe43c9 100644 --- a/src/scvi/model/base/_vaemixin.py +++ b/src/scvi/model/base/_vaemixin.py @@ -1,15 +1,21 @@ from __future__ import annotations import logging -from collections.abc import Iterator, Sequence +from typing import TYPE_CHECKING -import numpy.typing as npt import torch -from anndata import AnnData -from torch import Tensor from scvi.utils import unsupported_if_adata_minified +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + import numpy.typing as npt + from anndata import AnnData + from torch import Tensor + from torch.distributions import Distribution + + logger = logging.getLogger(__name__) @@ -265,7 +271,7 @@ def get_latent_representation( a tuple of arrays ``(n_obs, n_latent)`` with the mean and variance of the latent distribution. """ - from torch.distributions import Distribution, Normal + from torch.distributions import Normal from torch.nn.functional import softmax from scvi.module._constants import MODULE_KEYS diff --git a/src/scvi/model/utils/_mde.py b/src/scvi/model/utils/_mde.py index bc719efb0e..b1e499357f 100644 --- a/src/scvi/model/utils/_mde.py +++ b/src/scvi/model/utils/_mde.py @@ -1,16 +1,19 @@ from __future__ import annotations import logging +from typing import TYPE_CHECKING -import numpy as np import pandas as pd import torch -from scipy.sparse import spmatrix from scvi import settings from scvi.model._utils import parse_device_args from scvi.utils._docstrings import devices_dsp +if TYPE_CHECKING: + import numpy as np + from scipy.sparse import spmatrix + logger = logging.getLogger(__name__) diff --git a/src/scvi/module/_scanvae.py b/src/scvi/module/_scanvae.py index 828439753d..55c7d0a8a0 100644 --- a/src/scvi/module/_scanvae.py +++ b/src/scvi/module/_scanvae.py @@ -1,17 +1,15 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import torch -from torch.distributions import Categorical, Distribution, Normal +from torch.distributions import Categorical, Normal from torch.distributions import kl_divergence as kl from torch.nn import functional as F from scvi import REGISTRY_KEYS from scvi.data import _constants -from scvi.model.base import BaseModelClass from scvi.module.base import LossOutput, auto_move_data from scvi.nn import Decoder, Encoder @@ -19,6 +17,14 @@ from ._utils import broadcast_labels from ._vae import VAE +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Literal + + from torch.distributions import Distribution + + from scvi.model.base import BaseModelClass + class SCANVAE(VAE): """Single-cell annotation using variational inference. diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index c40faf60f2..920b65ca18 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -2,12 +2,10 @@ import logging import warnings -from collections.abc import Callable -from typing import Literal +from typing import TYPE_CHECKING import numpy as np import torch -from torch.distributions import Distribution from torch.nn.functional import one_hot from scvi import REGISTRY_KEYS, settings @@ -19,6 +17,12 @@ auto_move_data, ) +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Literal + + from torch.distributions import Distribution + logger = logging.getLogger(__name__) diff --git a/src/scvi/module/_vaec.py b/src/scvi/module/_vaec.py index a2c6b19448..627ca6a810 100644 --- a/src/scvi/module/_vaec.py +++ b/src/scvi/module/_vaec.py @@ -1,13 +1,17 @@ from __future__ import annotations -import numpy as np +from typing import TYPE_CHECKING + import torch -from torch.distributions import Distribution from scvi import REGISTRY_KEYS from scvi.module._constants import MODULE_KEYS from scvi.module.base import BaseModuleClass, auto_move_data +if TYPE_CHECKING: + import numpy as np + from torch.distributions import Distribution + class VAEC(BaseModuleClass): """Conditional Variational auto-encoder model. diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index c61b758b12..39097c9039 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -1,31 +1,36 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Callable, Iterable from dataclasses import field -from typing import Any +from typing import TYPE_CHECKING import flax import jax -import jax.numpy as jnp import numpy as np import pyro -import torch from flax.training import train_state from jax import random -from jaxlib.xla_extension import Device -from numpyro.distributions import Distribution -from pyro.infer.predictive import Predictive from torch import nn from scvi import settings -from scvi._types import LossRecord, MinifiedDataType, Tensor from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.utils._jax import device_selecting_PRNGKey from ._decorators import auto_move_data from ._pyro import AutoMoveDataPredictive +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from typing import Any + + import jax.numpy as jnp + import torch + from jaxlib.xla_extension import Device + from numpyro.distributions import Distribution + from pyro.infer.predictive import Predictive + + from scvi._types import LossRecord, MinifiedDataType, Tensor + @flax.struct.dataclass class LossOutput: diff --git a/src/scvi/nn/_embedding.py b/src/scvi/nn/_embedding.py index 396e6a47a8..50ffeb121d 100644 --- a/src/scvi/nn/_embedding.py +++ b/src/scvi/nn/_embedding.py @@ -1,10 +1,13 @@ from __future__ import annotations -from collections.abc import Callable +from typing import TYPE_CHECKING import torch from torch import nn +if TYPE_CHECKING: + from collections.abc import Callable + def _partial_freeze_hook_factory(freeze: int) -> Callable[[torch.Tensor], torch.Tensor]: """Factory for a hook that freezes the first ``freeze`` entries in the gradient. diff --git a/src/scvi/train/_callbacks.py b/src/scvi/train/_callbacks.py index b097f03440..4f6602dd68 100644 --- a/src/scvi/train/_callbacks.py +++ b/src/scvi/train/_callbacks.py @@ -6,9 +6,9 @@ from copy import deepcopy from datetime import datetime from shutil import rmtree +from typing import TYPE_CHECKING import flax -import lightning.pytorch as pl import numpy as np import torch from lightning.pytorch.callbacks import Callback, ModelCheckpoint @@ -16,10 +16,14 @@ from lightning.pytorch.utilities import rank_zero_info from scvi import settings -from scvi.dataloaders import AnnDataLoader from scvi.model.base import BaseModelClass from scvi.model.base._save_load import _load_saved_files +if TYPE_CHECKING: + import lightning.pytorch as pl + + from scvi.dataloaders import AnnDataLoader + MetricCallable = Callable[[BaseModelClass], float] @@ -220,20 +224,20 @@ def __init__( if mode == "min": self.monitor_op = np.less - self.best_module_metric_val = np.Inf + self.best_module_metric_val = np.inf self.mode = "min" elif mode == "max": self.monitor_op = np.greater - self.best_module_metric_val = -np.Inf + self.best_module_metric_val = -np.inf self.mode = "max" else: if "acc" in self.monitor or self.monitor.startswith("fmeasure"): self.monitor_op = np.greater - self.best_module_metric_val = -np.Inf + self.best_module_metric_val = -np.inf self.mode = "max" else: self.monitor_op = np.less - self.best_module_metric_val = np.Inf + self.best_module_metric_val = np.inf self.mode = "min" def check_monitor_top(self, current): diff --git a/src/scvi/utils/_track.py b/src/scvi/utils/_track.py index 3fc2878cb5..bac927e20c 100644 --- a/src/scvi/utils/_track.py +++ b/src/scvi/utils/_track.py @@ -36,7 +36,8 @@ def track( -------- >>> from scvi.utils import track >>> my_list = [1, 2, 3] - >>> for i in track(my_list): print(i) + >>> for i in track(my_list): + ... print(i) """ if style is None: style = settings.progress_bar_style diff --git a/tests/autotune/test_experiment.py b/tests/autotune/test_experiment.py index be7f1ed7a6..343dafd478 100644 --- a/tests/autotune/test_experiment.py +++ b/tests/autotune/test_experiment.py @@ -1,4 +1,4 @@ -from pytest import raises +import pytest from scvi import settings from scvi.autotune import AutotuneExperiment @@ -27,108 +27,108 @@ def test_experiment_init(save_path: str): assert hasattr(experiment, "id") assert experiment.id is not None assert isinstance(experiment.id, str) - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.id = "new_id" assert hasattr(experiment, "data") assert experiment.data is not None assert experiment.data is adata - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.data = "new_adata" assert hasattr(experiment, "setup_method_name") assert experiment.setup_method_name is not None assert experiment.setup_method_name == "setup_anndata" - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.setup_method_name = "new_setup_method_name" assert hasattr(experiment, "setup_method_args") assert experiment.setup_method_args is not None - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.setup_method_args = "new_setup_method_args" assert hasattr(experiment, "model_cls") assert experiment.model_cls is not None assert experiment.model_cls is SCVI - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.model_cls = "new_model_cls" assert hasattr(experiment, "metrics") assert experiment.metrics is not None assert experiment.metrics == ["elbo_validation"] - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.metrics = "new_metrics" assert hasattr(experiment, "mode") assert experiment.mode is not None assert experiment.mode == "min" - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.mode = "new_mode" assert hasattr(experiment, "search_space") assert experiment.search_space is not None - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.search_space = "new_search_space" assert hasattr(experiment, "num_samples") assert experiment.num_samples is not None assert experiment.num_samples == 1 - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.num_samples = 2 assert hasattr(experiment, "scheduler") assert experiment.scheduler is not None - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.scheduler = "new_scheduler" assert hasattr(experiment, "scheduler_kwargs") assert experiment.scheduler_kwargs is not None assert experiment.scheduler_kwargs == {} - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.scheduler_kwargs = "new_scheduler_kwargs" assert hasattr(experiment, "searcher") assert experiment.searcher is not None - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.searcher = "new_searcher" assert hasattr(experiment, "searcher_kwargs") assert experiment.searcher_kwargs is not None assert experiment.searcher_kwargs == {} - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.searcher_kwargs = "new_searcher_kwargs" assert hasattr(experiment, "seed") assert experiment.seed is not None assert experiment.seed == settings.seed - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.seed = 2 assert hasattr(experiment, "resources") assert experiment.resources is not None assert experiment.resources == {} - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.resources = "new_resources" assert hasattr(experiment, "name") assert experiment.name is not None assert experiment.name.startswith("scvi") - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.name = "new_name" assert hasattr(experiment, "logging_dir") assert experiment.logging_dir is not None - with raises(AttributeError): + with pytest.raises(AttributeError): experiment.logging_dir = "new_logging_dir" - with raises(AttributeError): + with pytest.raises(AttributeError): _ = experiment.result_grid # set after running the tuner def test_experiment_no_setup_anndata(): adata = synthetic_iid() - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -147,7 +147,7 @@ def test_experiment_invalid_metrics(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -160,7 +160,7 @@ def test_experiment_invalid_metrics(): }, num_samples=1, ) - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -173,7 +173,7 @@ def test_experiment_invalid_metrics(): }, num_samples=1, ) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, @@ -192,7 +192,7 @@ def test_experiment_invalid_mode(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -205,7 +205,7 @@ def test_experiment_invalid_mode(): }, num_samples=1, ) - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -224,7 +224,7 @@ def test_experiment_invalid_search_space(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, @@ -233,7 +233,7 @@ def test_experiment_invalid_search_space(): search_space=None, num_samples=1, ) - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -242,7 +242,7 @@ def test_experiment_invalid_search_space(): search_space={}, num_samples=1, ) - with raises(KeyError): + with pytest.raises(KeyError): _ = AutotuneExperiment( SCVI, adata, @@ -261,7 +261,7 @@ def test_experiment_invalid_num_samples(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, @@ -280,7 +280,7 @@ def test_experiment_invalid_scheduler(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -294,7 +294,7 @@ def test_experiment_invalid_scheduler(): num_samples=1, scheduler="invalid option", ) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, @@ -314,7 +314,7 @@ def test_experiment_invalid_scheduler_kwargs(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, @@ -334,7 +334,7 @@ def test_experiment_invalid_searcher(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(ValueError): + with pytest.raises(ValueError): _ = AutotuneExperiment( SCVI, adata, @@ -348,7 +348,7 @@ def test_experiment_invalid_searcher(): num_samples=1, searcher="invalid option", ) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, @@ -368,7 +368,7 @@ def test_experiment_invalid_searcher_kwargs(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, @@ -388,7 +388,7 @@ def test_experiment_invalid_seed(): adata = synthetic_iid() SCVI.setup_anndata(adata) - with raises(TypeError): + with pytest.raises(TypeError): _ = AutotuneExperiment( SCVI, adata, diff --git a/tests/conftest.py b/tests/conftest.py index 4fdc21a1c3..8d1973581f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -144,4 +144,4 @@ def set_seed(request): from scvi import settings settings.seed = int(request.config.getoption("--seed")) - yield + return None diff --git a/tests/criticism/test_criticism.py b/tests/criticism/test_criticism.py index 315bffea64..c6dcc7e7dd 100644 --- a/tests/criticism/test_criticism.py +++ b/tests/criticism/test_criticism.py @@ -1,8 +1,9 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pandas as pd -from anndata import AnnData from sparse import GCXS from xarray import Dataset @@ -10,6 +11,9 @@ from scvi.data import synthetic_iid from scvi.model import SCVI +if TYPE_CHECKING: + from anndata import AnnData + def get_ppc_with_samples(adata: AnnData, n_samples: int = 2, indices: list[int] | None = None): # create and train models diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index 40fe725ffa..b9677cf145 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -1,6 +1,6 @@ def test_poisson_gene_selection(): import numpy as np - from pytest import raises + import pytest from scvi.data import poisson_gene_selection, synthetic_iid @@ -24,10 +24,10 @@ def test_poisson_gene_selection(): X = adata.X adata.X = -X - with raises(ValueError): + with pytest.raises(ValueError): poisson_gene_selection(adata, batch_key="batch", n_top_genes=n_top_genes) adata.X = 0.25 * X - with raises(ValueError): + with pytest.raises(ValueError): poisson_gene_selection(adata, batch_key="batch", n_top_genes=n_top_genes) diff --git a/tests/external/gimvi/test_gimvi.py b/tests/external/gimvi/test_gimvi.py index a343e8f937..c97d2ed535 100644 --- a/tests/external/gimvi/test_gimvi.py +++ b/tests/external/gimvi/test_gimvi.py @@ -121,9 +121,8 @@ def test_gimvi(): labels_key="labels", ) model = GIMVI(adata_seq, adata_spatial, n_latent=10) - assert hasattr(model.module, "library_log_means_0") and not hasattr( - model.module, "library_log_means_1" - ) + assert hasattr(model.module, "library_log_means_0") + assert not hasattr(model.module, "library_log_means_1") model.train(1, check_val_every_n_epoch=1, train_size=0.5) model.get_latent_representation() model.get_imputed_values() @@ -153,9 +152,8 @@ def test_gimvi_model_library_size(): labels_key="labels", ) model = GIMVI(adata_seq, adata_spatial, model_library_size=[True, True], n_latent=10) - assert hasattr(model.module, "library_log_means_0") and hasattr( - model.module, "library_log_means_1" - ) + assert hasattr(model.module, "library_log_means_0") + assert hasattr(model.module, "library_log_means_1") model.train(1, check_val_every_n_epoch=1, train_size=0.5) model.get_latent_representation() model.get_imputed_values() diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 63c2d1e5a0..05edb27496 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -1,15 +1,19 @@ from __future__ import annotations import os -from typing import Any +from typing import TYPE_CHECKING import numpy as np import pytest -from anndata import AnnData from scvi.data import synthetic_iid from scvi.external import MRVI +if TYPE_CHECKING: + from typing import Any + + from anndata import AnnData + @pytest.fixture(scope="session") def adata(): @@ -54,7 +58,7 @@ def test_mrvi(model: MRVI, adata: AnnData, save_path: str): @pytest.mark.optional @pytest.mark.parametrize( - "setup_kwargs, de_kwargs", + ("setup_kwargs", "de_kwargs"), [ ( {"sample_key": "sample_str", "batch_key": "batch"}, diff --git a/tests/external/tangram/test_tangram.py b/tests/external/tangram/test_tangram.py index d1c9707d3e..b1a166b1b3 100644 --- a/tests/external/tangram/test_tangram.py +++ b/tests/external/tangram/test_tangram.py @@ -21,7 +21,7 @@ def _get_mdata(sparse_format: str | None = None): @pytest.mark.parametrize( - "density_prior_key,constrained", + ("density_prior_key", "constrained"), [ (None, False), ("rna_count_based_density", False), @@ -61,10 +61,10 @@ def test_tangram_errors(): with pytest.raises(ValueError): Tangram(mdata, constrained=True, target_count=None) + Tangram.setup_mudata( + mdata, + density_prior_key="bad_prior", + modalities=modalities, + ) with pytest.raises(ValueError): - Tangram.setup_mudata( - mdata, - density_prior_key="bad_prior", - modalities=modalities, - ) Tangram(mdata) diff --git a/tests/hub/test_hub_model.py b/tests/hub/test_hub_model.py index a27b81772d..8372e5c6a0 100644 --- a/tests/hub/test_hub_model.py +++ b/tests/hub/test_hub_model.py @@ -186,9 +186,11 @@ def test_hub_model_save(save_anndata: bool, save_path: str): hub_model.save(overwrite=True) card_path = os.path.join(model_path, _SCVI_HUB.MODEL_CARD_FILE_NAME) - assert os.path.exists(card_path) and os.path.isfile(card_path) + assert os.path.exists(card_path) + assert os.path.isfile(card_path) metadata_path = os.path.join(model_path, _SCVI_HUB.METADATA_FILE_NAME) - assert os.path.exists(metadata_path) and os.path.isfile(metadata_path) + assert os.path.exists(metadata_path) + assert os.path.isfile(metadata_path) with pytest.raises(FileExistsError): hub_model.save(overwrite=False) diff --git a/tests/model/test_amortizedlda.py b/tests/model/test_amortizedlda.py index 770e311a6c..337a634db6 100644 --- a/tests/model/test_amortizedlda.py +++ b/tests/model/test_amortizedlda.py @@ -69,22 +69,18 @@ def test_lda_model(n_topics: int = 5): adata_gbt = mod.get_feature_by_topic().to_numpy() assert np.allclose(adata_gbt.sum(axis=0), 1) adata_lda = mod.get_latent_representation(adata).to_numpy() - assert ( - adata_lda.shape == (adata.n_obs, n_topics) - and np.all((adata_lda <= 1) & (adata_lda >= 0)) - and np.allclose(adata_lda.sum(axis=1), 1) - ) + assert adata_lda.shape == (adata.n_obs, n_topics) + assert np.all((adata_lda <= 1) & (adata_lda >= 0)) + assert np.allclose(adata_lda.sum(axis=1), 1) mod.get_elbo() mod.get_perplexity() adata2 = synthetic_iid() AmortizedLDA.setup_anndata(adata2) adata2_lda = mod.get_latent_representation(adata2).to_numpy() - assert ( - adata2_lda.shape == (adata2.n_obs, n_topics) - and np.all((adata2_lda <= 1) & (adata2_lda >= 0)) - and np.allclose(adata2_lda.sum(axis=1), 1) - ) + assert adata2_lda.shape == (adata2.n_obs, n_topics) + assert np.all((adata2_lda <= 1) & (adata2_lda >= 0)) + assert np.allclose(adata2_lda.sum(axis=1), 1) mod.get_elbo(adata2) mod.get_perplexity(adata2) diff --git a/tests/model/test_autozi.py b/tests/model/test_autozi.py index 578e51939b..fcf2b82cf5 100644 --- a/tests/model/test_autozi.py +++ b/tests/model/test_autozi.py @@ -122,9 +122,8 @@ def test_autozi(): use_observed_lib_size=False, ) autozivae.train(1, plan_kwargs={"lr": 1e-2}, check_val_every_n_epoch=1) - assert hasattr(autozivae.module, "library_log_means") and hasattr( - autozivae.module, "library_log_vars" - ) + assert hasattr(autozivae.module, "library_log_means") + assert hasattr(autozivae.module, "library_log_vars") assert len(autozivae.history["elbo_train"]) == 1 assert len(autozivae.history["elbo_validation"]) == 1 autozivae.get_elbo(indices=autozivae.validation_indices) diff --git a/tests/model/test_differential.py b/tests/model/test_differential.py index 46839f2826..1da06ea07a 100644 --- a/tests/model/test_differential.py +++ b/tests/model/test_differential.py @@ -82,7 +82,8 @@ def test_differential_computation(save_path): cell_idx2 = ~cell_idx1 dc.get_bayes_factors(cell_idx1, cell_idx2, mode="vanilla", use_permutation=True) res = dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", use_permutation=False) - assert (res["delta"] == 0.5) and (res["pseudocounts"] == 0.0) + assert res["delta"] == 0.5 + assert res["pseudocounts"] == 0.0 res = dc.get_bayes_factors( cell_idx1, cell_idx2, mode="change", use_permutation=False, delta=None ) diff --git a/tests/model/test_jaxscvi.py b/tests/model/test_jaxscvi.py index 38b7fe4096..dac25b7c89 100644 --- a/tests/model/test_jaxscvi.py +++ b/tests/model/test_jaxscvi.py @@ -24,7 +24,8 @@ def test_jax_scvi(n_latent=5): z1 = model.get_latent_representation(give_mean=True, n_samples=1) assert z1.ndim == 2 z2 = model.get_latent_representation(give_mean=False, n_samples=15) - assert (z2.ndim == 3) and (z2.shape[0] == 15) + assert z2.ndim == 3 + assert z2.shape[0] == 15 def test_jax_scvi_training(n_latent: int = 5, dropout_rate: float = 0.1): diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index f707ad1ae2..5f87ed804c 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -169,10 +169,10 @@ def test_scanvi_from_scvi(save_path): model.save(save_path, overwrite=True) loaded_model = SCVI.load(save_path, adata=adata_before_setup) + adata2 = synthetic_iid() + # just add this to pretend the data is minified + adata2.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = ADATA_MINIFY_TYPE.LATENT_POSTERIOR with pytest.raises(ValueError) as e: - adata2 = synthetic_iid() - # just add this to pretend the data is minified - adata2.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = ADATA_MINIFY_TYPE.LATENT_POSTERIOR scvi.model.SCANVI.from_scvi_model(loaded_model, "label_0", adata=adata2) assert str(e.value) == "Please provide a non-minified `adata` to initialize scanvi." diff --git a/tests/model/test_pyro.py b/tests/model/test_pyro.py index 0115a6bfbd..1c0e4a946c 100644 --- a/tests/model/test_pyro.py +++ b/tests/model/test_pyro.py @@ -1,12 +1,12 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING import numpy as np import pyro import pyro.distributions as dist import torch -from anndata import AnnData from pyro import clear_param_store from pyro.infer.autoguide import AutoNormal, init_to_mean from pyro.nn import PyroModule, PyroSample @@ -27,6 +27,9 @@ from scvi.nn import DecoderSCVI, Encoder from scvi.train import LowLevelPyroTrainingPlan, PyroTrainingPlan, Trainer +if TYPE_CHECKING: + from anndata import AnnData + class BayesianRegressionPyroModel(PyroModule): def __init__(self, in_features, out_features, per_cell_weight=False): diff --git a/tests/model/test_scanvi.py b/tests/model/test_scanvi.py index 2bf791e3e2..aede994029 100644 --- a/tests/model/test_scanvi.py +++ b/tests/model/test_scanvi.py @@ -162,7 +162,8 @@ def test_scanvi(): assert scanvi_model.module.state_dict() is not m.module.state_dict() scanvi_pxr = scanvi_model.module.state_dict().get("px_r", None) scvi_pxr = m.module.state_dict().get("px_r", None) - assert scanvi_pxr is not None and scvi_pxr is not None + assert scanvi_pxr is not None + assert scvi_pxr is not None assert scanvi_pxr is not scvi_pxr scanvi_model.train(1) @@ -272,7 +273,8 @@ def test_scanvi_with_external_indices(): assert scanvi_model.module.state_dict() is not m.module.state_dict() scanvi_pxr = scanvi_model.module.state_dict().get("px_r", None) scvi_pxr = m.module.state_dict().get("px_r", None) - assert scanvi_pxr is not None and scvi_pxr is not None + assert scanvi_pxr is not None + assert scvi_pxr is not None assert scanvi_pxr is not scvi_pxr scanvi_model.train(1) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 6276693e19..36a3b7f7cd 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -233,6 +233,7 @@ def test_scvi(gene_likelihood: str, n_latent: int = 5): # test view_anndata_setup with different anndata before transfer setup with pytest.raises(ValueError): model.view_anndata_setup(adata=adata2) + with pytest.raises(ValueError): model.view_anndata_setup(adata=adata2, hide_state_registries=True) # test get methods with different anndata model.get_elbo(adata2) @@ -492,9 +493,9 @@ def test_setting_adata_attr(n_latent: int = 5): adata2 = synthetic_iid() model.adata = adata2 + rep = model.get_latent_representation(adata) + rep2 = model.get_latent_representation() with pytest.raises(AssertionError): - rep = model.get_latent_representation(adata) - rep2 = model.get_latent_representation() np.testing.assert_array_equal(rep, rep2) orig_manager = model.get_anndata_manager(adata) @@ -503,10 +504,9 @@ def test_setting_adata_attr(n_latent: int = 5): adata3 = synthetic_iid() del adata3.obs["batch"] - # validation catches no batch + # validation catches no batch column. with pytest.raises(KeyError): model.adata = adata3 - model.get_latent_representation() def assert_dict_is_subset(d1, d2): @@ -1065,14 +1065,11 @@ def test_scvi_library_size_update(save_path): SCVI.setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCVI(adata1, n_latent=n_latent, use_observed_lib_size=False) - assert ( - getattr(model.module, "library_log_means", None) is not None - and model.module.library_log_means.shape == (1, 2) - and model.module.library_log_means.count_nonzero().item() == 2 - ) - assert getattr( - model.module, "library_log_vars", None - ) is not None and model.module.library_log_vars.shape == ( + assert getattr(model.module, "library_log_means", None) is not None + assert model.module.library_log_means.shape == (1, 2) + assert model.module.library_log_means.count_nonzero().item() == 2 + assert getattr(model.module, "library_log_vars", None) is not None + assert model.module.library_log_vars.shape == ( 1, 2, ) @@ -1086,17 +1083,13 @@ def test_scvi_library_size_update(save_path): adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) model2 = SCVI.load_query_data(adata2, dir_path, inplace_subset_query_vars=True) - assert ( - getattr(model2.module, "library_log_means", None) is not None - and model2.module.library_log_means.shape == (1, 4) - and model2.module.library_log_means[:, :2].equal(model.module.library_log_means) - and model2.module.library_log_means.count_nonzero().item() == 4 - ) - assert ( - getattr(model2.module, "library_log_vars", None) is not None - and model2.module.library_log_vars.shape == (1, 4) - and model2.module.library_log_vars[:, :2].equal(model.module.library_log_vars) - ) + assert getattr(model2.module, "library_log_means", None) is not None + assert model2.module.library_log_means.shape == (1, 4) + assert model2.module.library_log_means[:, :2].equal(model.module.library_log_means) + assert model2.module.library_log_means.count_nonzero().item() == 4 + assert getattr(model2.module, "library_log_vars", None) is not None + assert model2.module.library_log_vars.shape == (1, 4) + assert model2.module.library_log_vars[:, :2].equal(model.module.library_log_vars) def test_set_seed(n_latent: int = 5, seed: int = 1): diff --git a/tests/model/test_totalvi.py b/tests/model/test_totalvi.py index 9b2476faae..e47e4bfec2 100644 --- a/tests/model/test_totalvi.py +++ b/tests/model/test_totalvi.py @@ -278,7 +278,8 @@ def test_totalvi_model_library_size(save_path): n_latent = 10 model = TOTALVI(adata, n_latent=n_latent, use_observed_lib_size=False) - assert hasattr(model.module, "library_log_means") and hasattr(model.module, "library_log_vars") + assert hasattr(model.module, "library_log_means") + assert hasattr(model.module, "library_log_vars") model.train(1, train_size=0.5) assert model.is_trained is True model.get_elbo() @@ -300,16 +301,14 @@ def test_totalvi_size_factor(): # Test size_factor_key overrides use_observed_lib_size. model = TOTALVI(adata, n_latent=n_latent, use_observed_lib_size=False) - assert not hasattr(model.module, "library_log_means") and not hasattr( - model.module, "library_log_vars" - ) + assert not hasattr(model.module, "library_log_means") + assert not hasattr(model.module, "library_log_vars") assert model.module.use_size_factor_key model.train(1, train_size=0.5) model = TOTALVI(adata, n_latent=n_latent, use_observed_lib_size=True) - assert not hasattr(model.module, "library_log_means") and not hasattr( - model.module, "library_log_vars" - ) + assert not hasattr(model.module, "library_log_means") + assert not hasattr(model.module, "library_log_vars") assert model.module.use_size_factor_key model.train(1, train_size=0.5) @@ -531,7 +530,8 @@ def test_totalvi_model_library_size_mudata(): n_latent = 10 model = TOTALVI(mdata, n_latent=n_latent, use_observed_lib_size=False) - assert hasattr(model.module, "library_log_means") and hasattr(model.module, "library_log_vars") + assert hasattr(model.module, "library_log_means") + assert hasattr(model.module, "library_log_vars") model.train(1, train_size=0.5) assert model.is_trained is True model.get_elbo() @@ -560,16 +560,14 @@ def test_totalvi_size_factor_mudata(): # Test size_factor_key overrides use_observed_lib_size. model = TOTALVI(mdata, n_latent=n_latent, use_observed_lib_size=False) - assert not hasattr(model.module, "library_log_means") and not hasattr( - model.module, "library_log_vars" - ) + assert not hasattr(model.module, "library_log_means") + assert not hasattr(model.module, "library_log_vars") assert model.module.use_size_factor_key model.train(1, train_size=0.5) model = TOTALVI(mdata, n_latent=n_latent, use_observed_lib_size=True) - assert not hasattr(model.module, "library_log_means") and not hasattr( - model.module, "library_log_vars" - ) + assert not hasattr(model.module, "library_log_means") + assert not hasattr(model.module, "library_log_vars") assert model.module.use_size_factor_key model.train(1, train_size=0.5) diff --git a/tests/nn/test_embedding.py b/tests/nn/test_embedding.py index 855e48aeb2..3ed5486967 100644 --- a/tests/nn/test_embedding.py +++ b/tests/nn/test_embedding.py @@ -2,16 +2,16 @@ from os.path import join +import pytest import torch -from pytest import mark, raises from scvi.nn import Embedding -@mark.parametrize("num_embeddings", [10]) -@mark.parametrize("embedding_dim", [5]) -@mark.parametrize("init", [2, [0, 1]]) -@mark.parametrize("freeze_prev", [True, False]) +@pytest.mark.parametrize("num_embeddings", [10]) +@pytest.mark.parametrize("embedding_dim", [5]) +@pytest.mark.parametrize("init", [2, [0, 1]]) +@pytest.mark.parametrize("freeze_prev", [True, False]) def test_embedding_extend( num_embeddings: int, embedding_dim: int, @@ -48,9 +48,9 @@ def test_embedding_extend( def test_embedding_extend_invalid_init(num_embeddings: int = 10, embedding_dim: int = 5): embedding = Embedding(num_embeddings, embedding_dim) - with raises(ValueError): + with pytest.raises(ValueError): Embedding.extend(embedding, init=0) - with raises(TypeError): + with pytest.raises(TypeError): Embedding.extend(embedding, init="invalid") diff --git a/tests/train/test_trainingplans.py b/tests/train/test_trainingplans.py index 8c4af9f142..cf623aad49 100644 --- a/tests/train/test_trainingplans.py +++ b/tests/train/test_trainingplans.py @@ -9,14 +9,13 @@ @pytest.mark.parametrize( - "current,n_warm_up,min_kl_weight,max_kl_weight,expected", + ("current", "n_warm_up", "min_kl_weight", "max_kl_weight", "expected"), [ (0, 400, 0.0, 1.0, 0.0), (200, 400, 0.0, 1.0, 0.5), (400, 400, 0.0, 1.0, 1.0), (0, 400, 0.5, 1.0, 0.5), (200, 400, 0.5, 1.0, 0.75), - (400, 400, 0.0, 1.0, 1.0), (400, 400, 0.0, 2.0, 2.0), ], ) @@ -40,7 +39,7 @@ def test_compute_kl_weight_min_greater_max(): @pytest.mark.parametrize( - "epoch,step,n_epochs_kl_warmup,n_steps_kl_warmup,expected", + ("epoch", "step", "n_epochs_kl_warmup", "n_steps_kl_warmup", "expected"), [ (0, 100, 100, 100, 0.0), (50, 200, 100, 1000, 0.5), From 54b393287712ff3390dfb80f2b041b541ed04e5d Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Wed, 18 Sep 2024 00:08:21 +0300 Subject: [PATCH 07/22] docs: fix progress bar interactive style in RTD (#2959) close https://github.com/scverse/scvi-tools/issues/2958 --- src/scvi/utils/_track.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/utils/_track.py b/src/scvi/utils/_track.py index bac927e20c..c2493340b8 100644 --- a/src/scvi/utils/_track.py +++ b/src/scvi/utils/_track.py @@ -4,7 +4,7 @@ from rich.console import Console from rich.progress import track as track_base -from tqdm import tqdm as tqdm_base +from tqdm.auto import tqdm as tqdm_base from scvi import settings From e8f9b369865e80567d5e900d32d8e7731497b974 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 01:32:44 +0300 Subject: [PATCH 08/22] docs: automated update of tutorials (#2977) automated update of tutorials submodule Co-authored-by: ori-kron-wis <175299014+ori-kron-wis@users.noreply.github.com> --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 0f9908fdc7..1a9b54617a 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 0f9908fdc7565b147493db0bca6c46d2608db5eb +Subproject commit 1a9b54617a1f685f29c17a2be421c2f964152183 From e0e0209183e03ff97e014fc3a73819e503c3504d Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Wed, 18 Sep 2024 15:36:12 +0300 Subject: [PATCH 09/22] Update Dockerfile (#2978) --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 5f4d005e50..305b8bd25c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.5.0-runtime-ubuntu22.04 +FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 FROM python:3.12 AS base RUN pip install --no-cache-dir uv From 3738f131daa6e5c879b1844af55d57ae007f5c3c Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Fri, 20 Sep 2024 11:10:25 +0300 Subject: [PATCH 10/22] test: Fix hugging face create repo test which was failing (#2982) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added a check on whether repo exists before creating it, which failed …previous test --------- Co-authored-by: Can Ergen --- tests/hub/test_hub_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/hub/test_hub_model.py b/tests/hub/test_hub_model.py index 8372e5c6a0..98e6139dcf 100644 --- a/tests/hub/test_hub_model.py +++ b/tests/hub/test_hub_model.py @@ -222,7 +222,10 @@ def test_hub_model_large_training_adata(request, save_path): @pytest.mark.private def test_hub_model_create_repo_hf(save_path: str): - from huggingface_hub import delete_repo + from huggingface_hub import delete_repo, repo_exists + + if repo_exists("scvi-tools/test-scvi-create"): + delete_repo("scvi-tools/test-scvi-create", token=os.environ["HF_API_TOKEN"]) hub_model = prep_scvi_hub_model(save_path) hub_model.push_to_huggingface_hub( From 08be4b3ce35271496aeec8fea3f94bbb134b1bd3 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Fri, 20 Sep 2024 11:59:17 +0300 Subject: [PATCH 11/22] Update pyproject.toml (#2985) update ruff target version to py312 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- tests/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7d4c961d48..9b1221c290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ markers = [ src = ["src"] line-length = 99 indent-width = 4 -target-version = "py310" +target-version = "py312" # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/tests/conftest.py b/tests/conftest.py index 8d1973581f..f3d0ab5526 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import shutil -from distutils.dir_util import copy_tree import pytest +from distutils.dir_util import copy_tree import scvi from tests.data.utils import generic_setup_adata_manager From 70371593e1e7cef913a69467ba1643adcd072b33 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Sat, 21 Sep 2024 07:34:47 +0300 Subject: [PATCH 12/22] docs: Counterfactual prediction user guide (#2960) Addresses #1176 --------- Co-authored-by: Can Ergen --- .../background/counterfactual_prediction.md | 60 ++++++++++++++++++- .../figures/counterfactual_cartoon.svg | 60 +++++++++++++++++++ 2 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 docs/user_guide/background/figures/counterfactual_cartoon.svg diff --git a/docs/user_guide/background/counterfactual_prediction.md b/docs/user_guide/background/counterfactual_prediction.md index 6cd1e9923f..ef4ac6bf7b 100644 --- a/docs/user_guide/background/counterfactual_prediction.md +++ b/docs/user_guide/background/counterfactual_prediction.md @@ -1,5 +1,61 @@ # Counterfactual prediction -:::{note} -This page is under construction. +Once we have trained a model to predict a variable of interest or a generative model to learn the data distribution, we are often interested in making predictions for new samples. However, predictions over test samples may not reveal exactly what the model has learned about how the input features relate to the target variable of interest. For example, we may want to answer the question: How would the model predict the expression levels of gene Y in cell Z if gene X is knocked out? Even if we do not have an data point corresponding to this scenario, we can instead perturb the input to see what the model reports. + +:::{warning} +We are using the term "counterfactual prediction" here rather loosely. In particular, we are not following the rigorous definition of counterfactual prediction in the causality literature[^ref1]. While closely related in spirit, we are making counterfactual queries with statistical models to gain some insight into what the model has learned about the data distribution. +::: + +:::{figure} figures/counterfactual_cartoon.svg +:align: center +:alt: Cartoon of the counterfactual prediction task across two conditions. +:class: img-fluid + +Cartoon of the counterfactual prediction task across two conditions. This counterfactual prediction can be thought of as an interpolation of nearby points in the feature space originating from condition B. ::: + +## Preliminaries + +Suppose we have a trained model $f_\theta$ that takes in a data point $x$ (e.g., gene expression counts) and a condition $c$ (e.g., treatment group) and returns a prediction $\hat{y}$. +Each data point takes the form of a tuple $(x,c) \in \mathcal{D}$. +We can define a *counterfactual query* as a pair $(x,c')$ where $c' \neq c$, +and the respective model output as the *counterfactual prediction*, $\hat{y}' = f_\theta(x,c')$. + +We separate $c$ here out from $x$ to make the counterfactual portion of the query explicit, but it can be thought of as another dimension of $x$. + +## In-distribution vs. out-of-distribution + +Since we are working with statistical models rather than causal models, we have to be careful when we can rely on counterfactual predictions. At a high level, if we assume the true function relating the features to the target is smooth, we can trust counterfactual predictions for queries that are similar to points in the training data. + +Say we have a counterfactual query $(x,c')$, and we have data points in the training set $(x',c')$ (i.e., $\|x - x'\|$ is small). +If our model predicts the $y$ for $(x', c')$ well, +we can reasonably trust the counterfactual prediction for $(x,c')$. +Otherwise, if $(x,c')$ is very different from any point in the training data +with condition $c'$, we cannot make any guarantees about the accuracy of the counterfactual prediction. +Dimensionality reduction techniques or harmonization methods may help create more overlap between the features $x$ across the conditions, setting the stage for more reliable counterfactual predictions. + +## Applications + +The most direct application of counterfactual prediction in scvi-tools can be found in the `transform_batch` kwarg of the {func}`~scvi.model.SCVI.get_normalized_expression` function. In this case, we can pass in a counterfactual batch label to get a prediction of what the normalized expression would be for a cell if it were a member of that batch. This can be useful if one wants to compare cells across different batches in the gene space. + +The described approach to counterfactual prediction has also been used in a variety of applications, including: +- characterizing cell-type-specific sample-level effects [^ref2] +- predicting chemical perturbation responses in different cell types [^ref2][^ref3] +- predicting infection/perturbation responses across species [^ref4] + +For more details on how counterfactual prediction is used in another method implemented in scvi-tools, see the {doc}`/user_guide/models/mrvi`. + +[^ref1]: + Judea Pearl. Causality. Cambridge university press, 2009. +[^ref2]: + Pierre Boyeau, Justin Hong, Adam Gayoso, Martin Kim, Jose L McFaline-Figueroa, Michael Jordan, Elham Azizi, Can Ergen, Nir Yosef (2024), + _Deep generative modeling of sample-level heterogeneity in single-cell genomics_, + [bioRxiv](https://doi.org/10.1101/2022.10.04.510898). +[^ref3]: + Mohammad Lotfollahi, Anna Klimovskaia Susmelj, Carlo De Donno, Leon Hetzel, Yuge Ji, Ignacio L Ibarra, Sanjay R Srivatsan, Mohsen Naghipourfar, Riza M Daza, Beth Martin, Jay Shendure, Jose L McFaline‐Figueroa, Pierre Boyeau, F Alexander Wolf, Nafissa Yakubova, Stephan Günnemann, Cole Trapnell, David Lopez‐Paz, Fabian J Theis (2023), + _Predicting cellular responses to complex perturbations in high‐throughput screens_, + [Molecular Systems Biology](https://doi.org/10.15252/msb.202211517). +[^ref4]: + Mohammad Lotfollahi, F Alexander Wolf, Fabian J Theis (2019), + _scGen predicts single-cell perturbation responses_, + [Nature Methods](https://doi.org/10.1038/s41592-019-0494-8). diff --git a/docs/user_guide/background/figures/counterfactual_cartoon.svg b/docs/user_guide/background/figures/counterfactual_cartoon.svg new file mode 100644 index 0000000000..153293f677 --- /dev/null +++ b/docs/user_guide/background/figures/counterfactual_cartoon.svg @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 7519daf68be90106434ff15a64e1eff134371016 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Sun, 22 Sep 2024 11:11:48 +0300 Subject: [PATCH 13/22] test: update the resolution test workflow (#2987) It was never passed There's not really a strong incentive to get it passing as people should be using more up-to-date packages anyway --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/test_linux_resolution.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/test_linux_resolution.yml b/.github/workflows/test_linux_resolution.yml index d41447365c..40faecbff2 100644 --- a/.github/workflows/test_linux_resolution.yml +++ b/.github/workflows/test_linux_resolution.yml @@ -34,12 +34,7 @@ jobs: matrix: os: [ubuntu-latest] python: ["3.10", "3.11", "3.12"] - install-flags: - [ - "--prerelease if-necessary-or-explicit", - "--resolution lowest-direct", - "--resolution lowest", - ] + install-flags: ["--prerelease if-necessary-or-explicit"] name: integration From e4e3bedab14f91ab379ef26e5da2dd3878de41a9 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Sun, 22 Sep 2024 14:08:05 +0300 Subject: [PATCH 14/22] Update pyproject.toml (#2989) Added scib-metrics package to tutorials --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 9b1221c290..43c9bd4768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ tutorials = [ "igraph", "scikit-misc", "scrublet", + "scib-metrics", "scvi-tools[optional]", "squidpy", ] From 4199a474eb480d0b508b32d67c0982f0c6be3866 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:51:05 +0300 Subject: [PATCH 15/22] [pre-commit.ci] pre-commit autoupdate (#2993) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.5 → v0.6.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.5...v0.6.7) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0239d02f7..4a72a3f797 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: )$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.5 + rev: v0.6.7 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From a557cf9c52e8df3790f06705adb9b4bb8cb21dac Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Thu, 26 Sep 2024 12:05:42 +0300 Subject: [PATCH 16/22] release: updated v1.2.0 and RTD for py312 (#2991) Co-authored-by: Can Ergen --- CHANGELOG.md | 4 ++-- docs/tutorials/notebooks | 2 +- pyproject.toml | 4 ++-- src/scvi/module/_totalvae.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a4c22dc63..cb781b50a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,9 +4,9 @@ Starting from version 0.20.1, this format is based on [Keep a Changelog], and th to [Semantic Versioning]. Full commit history is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/). -## Version 1.2 (unreleased) +## Version 1.2 -### 1.2.0 (unreleased) +### 1.2.0 (2024-09-26) #### Added diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 1a9b54617a..43eb27fc1d 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 1a9b54617a1f685f29c17a2be421c2f964152183 +Subproject commit 43eb27fc1dac500009aab6cb409cd1de71eff446 diff --git a/pyproject.toml b/pyproject.toml index 43c9bd4768..d890ce4843 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = ["hatchling"] [project] name = "scvi-tools" -version = "1.1.6" +version = "1.2.0" description = "Deep probabilistic analysis of single-cell omics data." readme = "README.md" requires-python = ">=3.10" @@ -15,6 +15,7 @@ authors = [ maintainers = [ {name = "The scvi-tools development team", email = "ori.kronfeld@weizmann.ac.il"}, ] + urls.Documentation = "https://scvi-tools.org" urls.Source = "https://github.com/scverse/scvi-tools" urls.Home-page = "https://scvi-tools.org" @@ -55,7 +56,6 @@ dependencies = [ "xarray>=2023.2.0", ] - [project.optional-dependencies] tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"] editing = ["jupyter", "pre-commit"] diff --git a/src/scvi/module/_totalvae.py b/src/scvi/module/_totalvae.py index bdb80638b6..d3fb5488da 100644 --- a/src/scvi/module/_totalvae.py +++ b/src/scvi/module/_totalvae.py @@ -741,8 +741,8 @@ def marginal_ll(self, tensors, n_mc_samples, return_mean: bool = True): return log_lkl def on_load(self, model: BaseModelClass): - manager = model.get_anndata_manager(model.adata) - source_version = manager.registry[_constants._SCVI_VERSION_KEY] + manager = model.get_anndata_manager(model.adata, required=True) + source_version = manager._source_registry[_constants._SCVI_VERSION_KEY] version_split = source_version.split(".") if int(version_split[0]) >= 1 and int(version_split[1]) >= 2: return From 5b776fd1592a075e589b3428c11df5a5e6e9957a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:42:36 +0300 Subject: [PATCH 17/22] docs: automated update of tutorials (#2996) automated update of tutorials submodule Co-authored-by: ori-kron-wis <175299014+ori-kron-wis@users.noreply.github.com> --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 43eb27fc1d..09bd7fc99b 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 43eb27fc1dac500009aab6cb409cd1de71eff446 +Subproject commit 09bd7fc99bcb4c64c1661b17b7495ebe77fc72f3 From f1f58b9b0a027099dd067840911acca588cec1d6 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Thu, 26 Sep 2024 16:21:13 +0300 Subject: [PATCH 18/22] SCVI-Tools v1.2.0 Release (#2998) close https://github.com/scverse/scvi-tools/issues/2889 --------- Co-authored-by: Lumberbot (aka Jack) <39504233+meeseeksmachine@users.noreply.github.com> Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com> Co-authored-by: Can Ergen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justin Hong Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: ori-kron-wis <175299014+ori-kron-wis@users.noreply.github.com> From efbbc68d477199fd71009f783d212fba7e861bfa Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 30 Sep 2024 12:09:05 -0400 Subject: [PATCH 19/22] Fix warning for DA function in MrVI (#2999) Mentioned in https://discourse.scverse.org/t/error-in-mrvi-differential-abundance/2486 --- src/scvi/external/mrvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index a0d362018b..4f3dabfbc4 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -843,7 +843,7 @@ def differential_abundance( if n_cov_values > n_samples / 2: warnings.warn( f"The covariate '{key}' does not seem to refer to a discrete key. " - f"It has {len(n_cov_values)} unique values, which exceeds one half of the " + f"It has {n_cov_values} unique values, which exceeds one half of the " f"total samples ({n_samples}).", UserWarning, stacklevel=2, From eb7f0055c46aade759c7da686971f91b385b015f Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Mon, 30 Sep 2024 19:09:17 +0300 Subject: [PATCH 20/22] Update pyproject.toml (#3000) removal of gdown from requirements as it is not needed in cellassign tutorial --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d890ce4843..18d55569d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,6 @@ optional = [ ] tutorials = [ "cell2location", - "gdown", "jupyter", "leidenalg", "muon", From 0bd640490fb7f75f863c6470a33916c3cf9dccf3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 09:44:52 +0300 Subject: [PATCH 21/22] [pre-commit.ci] pre-commit autoupdate (#3003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/igorshubovych/markdownlint-cli: v0.41.0 → v0.42.0](https://github.com/igorshubovych/markdownlint-cli/compare/v0.41.0...v0.42.0) - [github.com/astral-sh/ruff-pre-commit: v0.6.7 → v0.6.8](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.7...v0.6.8) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a72a3f797..c7ae9e9c99 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: )$ - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.41.0 + rev: v0.42.0 hooks: - id: markdownlint-fix exclude: | @@ -41,7 +41,7 @@ repos: )$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 + rev: v0.6.8 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From b8bc97040a81fba43775eb85b8a74ea9feca6fac Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:30:45 +0300 Subject: [PATCH 22/22] docs: automated update of tutorials (#3005) automated update of tutorials submodule Co-authored-by: ori-kron-wis <175299014+ori-kron-wis@users.noreply.github.com> --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 09bd7fc99b..6ff469f5a9 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 09bd7fc99bcb4c64c1661b17b7495ebe77fc72f3 +Subproject commit 6ff469f5a9ec3e26324fcb27ac487d8486c6942f