From 61bbedb2a412fc2668ab567b8bb216d0675b9c3a Mon Sep 17 00:00:00 2001 From: Martin Kim Date: Wed, 26 Jun 2024 12:57:03 -0700 Subject: [PATCH 01/51] wip --- src/scvi/model/_scanvi.py | 92 +------------ src/scvi/model/_scvi.py | 85 ------------ src/scvi/model/base/_base_model.py | 125 ++++++++++++++---- src/scvi/model/utils/_minification.py | 50 ++++--- tests/model/test_models_with_minified_data.py | 47 +++---- 5 files changed, 148 insertions(+), 251 deletions(-) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 55c6e7a980..52d7d03382 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -12,41 +12,29 @@ from anndata import AnnData from scvi import REGISTRY_KEYS, settings -from scvi._types import MinifiedDataType from scvi.data import AnnDataManager from scvi.data._constants import ( - _ADATA_MINIFY_TYPE_UNS_KEY, _SETUP_ARGS_KEY, - ADATA_MINIFY_TYPE, ) from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute from scvi.data.fields import ( - BaseAnnDataField, CategoricalJointObsField, CategoricalObsField, LabelsWithUnlabeledObsField, LayerField, NumericalJointObsField, NumericalObsField, - ObsmField, - StringUnsField, ) from scvi.dataloaders import SemiSupervisedDataSplitter +from scvi.model._scvi import SCVI from scvi.model._utils import _init_library_size, get_max_epochs_heuristic -from scvi.model.utils import get_minified_adata_scrna +from scvi.model.base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin from scvi.module import SCANVAE from scvi.train import SemiSupervisedTrainingPlan, TrainRunner from scvi.train._callbacks import SubSampleLabels from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp -from ._scvi import SCVI -from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin - -_SCANVI_LATENT_QZM = "_scanvi_latent_qzm" -_SCANVI_LATENT_QZV = "_scanvi_latent_qzv" -_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" - logger = logging.getLogger(__name__) @@ -476,79 +464,3 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - - @staticmethod - def _get_fields_for_adata_minification( - minified_data_type: MinifiedDataType, - ) -> list[BaseAnnDataField]: - """Return the fields required for adata minification of the given minified_data_type.""" - if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - fields = [ - ObsmField( - REGISTRY_KEYS.LATENT_QZM_KEY, - _SCANVI_LATENT_QZM, - ), - ObsmField( - REGISTRY_KEYS.LATENT_QZV_KEY, - _SCANVI_LATENT_QZV, - ), - NumericalObsField( - REGISTRY_KEYS.OBSERVED_LIB_SIZE, - _SCANVI_OBSERVED_LIB_SIZE, - ), - ] - else: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - fields.append( - StringUnsField( - REGISTRY_KEYS.MINIFY_TYPE_KEY, - _ADATA_MINIFY_TYPE_UNS_KEY, - ), - ) - return fields - - def minify_adata( - self, - minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, - use_latent_qzm_key: str = "X_latent_qzm", - use_latent_qzv_key: str = "X_latent_qzv", - ): - """Minifies the model's adata. - - Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns - containing minified-adata type, and library size. - This also sets the appropriate property on the module to indicate that the adata is - minified. - - Parameters - ---------- - minified_data_type - How to minify the data. Currently only supports `latent_posterior_parameters`. - If minified_data_type == `latent_posterior_parameters`: - - * the original count data is removed (`adata.X`, adata.raw, and any layers) - * the parameters of the latent representation of the original data is stored - * everything else is left untouched - use_latent_qzm_key - Key to use in `adata.obsm` where the latent qzm params are stored - use_latent_qzv_key - Key to use in `adata.obsm` where the latent qzv params are stored - - Notes - ----- - The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. - """ - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - - if self.module.use_observed_lib_size is False: - raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") - - minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) - minified_adata.obsm[_SCANVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] - minified_adata.obsm[_SCANVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] - counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) - self._update_adata_and_manager_post_minification(minified_adata, minified_data_type) - self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 95c88c3541..a0e99281de 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -4,27 +4,20 @@ import warnings from typing import Literal -import numpy as np from anndata import AnnData from scvi import REGISTRY_KEYS, settings -from scvi._types import MinifiedDataType from scvi.data import AnnDataManager -from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( - BaseAnnDataField, CategoricalJointObsField, CategoricalObsField, LayerField, NumericalJointObsField, NumericalObsField, - ObsmField, - StringUnsField, ) from scvi.model._utils import _init_library_size from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin -from scvi.model.utils import get_minified_adata_scrna from scvi.module import VAE from scvi.utils import setup_anndata_dsp @@ -224,81 +217,3 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - - @staticmethod - def _get_fields_for_adata_minification( - minified_data_type: MinifiedDataType, - ) -> list[BaseAnnDataField]: - """Return the fields required for adata minification of the given minified_data_type.""" - if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - fields = [ - ObsmField( - REGISTRY_KEYS.LATENT_QZM_KEY, - _SCVI_LATENT_QZM, - ), - ObsmField( - REGISTRY_KEYS.LATENT_QZV_KEY, - _SCVI_LATENT_QZV, - ), - NumericalObsField( - REGISTRY_KEYS.OBSERVED_LIB_SIZE, - _SCVI_OBSERVED_LIB_SIZE, - ), - ] - else: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - fields.append( - StringUnsField( - REGISTRY_KEYS.MINIFY_TYPE_KEY, - _ADATA_MINIFY_TYPE_UNS_KEY, - ), - ) - return fields - - def minify_adata( - self, - minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, - use_latent_qzm_key: str = "X_latent_qzm", - use_latent_qzv_key: str = "X_latent_qzv", - ) -> None: - """Minifies the model's adata. - - Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns - containing minified-adata type, and library size. - This also sets the appropriate property on the module to indicate that the adata is - minified. - - Parameters - ---------- - minified_data_type - How to minify the data. Currently only supports `latent_posterior_parameters`. - If minified_data_type == `latent_posterior_parameters`: - - * the original count data is removed (`adata.X`, adata.raw, and any layers) - * the parameters of the latent representation of the original data is stored - * everything else is left untouched - use_latent_qzm_key - Key to use in `adata.obsm` where the latent qzm params are stored - use_latent_qzv_key - Key to use in `adata.obsm` where the latent qzv params are stored - - Notes - ----- - The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. - """ - # TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior - # without removing the original counts. - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - - if self.module.use_observed_lib_size is False: - raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") - - minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) - minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] - minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] - counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) - self._update_adata_and_manager_post_minification(minified_adata, minified_data_type) - self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index e991da3eb3..e23aaeff1d 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -16,13 +16,15 @@ from scvi import REGISTRY_KEYS, settings from scvi._types import AnnOrMuData, MinifiedDataType -from scvi.data import AnnDataManager +from scvi.data import AnnDataManager, fields from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( + _ADATA_MINIFY_TYPE_UNS_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, + ADATA_MINIFY_TYPE, ) from scvi.data._utils import _assign_adata_uuid, _check_if_view, _get_adata_minify_type from scvi.dataloaders import AnnDataLoader @@ -34,6 +36,7 @@ _load_saved_files, _validate_var_names, ) +from scvi.model.utils import get_minified_adata_scrna from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp @@ -83,6 +86,9 @@ class BaseModelClass(metaclass=BaseModelMetaClass): 1. :doc:`/tutorials/notebooks/dev/model_user_guide` """ + _LATENT_QZM_KEY = "latent_qzm" + _LATENT_QZV_KEY = "latent_qzv" + _OBSERVED_LIB_SIZE_KEY = "observed_lib_size" _data_loader_cls = AnnDataLoader def __init__(self, adata: AnnOrMuData | None = None): @@ -881,53 +887,126 @@ def view_anndata_setup( class BaseMinifiedModeModelClass(BaseModelClass): - """Abstract base class for scvi-tools models that can handle minified data.""" + """Base class for models that can handle minified data.""" @property def minified_data_type(self) -> MinifiedDataType | None: - """The type of minified data associated with this model, if applicable.""" + """Type of minified data associated with this model.""" return ( self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry else None ) - @abstractmethod def minify_adata( self, - *args, - **kwargs, - ): - """Minifies the model's adata. + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + keep_count_data: bool = False, + ) -> None: + """Minify the model's :attr:`~scvi.model.base.BaseModelClass.adata`. + + Minifies the :class:`~anndata.AnnData` object associated with the model according to the + method specified by ``minified_data_type`` and registers the new fields with the model's + :class:`~scvi.data.AnnDataManager`. This also sets the ``minified_data_type`` attribute + of the underlying :class:`~scvi.module.base.BaseModuleClass` instance. - Minifies the adata, and registers new anndata fields as required (can be model-specific). - This also sets the appropriate property on the module to indicate that the adata is - minified. + Parameters + ---------- + minified_data_type + Method for minifying the data. One of the following: + + - ``"latent_posterior"``: Store the latent posterior mean and variance in + :attr:`~anndata.AnnData.obsm` using the keys ``use_latent_qzm_key`` and + ``use_latent_qzv_key``. + use_latent_qzm_key + Key to use for storing the latent posterior mean in :attr:`~anndata.AnnData.obsm` when + ``minified_data_type`` is ``"latent_posterior"``. + use_latent_qzv_key + Key to use for storing the latent posterior variance in :attr:`~anndata.AnnData.obsm` + when ``minified_data_type`` is ``"latent_posterior"``. + keep_count_data + If ``True``, the full count matrix is kept in the minified + :attr:`~scvi.model.base.BaseModelClass.adata`. Notes ----- The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. + version of the :class:`~anndata.AnnData`. """ + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError( + f"Minification method {minified_data_type} is not supported." + ) + elif not getattr(self.module, "use_observed_lib_size", True): + raise ValueError( + "Minification is not supported for models that do not use observed library size." + ) - @staticmethod - @abstractmethod - def _get_fields_for_adata_minification(minified_data_type: MinifiedDataType): - """Return the anndata fields required for adata minification of the given type.""" + mini_adata = get_minified_adata_scrna( + adata_manager=self.adata_manager, + minified_data_type=minified_data_type, + keep_count_data=keep_count_data, + ) + mini_adata.obsm[self._LATENT_QZM_KEY] = self.adata.obsm[use_latent_qzm_key] + mini_adata.obsm[self._LATENT_QZV_KEY] = self.adata.obsm[use_latent_qzv_key] + mini_adata.obs[self._OBSERVED_LIB_SIZE_KEY] = np.squeeze( + np.asarray(self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(axis=-1)) + ) + self._update_adata_and_manager_post_minification( + mini_adata, + minified_data_type, + keep_count_data=keep_count_data, + ) + self.module.minified_data_type = minified_data_type + + @classmethod + def _get_fields_for_adata_minification( + cls, + minified_data_type: MinifiedDataType, + keep_count_data: bool, + ): + """Return the fields required for minification of the given type.""" + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError( + f"Minification method {minified_data_type} is not supported." + ) + + mini_fields = [ + fields.ObsmField(REGISTRY_KEYS.LATENT_QZM_KEY, cls._LATENT_QZM_KEY), + fields.ObsmField(REGISTRY_KEYS.LATENT_QZV_KEY, cls._LATENT_QZV_KEY), + fields.NumericalObsField(REGISTRY_KEYS.OBSERVED_LIB_SIZE, cls._OBSERVED_LIB_SIZE_KEY), + fields.StringUnsField(REGISTRY_KEYS.MINIFY_TYPE_KEY, _ADATA_MINIFY_TYPE_UNS_KEY), + ] + if keep_count_data: + mini_fields.append(fields.LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True)) + + return mini_fields def _update_adata_and_manager_post_minification( - self, minified_adata: AnnOrMuData, minified_data_type: MinifiedDataType + self, + minified_adata: AnnOrMuData, + minified_data_type: MinifiedDataType, + keep_count_data: bool, ): - """Update the anndata and manager inplace after creating a minified adata.""" - # Register this new adata with the model, creating a new manager in the cache + """Update the :class:`~anndata.AnnData` and :class:`~scvi.data.AnnDataManager` in-place. + + Parameters + ---------- + minified_adata + Minified version of :attr:`~scvi.model.base.BaseModelClass.adata`. + minified_data_type + Method used for minifying the data. + keep_count_data + If ``True``, the full count matrix is kept in the minified + :attr:`~scvi.model.base.BaseModelClass.adata`. + """ self._validate_anndata(minified_adata) new_adata_manager = self.get_anndata_manager(minified_adata, required=True) - # This inplace edits the manager new_adata_manager.register_new_fields( - self._get_fields_for_adata_minification(minified_data_type) + self._get_fields_for_adata_minification(minified_data_type, keep_count_data) ) - # We set the adata attribute of the model as this will update self.registry_ - # and self.adata_manager with the new adata manager self.adata = minified_adata @property diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index cf84687bc5..e63bc3947c 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -1,7 +1,11 @@ +from __future__ import annotations + from anndata import AnnData from scipy.sparse import csr_matrix +from scvi import REGISTRY_KEYS from scvi._types import MinifiedDataType +from scvi.data import AnnDataManager from scvi.data._constants import ( _ADATA_MINIFY_TYPE_UNS_KEY, _SCVI_UUID_KEY, @@ -10,34 +14,26 @@ def get_minified_adata_scrna( - adata: AnnData, - minified_data_type: MinifiedDataType, + adata_manager: AnnDataManager, + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + keep_count_data: bool = False, ) -> AnnData: - """Returns a minified adata that works for most scrna models (such as SCVI, SCANVI). - - Parameters - ---------- - adata - Original adata, of which we to create a minified version. - minified_data_type - How to minify the data. - """ + """Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object.""" if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + raise NotImplementedError(f"Minification method {minified_data_type} is not supported.") - all_zeros = csr_matrix(adata.X.shape) - layers = {layer: all_zeros.copy() for layer in adata.layers} - bdata = AnnData( - X=all_zeros, - layers=layers, - uns=adata.uns.copy(), - obs=adata.obs, - var=adata.var, - varm=adata.varm, - obsm=adata.obsm, - obsp=adata.obsp, + counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + mini_adata = AnnData( + X=counts if keep_count_data else csr_matrix(counts.shape), + obs=adata_manager.adata.obs.copy(), + var=adata_manager.adata.var.copy(), + uns=adata_manager.adata.uns.copy(), + obsm=adata_manager.adata.obsm.copy(), + varm=adata_manager.adata.varm.copy(), + obsp=adata_manager.adata.obsp.copy(), + varp=adata_manager.adata.varp.copy(), ) - # Remove scvi uuid key to make bdata fresh w.r.t. the model's manager - del bdata.uns[_SCVI_UUID_KEY] - bdata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type - return bdata + del mini_adata.uns[_SCVI_UUID_KEY] + mini_adata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type + + return mini_adata diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 9069d6781f..7db1053552 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -1,56 +1,56 @@ +from __future__ import annotations + import numpy as np +import numpy.typing as npt import pytest +from anndata import AnnData import scvi from scvi.data import synthetic_iid from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified from scvi.model import SCANVI, SCVI +from scvi.model.base import BaseMinifiedModeModelClass _SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" _SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" -def prep_model(cls=SCVI, layer=None, use_size_factor=False): - # create a synthetic dataset +def prep_model( + cls: BaseMinifiedModeModelClass = SCVI, + layer: str | None = None, + use_size_factor: bool = False, + n_latent: int = 5, +) -> tuple[BaseMinifiedModeModelClass, AnnData, npt.NDArray, AnnData]: adata = synthetic_iid() - adata_counts = adata.X + counts = adata.X if use_size_factor: adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) if layer is not None: adata.layers[layer] = adata.X.copy() adata.X = np.zeros_like(adata.X) - adata.var["n_counts"] = np.squeeze(np.asarray(np.sum(adata_counts, axis=0))) - adata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(adata.shape[1], 3)) - adata.layers["my_layer"] = np.ones_like(adata.X) + adata_before_setup = adata.copy() - # run setup_anndata setup_kwargs = { "layer": layer, "batch_key": "batch", "labels_key": "labels", + "size_factor_key": "size_factor" if use_size_factor else None, } if cls == SCANVI: setup_kwargs["unlabeled_category"] = "unknown" - if use_size_factor: - setup_kwargs["size_factor_key"] = "size_factor" cls.setup_anndata( adata, **setup_kwargs, ) - # create and train the model - model = cls(adata, n_latent=5) + model = cls(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5) - # get the adata lib size - adata_lib_size = np.squeeze(np.asarray(adata_counts.sum(axis=1))) - assert ( - np.min(adata_lib_size) > 0 - ) # make sure it's not all zeros and there are no negative values + lib_size = np.squeeze(np.asarray(counts.sum(axis=-1))) - return model, adata, adata_lib_size, adata_before_setup + return model, adata, lib_size, adata_before_setup def assert_approx_equal(a, b): @@ -60,11 +60,11 @@ def assert_approx_equal(a, b): def run_test_for_model_with_minified_adata( - cls=SCVI, + cls: BaseMinifiedModeModelClass = SCVI, n_samples: int = 1, give_mean: bool = False, layer: str = None, - use_size_factor=False, + use_size_factor: bool = False, ): model, adata, adata_lib_size, _ = prep_model(cls, layer, use_size_factor) @@ -81,19 +81,14 @@ def run_test_for_model_with_minified_adata( assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR assert model.adata_manager.registry is model.registry_ - # make sure the original adata we set up the model with was not changed + assert not _is_minified(adata) assert adata is not model.adata - assert _is_minified(adata) is False - assert adata_orig.layers.keys() == model.adata.layers.keys() orig_obs_df = adata_orig.obs - obs_keys = _SCANVI_OBSERVED_LIB_SIZE if cls == SCANVI else _SCVI_OBSERVED_LIB_SIZE - orig_obs_df[obs_keys] = adata_lib_size + orig_obs_df[BaseMinifiedModeModelClass._OBSERVED_LIB_SIZE_KEY] = adata_lib_size assert model.adata.obs.equals(orig_obs_df) assert model.adata.var_names.equals(adata_orig.var_names) assert model.adata.var.equals(adata_orig.var) - assert model.adata.varm.keys() == adata_orig.varm.keys() - np.testing.assert_array_equal(model.adata.varm["my_varm"], adata_orig.varm["my_varm"]) scvi.settings.seed = 1 keys = ["mean", "dispersions", "dropout"] From 1bb2f2f64ccccd8244246f5dc26431bca50bf9f8 Mon Sep 17 00:00:00 2001 From: Martin Kim Date: Wed, 3 Jul 2024 11:24:09 -0700 Subject: [PATCH 02/51] wip --- src/scvi/data/_constants.py | 1 + src/scvi/model/base/_base_model.py | 24 +++++++++++------------- src/scvi/model/utils/_minification.py | 16 +--------------- 3 files changed, 13 insertions(+), 28 deletions(-) diff --git a/src/scvi/data/_constants.py b/src/scvi/data/_constants.py index 9efa664537..09ab91aeaa 100644 --- a/src/scvi/data/_constants.py +++ b/src/scvi/data/_constants.py @@ -37,6 +37,7 @@ class _ADATA_MINIFY_TYPE_NT(NamedTuple): LATENT_POSTERIOR: str = "latent_posterior_parameters" + LATENT_POSTERIOR_WITH_COUNTS: str = "latent_posterior_parameters_with_counts" ADATA_MINIFY_TYPE = _ADATA_MINIFY_TYPE_NT() diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index e23aaeff1d..9b35b47c58 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -903,7 +903,6 @@ def minify_adata( minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, use_latent_qzm_key: str = "X_latent_qzm", use_latent_qzv_key: str = "X_latent_qzv", - keep_count_data: bool = False, ) -> None: """Minify the model's :attr:`~scvi.model.base.BaseModelClass.adata`. @@ -917,25 +916,25 @@ def minify_adata( minified_data_type Method for minifying the data. One of the following: - - ``"latent_posterior"``: Store the latent posterior mean and variance in + - ``"latent_posterior_parameters"``: Store the latent posterior mean and variance in :attr:`~anndata.AnnData.obsm` using the keys ``use_latent_qzm_key`` and ``use_latent_qzv_key``. + - ``"latent_posterior_parameters_with_counts"``: Store the latent posterior mean and + variance in :attr:`~anndata.AnnData.obsm` using the keys ``use_latent_qzm_key`` and + ``use_latent_qzv_key``, and the raw count data in :attr:`~anndata.AnnData.X`. use_latent_qzm_key Key to use for storing the latent posterior mean in :attr:`~anndata.AnnData.obsm` when ``minified_data_type`` is ``"latent_posterior"``. use_latent_qzv_key Key to use for storing the latent posterior variance in :attr:`~anndata.AnnData.obsm` when ``minified_data_type`` is ``"latent_posterior"``. - keep_count_data - If ``True``, the full count matrix is kept in the minified - :attr:`~scvi.model.base.BaseModelClass.adata`. Notes ----- The modification is not done inplace -- instead the model is assigned a new (minified) version of the :class:`~anndata.AnnData`. """ - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if minified_data_type not in ADATA_MINIFY_TYPE: raise NotImplementedError( f"Minification method {minified_data_type} is not supported." ) @@ -944,11 +943,13 @@ def minify_adata( "Minification is not supported for models that do not use observed library size." ) + keep_count_data = minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS mini_adata = get_minified_adata_scrna( adata_manager=self.adata_manager, - minified_data_type=minified_data_type, keep_count_data=keep_count_data, ) + del mini_adata.uns[_SCVI_UUID_KEY] + mini_adata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type mini_adata.obsm[self._LATENT_QZM_KEY] = self.adata.obsm[use_latent_qzm_key] mini_adata.obsm[self._LATENT_QZV_KEY] = self.adata.obsm[use_latent_qzv_key] mini_adata.obs[self._OBSERVED_LIB_SIZE_KEY] = np.squeeze( @@ -957,7 +958,6 @@ def minify_adata( self._update_adata_and_manager_post_minification( mini_adata, minified_data_type, - keep_count_data=keep_count_data, ) self.module.minified_data_type = minified_data_type @@ -965,10 +965,9 @@ def minify_adata( def _get_fields_for_adata_minification( cls, minified_data_type: MinifiedDataType, - keep_count_data: bool, ): """Return the fields required for minification of the given type.""" - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if minified_data_type not in ADATA_MINIFY_TYPE: raise NotImplementedError( f"Minification method {minified_data_type} is not supported." ) @@ -979,7 +978,7 @@ def _get_fields_for_adata_minification( fields.NumericalObsField(REGISTRY_KEYS.OBSERVED_LIB_SIZE, cls._OBSERVED_LIB_SIZE_KEY), fields.StringUnsField(REGISTRY_KEYS.MINIFY_TYPE_KEY, _ADATA_MINIFY_TYPE_UNS_KEY), ] - if keep_count_data: + if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS: mini_fields.append(fields.LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True)) return mini_fields @@ -988,7 +987,6 @@ def _update_adata_and_manager_post_minification( self, minified_adata: AnnOrMuData, minified_data_type: MinifiedDataType, - keep_count_data: bool, ): """Update the :class:`~anndata.AnnData` and :class:`~scvi.data.AnnDataManager` in-place. @@ -1005,7 +1003,7 @@ def _update_adata_and_manager_post_minification( self._validate_anndata(minified_adata) new_adata_manager = self.get_anndata_manager(minified_adata, required=True) new_adata_manager.register_new_fields( - self._get_fields_for_adata_minification(minified_data_type, keep_count_data) + self._get_fields_for_adata_minification(minified_data_type) ) self.adata = minified_adata diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index e63bc3947c..4c771f10ba 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -4,26 +4,16 @@ from scipy.sparse import csr_matrix from scvi import REGISTRY_KEYS -from scvi._types import MinifiedDataType from scvi.data import AnnDataManager -from scvi.data._constants import ( - _ADATA_MINIFY_TYPE_UNS_KEY, - _SCVI_UUID_KEY, - ADATA_MINIFY_TYPE, -) def get_minified_adata_scrna( adata_manager: AnnDataManager, - minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, keep_count_data: bool = False, ) -> AnnData: """Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object.""" - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Minification method {minified_data_type} is not supported.") - counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - mini_adata = AnnData( + return AnnData( X=counts if keep_count_data else csr_matrix(counts.shape), obs=adata_manager.adata.obs.copy(), var=adata_manager.adata.var.copy(), @@ -33,7 +23,3 @@ def get_minified_adata_scrna( obsp=adata_manager.adata.obsp.copy(), varp=adata_manager.adata.varp.copy(), ) - del mini_adata.uns[_SCVI_UUID_KEY] - mini_adata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type - - return mini_adata From 245c7107fc036735de5f452d1e874a6a601b1846 Mon Sep 17 00:00:00 2001 From: Martin Kim Date: Wed, 3 Jul 2024 11:31:50 -0700 Subject: [PATCH 03/51] keep empty layers for registry --- src/scvi/model/utils/_minification.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index 4c771f10ba..5a99395c09 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -13,8 +13,10 @@ def get_minified_adata_scrna( ) -> AnnData: """Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object.""" counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + all_zeros = csr_matrix(counts.shape) return AnnData( - X=counts if keep_count_data else csr_matrix(counts.shape), + X=counts if keep_count_data else all_zeros, + layers={layer: all_zeros.copy() for layer in adata_manager.adata.layers}, obs=adata_manager.adata.obs.copy(), var=adata_manager.adata.var.copy(), uns=adata_manager.adata.uns.copy(), From 8d9d0124853a2f12768906b101ea0dc64206a7fd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:57:12 +0000 Subject: [PATCH 04/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_scanvi.py | 5 ----- src/scvi/model/_scvi.py | 6 ------ 2 files changed, 11 deletions(-) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 93670c8a73..145ea7aa4f 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -41,11 +41,6 @@ from anndata import AnnData - from scvi._types import MinifiedDataType - from scvi.data.fields import ( - BaseAnnDataField, - ) - from ._scvi import SCVI _SCANVI_LATENT_QZM = "_scanvi_latent_qzm" diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index aeeea1b77f..a27dd1bb12 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -4,8 +4,6 @@ import warnings from typing import TYPE_CHECKING -import numpy as np - from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data._utils import _get_adata_minify_type @@ -28,10 +26,6 @@ from anndata import AnnData - from scvi._types import MinifiedDataType - from scvi.data.fields import ( - BaseAnnDataField, - ) _SCVI_LATENT_QZM = "_scvi_latent_qzm" _SCVI_LATENT_QZV = "_scvi_latent_qzv" From 243c20638aa49702dfc37b41ca74b012317842d1 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 6 Nov 2024 12:00:31 +0200 Subject: [PATCH 05/51] updated conflics during merge --- src/scvi/model/_scanvi.py | 2 -- src/scvi/model/utils/_minification.py | 4 +++- tests/model/test_models_with_minified_data.py | 7 +++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 145ea7aa4f..f37c4c353a 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -24,9 +24,7 @@ NumericalObsField, ) from scvi.dataloaders import SemiSupervisedDataSplitter -from scvi.model._scvi import SCVI from scvi.model._utils import _init_library_size, get_max_epochs_heuristic -from scvi.model.base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin from scvi.module import SCANVAE from scvi.train import SemiSupervisedTrainingPlan, TrainRunner from scvi.train._callbacks import SubSampleLabels diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index 5a99395c09..e981ed3faf 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -2,10 +2,12 @@ from anndata import AnnData from scipy.sparse import csr_matrix +from typing import TYPE_CHECKING from scvi import REGISTRY_KEYS -from scvi.data import AnnDataManager +if TYPE_CHECKING: + from scvi.data import AnnDataManager def get_minified_adata_scrna( adata_manager: AnnDataManager, diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index e244fc7721..af35f93d83 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -1,9 +1,8 @@ from __future__ import annotations import numpy as np -import numpy.typing as npt import pytest -from anndata import AnnData +from typing import TYPE_CHECKING import scvi from scvi.data import synthetic_iid @@ -12,6 +11,10 @@ from scvi.model import SCANVI, SCVI from scvi.model.base import BaseMinifiedModeModelClass +if TYPE_CHECKING: + import numpy.typing as npt + from anndata import AnnData + _SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" _SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" From 23c60a53662942f466047581c4193fd89d2f0b2f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:00:47 +0000 Subject: [PATCH 06/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/utils/_minification.py | 4 +++- tests/model/test_models_with_minified_data.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index e981ed3faf..ce9c8bb4de 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -1,14 +1,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from anndata import AnnData from scipy.sparse import csr_matrix -from typing import TYPE_CHECKING from scvi import REGISTRY_KEYS if TYPE_CHECKING: from scvi.data import AnnDataManager + def get_minified_adata_scrna( adata_manager: AnnDataManager, keep_count_data: bool = False, diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index af35f93d83..5b23cf4e3a 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -1,8 +1,9 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pytest -from typing import TYPE_CHECKING import scvi from scvi.data import synthetic_iid From d0ad5f559da43b94db177391c55695d52a5a7884 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 6 Nov 2024 19:19:57 +0200 Subject: [PATCH 07/51] Added mudata support for MULTIVI as well as tests --- CHANGELOG.md | 2 + src/scvi/model/_multivi.py | 104 +++++++++++++++++++++++++++++++++++- src/scvi/model/_totalvi.py | 2 +- tests/model/test_multivi.py | 97 ++++++++++++++++++++++++++++++++- 4 files changed, 201 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 046502dbd8..6d3cbae594 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ to [Semantic Versioning]. Full commit history is available in the ### 1.2.1 (2024-XX-XX) #### Added +- Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method + {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`30xx`. #### Fixed diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index e17fa012ca..2e3613c66c 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -13,7 +13,7 @@ from torch.distributions import Normal from scvi import REGISTRY_KEYS, settings -from scvi.data import AnnDataManager +from scvi.data import AnnDataManager, fields from scvi.data.fields import ( CategoricalJointObsField, CategoricalObsField, @@ -44,6 +44,7 @@ from typing import Literal from anndata import AnnData + from mudata import MuData from scvi._types import Number @@ -1117,3 +1118,104 @@ def _check_adata_modality_weights(self, adata): """ if (adata is not None) and (self.module.modality_weights == "cell"): raise RuntimeError("Held out data not permitted when using per cell weights") + + @classmethod + @setup_anndata_dsp.dedent + def setup_mudata( + cls, + mdata: MuData, + rna_layer: str | None = None, + protein_layer: str | None = None, + batch_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + idx_layer: str | None = None, + modalities: dict[str, str] | None = None, + **kwargs, + ): + """%(summary_mdata)s. + + Parameters + ---------- + %(param_mdata)s + rna_layer + RNA layer key. If `None`, will use `.X` of specified modality key. + protein_layer + ATAC layer key. If `None`, will use `.X` of specified modality key. + %(param_batch_key)s + %(param_size_factor_key)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + %(idx_layer)s + %(param_modalities)s + + Examples + -------- + >>> mdata = muon.read_10x_h5("filtered_feature_bc_matrix.h5") + >>> scvi.model.MULTIVI.setup_mudata( + mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"} + ) + >>> vae = scvi.model.MULTIVI(mdata) + """ + setup_method_args = cls._get_setup_method_args(**locals()) + + if modalities is None: + raise ValueError("Modalities cannot be None.") + modalities = cls._create_modalities_attr_dict(modalities, setup_method_args) + mdata.obs["_indices"] = np.arange(mdata.n_obs) + + batch_field = fields.MuDataCategoricalObsField( + REGISTRY_KEYS.BATCH_KEY, + batch_key, + mod_key=modalities.batch_key, + ) + mudata_fields = [ + fields.MuDataLayerField( + REGISTRY_KEYS.X_KEY, + rna_layer, + mod_key=modalities.rna_layer, + is_count_data=True, + mod_required=True, + ), + batch_field, + fields.MuDataCategoricalObsField( + REGISTRY_KEYS.LABELS_KEY, + None, + mod_key=None, + ), + fields.MuDataNumericalObsField( + REGISTRY_KEYS.SIZE_FACTOR_KEY, + size_factor_key, + mod_key=modalities.size_factor_key, + required=False, + ), + fields.MuDataCategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, + categorical_covariate_keys, + mod_key=modalities.categorical_covariate_keys, + ), + fields.MuDataNumericalJointObsField( + REGISTRY_KEYS.CONT_COVS_KEY, + continuous_covariate_keys, + mod_key=modalities.continuous_covariate_keys, + ), + fields.MuDataNumericalObsField( + REGISTRY_KEYS.INDICES_KEY, + "_indices", + mod_key=modalities.idx_layer, + required=False, + ), + fields.MuDataProteinLayerField( + REGISTRY_KEYS.PROTEIN_EXP_KEY, + protein_layer, + mod_key=modalities.protein_layer, + use_batch_mask=True, + batch_field=batch_field, + is_count_data=True, + mod_required=True, + ), + ] + adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(mdata, **kwargs) + cls.register_manager(adata_manager) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index 929ee1ad6c..3ad947d64b 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -1275,7 +1275,7 @@ def setup_mudata( -------- >>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5") >>> scvi.model.TOTALVI.setup_mudata( - mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"} + mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"} ) >>> vae = scvi.model.TOTALVI(mdata) """ diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index bdeebba550..83a2965a31 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -3,7 +3,8 @@ from scvi.data import synthetic_iid from scvi.model import MULTIVI - +from mudata import MuData +import muon def test_multivi(): data = synthetic_iid() @@ -66,7 +67,7 @@ def test_multivi(): n_regions=50, modality_weights="cell", ) - assert vae.n_proteins == data.obsm["protein_expression"].shape[1] + #assert vae.n_regions == data.obsm["protein_expression"].shape[1] vae.train(3) @@ -76,3 +77,95 @@ def test_multivi_single_batch(): vae = MULTIVI(data, n_genes=50, n_regions=50) with pytest.warns(UserWarning): vae.train(3) + + +def test_multivi_mudata(): + #optional data - big one + url = "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5" + mdata = muon.read_10x_h5("data/multiome10k.h5mu", backup_url=url) + mdata + MULTIVI.setup_mudata(mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"}) + vae = MULTIVI(mdata, n_genes=50, n_regions=50) + + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + + n_obs = mdata.n_obs + n_genes = np.min([adata.n_vars,protein_adata.n_vars]) + n_regions = protein_adata.X.shape[1] + n_latent = 10 + + model = MULTIVI(mdata, n_latent=n_latent, n_genes=n_genes, n_regions=n_regions) + model.train(1, train_size=0.5) + assert model.is_trained is True + z = model.get_latent_representation() + assert z.shape == (n_obs, n_latent) + model.get_elbo() + #model.get_marginal_ll(n_mc_samples=3) + model.get_reconstruction_error() + model.get_normalized_expression() + model.get_normalized_expression(transform_batch=["batch_0", "batch_1"]) + #model.get_latent_library_size() + #model.get_protein_foreground_probability() + #model.get_protein_foreground_probability(transform_batch=["batch_0", "batch_1"]) + #post_pred = model.posterior_predictive_sample(n_samples=2) + #assert post_pred.shape == (n_obs, n_genes + n_regions, 2) + #post_pred = model.posterior_predictive_sample(n_samples=1) + #assert post_pred.shape == (n_obs, n_genes + n_regions) + # feature_correlation_matrix1 = model.get_feature_correlation_matrix(correlation_type="spearman") + # feature_correlation_matrix1 = model.get_feature_correlation_matrix( + # correlation_type="spearman", transform_batch=["batch_0", "batch_1"] + # ) + # feature_correlation_matrix2 = model.get_feature_correlation_matrix(correlation_type="pearson") + # assert feature_correlation_matrix1.shape == ( + # n_genes + n_regions, + # n_genes + n_regions, + # ) + # assert feature_correlation_matrix2.shape == ( + # n_genes + n_regions, + # n_genes + n_regions, + # ) + + model.get_elbo(indices=model.validation_indices) + #model.get_marginal_ll(indices=model.validation_indices, n_mc_samples=3) + model.get_reconstruction_error(indices=model.validation_indices) + + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata2, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + norm_exp = model.get_normalized_expression(mdata2, indices=[1, 2, 3]) + # assert norm_exp[0].shape == (3, adata2.n_vars) + # assert norm_exp[1].shape == (3, protein_adata2.n_vars) + # norm_exp = model.get_normalized_expression( + # mdata2, + # gene_list=adata2.var_names[:5].to_list(), + # protein_list=protein_adata2.var_names[:3].to_list(), + # transform_batch=["batch_0", "batch_1"], + # ) + + # latent_lib_size = model.get_latent_library_size(mdata2, indices=[1, 2, 3]) + # assert latent_lib_size.shape == (3, 1) + + # pro_foreground_prob = model.get_protein_foreground_probability( + # mdata2, indices=[1, 2, 3], protein_list=["gene_1", "gene_2"] + # ) + # assert pro_foreground_prob.shape == (3, 2) + # model.posterior_predictive_sample(mdata2) + # model.get_feature_correlation_matrix(mdata2) + + # test transfer_anndata_setup + view + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + #model.get_elbo(mdata2[:10]) From 94505462d8627037b6fb24e4857dffd4521dc54a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:23:22 +0000 Subject: [PATCH 08/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 1 + tests/model/test_multivi.py | 31 ++++++++++++++++--------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d3cbae594..ecf405e1f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ to [Semantic Versioning]. Full commit history is available in the ### 1.2.1 (2024-XX-XX) #### Added + - Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`30xx`. diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 83a2965a31..c195bbc368 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -1,10 +1,11 @@ +import muon import numpy as np import pytest +from mudata import MuData from scvi.data import synthetic_iid from scvi.model import MULTIVI -from mudata import MuData -import muon + def test_multivi(): data = synthetic_iid() @@ -67,7 +68,7 @@ def test_multivi(): n_regions=50, modality_weights="cell", ) - #assert vae.n_regions == data.obsm["protein_expression"].shape[1] + # assert vae.n_regions == data.obsm["protein_expression"].shape[1] vae.train(3) @@ -80,7 +81,7 @@ def test_multivi_single_batch(): def test_multivi_mudata(): - #optional data - big one + # optional data - big one url = "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5" mdata = muon.read_10x_h5("data/multiome10k.h5mu", backup_url=url) mdata @@ -97,7 +98,7 @@ def test_multivi_mudata(): ) n_obs = mdata.n_obs - n_genes = np.min([adata.n_vars,protein_adata.n_vars]) + n_genes = np.min([adata.n_vars, protein_adata.n_vars]) n_regions = protein_adata.X.shape[1] n_latent = 10 @@ -107,17 +108,17 @@ def test_multivi_mudata(): z = model.get_latent_representation() assert z.shape == (n_obs, n_latent) model.get_elbo() - #model.get_marginal_ll(n_mc_samples=3) + # model.get_marginal_ll(n_mc_samples=3) model.get_reconstruction_error() model.get_normalized_expression() model.get_normalized_expression(transform_batch=["batch_0", "batch_1"]) - #model.get_latent_library_size() - #model.get_protein_foreground_probability() - #model.get_protein_foreground_probability(transform_batch=["batch_0", "batch_1"]) - #post_pred = model.posterior_predictive_sample(n_samples=2) - #assert post_pred.shape == (n_obs, n_genes + n_regions, 2) - #post_pred = model.posterior_predictive_sample(n_samples=1) - #assert post_pred.shape == (n_obs, n_genes + n_regions) + # model.get_latent_library_size() + # model.get_protein_foreground_probability() + # model.get_protein_foreground_probability(transform_batch=["batch_0", "batch_1"]) + # post_pred = model.posterior_predictive_sample(n_samples=2) + # assert post_pred.shape == (n_obs, n_genes + n_regions, 2) + # post_pred = model.posterior_predictive_sample(n_samples=1) + # assert post_pred.shape == (n_obs, n_genes + n_regions) # feature_correlation_matrix1 = model.get_feature_correlation_matrix(correlation_type="spearman") # feature_correlation_matrix1 = model.get_feature_correlation_matrix( # correlation_type="spearman", transform_batch=["batch_0", "batch_1"] @@ -133,7 +134,7 @@ def test_multivi_mudata(): # ) model.get_elbo(indices=model.validation_indices) - #model.get_marginal_ll(indices=model.validation_indices, n_mc_samples=3) + # model.get_marginal_ll(indices=model.validation_indices, n_mc_samples=3) model.get_reconstruction_error(indices=model.validation_indices) adata2 = synthetic_iid() @@ -168,4 +169,4 @@ def test_multivi_mudata(): adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) - #model.get_elbo(mdata2[:10]) + # model.get_elbo(mdata2[:10]) From 079faffe3acbc8d1adf5d59b9133cc837416d4b9 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 6 Nov 2024 19:30:32 +0200 Subject: [PATCH 09/51] needed muon --- CHANGELOG.md | 2 +- pyproject.toml | 5 +++-- tests/model/test_multivi.py | 11 +++++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d3cbae594..10da7435ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Added - Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method - {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`30xx`. + {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`. #### Fixed diff --git a/pyproject.toml b/pyproject.toml index 34c9d6b58b..4d19105208 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,8 @@ census = ["cellxgene-census"] hub = ["huggingface_hub"] # scvi.model.utils.mde dependencies pymde = ["pymde"] +# mudata dependencies +muon = ["muon"] # scvi.data.add_dna_sequence regseq = ["biopython>=1.81", "genomepy"] # read loom @@ -96,13 +98,12 @@ loompy = ["loompy>=3.0.6"] scanpy = ["scanpy>=1.6"] optional = [ - "scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]" + "scvi-tools[autotune,aws,hub,loompy,muon,pymde,regseq,scanpy]" ] tutorials = [ "cell2location", "jupyter", "leidenalg", - "muon", "plotnine", "pooch", "pynndescent", diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 83a2965a31..3420070906 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -81,9 +81,10 @@ def test_multivi_single_batch(): def test_multivi_mudata(): #optional data - big one - url = "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5" + url = ("https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X" + "/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5") mdata = muon.read_10x_h5("data/multiome10k.h5mu", backup_url=url) - mdata + #mdata MULTIVI.setup_mudata(mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"}) vae = MULTIVI(mdata, n_genes=50, n_regions=50) @@ -118,11 +119,13 @@ def test_multivi_mudata(): #assert post_pred.shape == (n_obs, n_genes + n_regions, 2) #post_pred = model.posterior_predictive_sample(n_samples=1) #assert post_pred.shape == (n_obs, n_genes + n_regions) - # feature_correlation_matrix1 = model.get_feature_correlation_matrix(correlation_type="spearman") + # feature_correlation_matrix1 = (model.get_feature_correlation_matrix + # (correlation_type="spearman")) # feature_correlation_matrix1 = model.get_feature_correlation_matrix( # correlation_type="spearman", transform_batch=["batch_0", "batch_1"] # ) - # feature_correlation_matrix2 = model.get_feature_correlation_matrix(correlation_type="pearson") + # feature_correlation_matrix2 = (model.get_feature_correlation_matrix + # (correlation_type="pearson")) # assert feature_correlation_matrix1.shape == ( # n_genes + n_regions, # n_genes + n_regions, From b420037cc9dbf7882794fb7cc1c7158a1cef6bbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:33:46 +0000 Subject: [PATCH 10/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/model/test_multivi.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 0dd3e35302..88a6be7069 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -82,8 +82,10 @@ def test_multivi_single_batch(): def test_multivi_mudata(): # optional data - big one - url = ("https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X" - "/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5") + url = ( + "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X" + "/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5" + ) mdata = muon.read_10x_h5("data/multiome10k.h5mu", backup_url=url) mdata MULTIVI.setup_mudata(mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"}) From 815555c06c4ea06677faf44deb83bb13823b1210 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 7 Nov 2024 16:35:40 +0200 Subject: [PATCH 11/51] Added ATC/PROTEIN + RNA capability for MultiVI + more tests like in totalvi --- src/scvi/model/_multivi.py | 182 ++++----------- src/scvi/model/_totalvi.py | 7 +- tests/model/test_multivi.py | 443 +++++++++++++++++++++++++++++++----- 3 files changed, 424 insertions(+), 208 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 2e3613c66c..42eb8d4aec 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -46,7 +46,7 @@ from anndata import AnnData from mudata import MuData - from scvi._types import Number + from scvi._types import Number, AnnOrMuData logger = logging.getLogger(__name__) @@ -60,7 +60,8 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi.model.MULTIVI.setup_anndata`. + AnnData/MuData object that has been registered via + :meth:`~scvi.model.MULTIVI.setup_anndata` or :meth:`~scvi.model.MULTIVI.setup_mudata`. n_genes The number of gene expression features (genes). n_regions @@ -141,7 +142,7 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): def __init__( self, - adata: AnnData, + adata: AnnOrMuData, n_genes: int, n_regions: int, modality_weights: Literal["equal", "cell", "universal"] = "equal", @@ -589,7 +590,8 @@ def get_accessibility_estimates( return pd.DataFrame( imputed, index=adata.obs_names[indices], - columns=adata.var_names[self.n_genes :][region_mask], + columns=adata["rna"].var_names[self.n_genes :][region_mask] if + type(adata).__name__ == "MuData" else adata.var_names[self.n_genes :][region_mask], ) @torch.inference_mode() @@ -926,130 +928,6 @@ def differential_expression( return result - @torch.no_grad() - def get_protein_foreground_probability( - self, - adata: AnnData | None = None, - indices: Sequence[int] | None = None, - transform_batch: Sequence[Number | str] | None = None, - protein_list: Sequence[str] | None = None, - n_samples: int = 1, - batch_size: int | None = None, - use_z_mean: bool = True, - return_mean: bool = True, - return_numpy: bool | None = None, - ): - r"""Returns the foreground probability for proteins. - - This is denoted as :math:`(1 - \pi_{nt})` in the totalVI paper. - - Parameters - ---------- - adata - AnnData object with equivalent structure to initial AnnData. If ``None``, defaults to - the AnnData object used to initialize the model. - indices - Indices of cells in adata to use. If `None`, all cells are used. - transform_batch - Batch to condition on. - If transform_batch is: - - * ``None`` - real observed batch is used - * ``int`` - batch transform_batch is used - * ``List[int]`` - average over batches in list - protein_list - Return protein expression for a subset of genes. - This can save memory when working with large datasets and few genes are - of interest. - n_samples - Number of posterior samples to use for estimation. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - return_mean - Whether to return the mean of the samples. - return_numpy - Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame - includes gene names as columns. If either ``n_samples=1`` or ``return_mean=True``, - defaults to ``False``. Otherwise, it defaults to `True`. - - Returns - ------- - - **foreground_probability** - probability foreground for each protein - - If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. - Otherwise, shape is `(cells, genes)`. In this case, return type is - :class:`~pandas.DataFrame` unless `return_numpy` is True. - """ - adata = self._validate_anndata(adata) - post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) - - if protein_list is None: - protein_mask = slice(None) - else: - all_proteins = self.scvi_setup_dict_["protein_names"] - protein_mask = [True if p in protein_list else False for p in all_proteins] - - if n_samples > 1 and return_mean is False: - if return_numpy is False: - warnings.warn( - "`return_numpy` must be `True` if `n_samples > 1` and `return_mean` is " - "`False`, returning an `np.ndarray`.", - UserWarning, - stacklevel=settings.warnings_stacklevel, - ) - return_numpy = True - if indices is None: - indices = np.arange(adata.n_obs) - - py_mixings = [] - if not isinstance(transform_batch, IterableClass): - transform_batch = [transform_batch] - - transform_batch = _get_batch_code_from_category(self.adata_manager, transform_batch) - for tensors in post: - y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] - py_mixing = torch.zeros_like(y[..., protein_mask]) - if n_samples > 1: - py_mixing = torch.stack(n_samples * [py_mixing]) - for _ in transform_batch: - # generative_kwargs = dict(transform_batch=b) - generative_kwargs = {"use_z_mean": use_z_mean} - inference_kwargs = {"n_samples": n_samples} - _, generative_outputs = self.module.forward( - tensors=tensors, - inference_kwargs=inference_kwargs, - generative_kwargs=generative_kwargs, - compute_loss=False, - ) - py_mixing += torch.sigmoid(generative_outputs["py_"]["mixing"])[ - ..., protein_mask - ].cpu() - py_mixing /= len(transform_batch) - py_mixings += [py_mixing] - if n_samples > 1: - # concatenate along batch dimension -> result shape = (samples, cells, features) - py_mixings = torch.cat(py_mixings, dim=1) - # (cells, features, samples) - py_mixings = py_mixings.permute(1, 2, 0) - else: - py_mixings = torch.cat(py_mixings, dim=0) - - if return_mean is True and n_samples > 1: - py_mixings = torch.mean(py_mixings, dim=-1) - - py_mixings = py_mixings.cpu().numpy() - - if return_numpy is True: - return 1 - py_mixings - else: - pro_names = self.protein_state_registry.column_names - foreground_prob = pd.DataFrame( - 1 - py_mixings, - columns=pro_names[protein_mask], - index=adata.obs_names[indices], - ) - return foreground_prob - @classmethod @setup_anndata_dsp.dedent def setup_anndata( @@ -1125,6 +1003,7 @@ def setup_mudata( cls, mdata: MuData, rna_layer: str | None = None, + atac_layer: str | None = None, protein_layer: str | None = None, batch_key: str | None = None, size_factor_key: str | None = None, @@ -1142,6 +1021,8 @@ def setup_mudata( rna_layer RNA layer key. If `None`, will use `.X` of specified modality key. protein_layer + Protein layer key. If `None`, will use `.X` of specified modality key. + atac_layer ATAC layer key. If `None`, will use `.X` of specified modality key. %(param_batch_key)s %(param_size_factor_key)s @@ -1171,13 +1052,6 @@ def setup_mudata( mod_key=modalities.batch_key, ) mudata_fields = [ - fields.MuDataLayerField( - REGISTRY_KEYS.X_KEY, - rna_layer, - mod_key=modalities.rna_layer, - is_count_data=True, - mod_required=True, - ), batch_field, fields.MuDataCategoricalObsField( REGISTRY_KEYS.LABELS_KEY, @@ -1206,16 +1080,36 @@ def setup_mudata( mod_key=modalities.idx_layer, required=False, ), - fields.MuDataProteinLayerField( - REGISTRY_KEYS.PROTEIN_EXP_KEY, - protein_layer, - mod_key=modalities.protein_layer, - use_batch_mask=True, - batch_field=batch_field, - is_count_data=True, - mod_required=True, - ), ] + if modalities.rna_layer is not None: + mudata_fields.append( + fields.MuDataLayerField( + REGISTRY_KEYS.X_KEY, + rna_layer, + mod_key=modalities.rna_layer, + is_count_data=True, + mod_required=True, + )) + if modalities.atac_layer is not None: + mudata_fields.append( + fields.MuDataLayerField( + REGISTRY_KEYS.X_KEY, + atac_layer, + mod_key=modalities.atac_layer, + is_count_data=True, + mod_required=True, + )) + if modalities.protein_layer is not None: + mudata_fields.append( + fields.MuDataProteinLayerField( + REGISTRY_KEYS.PROTEIN_EXP_KEY, + protein_layer, + mod_key=modalities.protein_layer, + use_batch_mask=True, + batch_field=batch_field, + is_count_data=True, + mod_required=True, + )) adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index 3ad947d64b..379a307a63 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -35,7 +35,7 @@ from anndata import AnnData from mudata import MuData - from scvi._types import Number + from scvi._types import Number, AnnOrMuData logger = logging.getLogger(__name__) @@ -46,7 +46,8 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`. + AnnData/MuData object that has been registered via + :meth:`~scvi.model.TOTALVI.setup_anndata` or :meth:`~scvi.model.TOTALVI.setup_mudata`. n_latent Dimensionality of the latent space. gene_dispersion @@ -108,7 +109,7 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnOrMuData, n_latent: int = 20, gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein", diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 88a6be7069..80ca75e424 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -2,9 +2,14 @@ import numpy as np import pytest from mudata import MuData - +import scanpy as sc +import anndata as ad +import scvi +import os from scvi.data import synthetic_iid from scvi.model import MULTIVI +from scvi import REGISTRY_KEYS +from scvi.utils import attrdict def test_multivi(): @@ -68,7 +73,7 @@ def test_multivi(): n_regions=50, modality_weights="cell", ) - # assert vae.n_regions == data.obsm["protein_expression"].shape[1] + assert vae.n_proteins == data.obsm["protein_expression"].shape[1] vae.train(3) @@ -80,98 +85,414 @@ def test_multivi_single_batch(): vae.train(3) -def test_multivi_mudata(): - # optional data - big one +def test_multivi_mudata_rna_prot_external(): + # Example on how to download protein adata to mudata (from multivi tutotial) - mudata RNA/PROT + adata = scvi.data.pbmcs_10x_cite_seq() + adata.layers["counts"] = adata.X.copy() + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + adata.obs_names_make_unique() + protein_adata = ad.AnnData(adata.obsm["protein_expression"]) + protein_adata.obs_names = adata.obs_names + del adata.obsm["protein_expression"] + mdata = MuData({"rna": adata, "protein": protein_adata}) + sc.pp.highly_variable_genes( + mdata.mod["rna"], + n_top_genes=4000, + flavor="seurat_v3", + batch_key="batch", + layer="counts", + ) + mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy() + mdata.update() + # mdata + # mdata.mod + MULTIVI.setup_mudata( + mdata, + rna_layer="counts", # mean we use: mdata.mod["rna_subset"].layers["counts"] + protein_layer=None, # mean we use: mdata.mod["protein"].X + batch_key="batch", # the batch is here: mdata.mod["rna_subset"].obs["batch"] + modalities={ + "rna_layer": "rna_subset", + "protein_layer": "protein", + "batch_key": "rna_subset", + }, + ) + model = MULTIVI(mdata, n_genes=50, n_regions=50) + model.train(1, train_size=0.9) + + +def test_multivi_mudata_rna_atac_external(): + # optional data - mudata RNA/ATAC url = ( "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X" "/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5" ) mdata = muon.read_10x_h5("data/multiome10k.h5mu", backup_url=url) - mdata - MULTIVI.setup_mudata(mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"}) - vae = MULTIVI(mdata, n_genes=50, n_regions=50) + # Preprocessing + sc.pp.normalize_total(mdata.mod["rna"]) + sc.pp.log1p(mdata.mod["rna"]) + sc.pp.highly_variable_genes( + mdata.mod["rna"], + n_top_genes=4000, + flavor="seurat_v3", + ) + mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy() + sc.pp.normalize_total(mdata.mod["atac"]) + sc.pp.log1p(mdata.mod["atac"]) + sc.pp.highly_variable_genes( + mdata.mod["atac"], + n_top_genes=4000, + flavor="seurat_v3", + ) + mdata.mod["atac_subset"] = (mdata.mod["atac"][:, mdata.mod["atac"].var["highly_variable"]]. + copy()) + mdata.update() + # mdata + # mdata.mod + MULTIVI.setup_mudata(mdata, modalities={"rna_layer": "rna_subset", + "atac_layer": "atac_subset"}) + model = MULTIVI(mdata, n_genes=50, n_regions=50) + model.train(1, train_size=0.9) - adata = synthetic_iid() - protein_adata = synthetic_iid(n_genes=50) - mdata = MuData({"rna": adata, "protein": protein_adata}) + +def test_multivi_mudata(): + # use of syntetic data of rna/proteins/atac for speed + + # adata = synthetic_iid() + # protein_adata = synthetic_iid() + # atac_adata = synthetic_iid() + # mdata = MuData({"rna": adata, "protein": protein_adata, "atac": atac_adata}) + # MULTIVI.setup_mudata( + # mdata, + # batch_key="batch", + # modalities={"rna_layer": "rna", "protein_layer": "protein", "batch_key": "rna", + # "atac_layer": "atac"}, + # ) + # n_obs = mdata.n_obs + # n_genes = np.min([adata.n_vars, protein_adata.n_vars]) + # n_regions = protein_adata.X.shape[1] + # n_latent = 10 + + mdata = synthetic_iid(return_mudata=True) MULTIVI.setup_mudata( mdata, batch_key="batch", - modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + modalities={"rna_layer": "rna", "protein_layer": "protein_expression", + "atac_layer": "accessibility"}, ) - n_obs = mdata.n_obs - n_genes = np.min([adata.n_vars, protein_adata.n_vars]) - n_regions = protein_adata.X.shape[1] + # n_genes = np.min([mdata.n_vars, mdata["protein_expression"].n_vars]) + # n_regions = mdata["protein_expression"].X.shape[1] n_latent = 10 - model = MULTIVI(mdata, n_latent=n_latent, n_genes=n_genes, n_regions=n_regions) - model.train(1, train_size=0.5) + model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model.train(1, train_size=0.9) assert model.is_trained is True z = model.get_latent_representation() assert z.shape == (n_obs, n_latent) model.get_elbo() - # model.get_marginal_ll(n_mc_samples=3) model.get_reconstruction_error() model.get_normalized_expression() model.get_normalized_expression(transform_batch=["batch_0", "batch_1"]) - # model.get_latent_library_size() - # model.get_protein_foreground_probability() - # model.get_protein_foreground_probability(transform_batch=["batch_0", "batch_1"]) - # post_pred = model.posterior_predictive_sample(n_samples=2) - # assert post_pred.shape == (n_obs, n_genes + n_regions, 2) - # post_pred = model.posterior_predictive_sample(n_samples=1) - # assert post_pred.shape == (n_obs, n_genes + n_regions) - # feature_correlation_matrix1 = model.get_feature_correlation_matrix - # (correlation_type="spearman") - # feature_correlation_matrix1 = model.get_feature_correlation_matrix( - # correlation_type="spearman", transform_batch=["batch_0", "batch_1"] - # ) - # feature_correlation_matrix2 = (model.get_feature_correlation_matrix - # (correlation_type="pearson")) - # assert feature_correlation_matrix1.shape == ( - # n_genes + n_regions, - # n_genes + n_regions, - # ) - # assert feature_correlation_matrix2.shape == ( - # n_genes + n_regions, - # n_genes + n_regions, - # ) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() model.get_elbo(indices=model.validation_indices) - # model.get_marginal_ll(indices=model.validation_indices, n_mc_samples=3) model.get_reconstruction_error(indices=model.validation_indices) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + # adata2 = synthetic_iid() + # protein_adata2 = synthetic_iid(n_genes=50) + # mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + mdata2 = synthetic_iid(return_mudata=True) + MULTIVI.setup_mudata( + mdata2, + batch_key="batch", + modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, + ) + norm_exp = model.get_normalized_expression(mdata2, indices=[1, 2, 3]) + assert norm_exp.shape == (3, 50) + # test transfer_anndata_setup + view + mdata3 = synthetic_iid(return_mudata=True) + mdata3.obs["_indices"] = np.arange(mdata3.n_obs) + model.get_elbo(mdata3[:10]) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_auto_transfer_mudata(): + # test automatic transfer_fields + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata, n_genes=50, n_regions=50) adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) - mdata2 = MuData({"rna": adata, "protein": protein_adata}) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + mdata2.obs["_indices"] = np.arange(mdata2.n_obs) + model.get_elbo(mdata2) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_incorrect_mapping_mudata(): + # test that we catch incorrect mappings + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) MULTIVI.setup_mudata( - mdata2, + mdata, batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - norm_exp = model.get_normalized_expression(mdata2, indices=[1, 2, 3]) - # assert norm_exp[0].shape == (3, adata2.n_vars) - # assert norm_exp[1].shape == (3, protein_adata2.n_vars) - # norm_exp = model.get_normalized_expression( - # mdata2, - # gene_list=adata2.var_names[:5].to_list(), - # protein_list=protein_adata2.var_names[:3].to_list(), - # transform_batch=["batch_0", "batch_1"], - # ) + model = MULTIVI(mdata, n_genes=50, n_regions=50) + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + adata2.obs.batch = adata2.obs.batch.cat.rename_categories(["batch_0", "batch_10"]) + with pytest.raises(ValueError): + model.get_elbo(mdata2) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() - # latent_lib_size = model.get_latent_library_size(mdata2, indices=[1, 2, 3]) - # assert latent_lib_size.shape == (3, 1) - # pro_foreground_prob = model.get_protein_foreground_probability( - # mdata2, indices=[1, 2, 3], protein_list=["gene_1", "gene_2"] - # ) - # assert pro_foreground_prob.shape == (3, 2) - # model.posterior_predictive_sample(mdata2) - # model.get_feature_correlation_matrix(mdata2) +def test_multivi_reordered_mapping_mudata(): + # test that same mapping different order is okay + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata, n_genes=50, n_regions=50) + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + adata2.obs.batch = adata2.obs.batch.cat.rename_categories(["batch_1", "batch_0"]) + mdata2.obs["_indices"] = np.arange(mdata2.n_obs) + model.get_elbo(mdata2) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_model_library_size_mudata(): + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + + n_latent = 10 + model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model.train(1, train_size=0.5) + assert model.is_trained is True + model.get_elbo() + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_size_factor_mudata(): + adata = synthetic_iid() + adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + size_factor_key="size_factor", + modalities={ + "rna_layer": "rna", + "batch_key": "rna", + "protein_layer": "protein", + "size_factor_key": "rna", + }, + ) + + n_latent = 10 + + # Test size_factor_key overrides use_observed_lib_size. + model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + assert model.module.use_size_factor_key + model.train(1, train_size=0.5) + + model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + assert model.module.use_size_factor_key + model.train(1, train_size=0.5) + + +def test_multivi_saving_and_loading_mudata(save_path: str="."): + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata, n_genes=50, n_regions=50) + model.train(1, train_size=0.2) + z1 = model.get_latent_representation(mdata) + test_idx1 = model.validation_indices + + model.save(save_path, overwrite=True, save_anndata=True) + model.view_setup_args(save_path) + + model = MULTIVI.load(save_path) + model.get_latent_representation() + + # Load with mismatched genes. + tmp_adata = synthetic_iid( + n_genes=200, + ) + tmp_protein_adata = synthetic_iid(n_genes=50) + tmp_mdata = MuData({"rna": tmp_adata, "protein": tmp_protein_adata}) + with pytest.raises(ValueError): + MULTIVI.load(save_path, adata=tmp_mdata) + + # Load with different batches. + tmp_adata = synthetic_iid() + tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories(["batch_2", "batch_3"]) + tmp_protein_adata = synthetic_iid(n_genes=50) + tmp_mdata = MuData({"rna": tmp_adata, "protein": tmp_protein_adata}) + with pytest.raises(ValueError): + MULTIVI.load(save_path, adata=tmp_mdata) + + model = MULTIVI.load(save_path, adata=mdata) + assert REGISTRY_KEYS.BATCH_KEY in model.adata_manager.data_registry + assert model.adata_manager.data_registry.batch == attrdict( + {"mod_key": "rna", "attr_name": "obs", "attr_key": "_scvi_batch"} + ) + + z2 = model.get_latent_representation() + test_idx2 = model.validation_indices + np.testing.assert_array_equal(z1, z2) + np.testing.assert_array_equal(test_idx1, test_idx2) + assert model.is_trained is True + + save_path = os.path.join(save_path, "tmp") - # test transfer_anndata_setup + view adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) - # model.get_elbo(mdata2[:10]) + MULTIVI.setup_mudata( + mdata2, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + + +def test_scarches_mudata_prep_layer(save_path: str="."): + n_latent = 5 + mdata1 = synthetic_iid(return_mudata=True) + + mdata1["rna"].layers["counts"] = mdata1["rna"].X.copy() + MULTIVI.setup_mudata( + mdata1, + batch_key="batch", + modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, + ) + model = MULTIVI(mdata1, n_latent=n_latent, n_genes=50, n_regions=50) + model.train(1, check_val_every_n_epoch=1) + dir_path = os.path.join(save_path, "saved_model/") + model.save(dir_path, overwrite=True) + + # mdata2 has more genes and missing 10 genes from mdata1. + # protein/acessibility features are same as in mdata1 + mdata2 = synthetic_iid(n_genes=110, return_mudata=True) + mdata2["rna"].layers["counts"] = mdata2["rna"].X.copy() + new_var_names_init = [f"Random {i}" for i in range(10)] + new_var_names = new_var_names_init + mdata2["rna"].var_names[10:].to_list() + mdata2["rna"].var_names = new_var_names + + original_protein_values = mdata2["protein_expression"].X.copy() + original_accessibility_values = mdata2["accessibility"].X.copy() + + MULTIVI.prepare_query_mudata(mdata2, dir_path) + # should be padded 0s + assert np.sum(mdata2["rna"][:, mdata2["rna"].var_names[:10]].layers["counts"]) == 0 + np.testing.assert_equal( + mdata2["rna"].var_names[:10].to_numpy(), mdata1["rna"].var_names[:10].to_numpy() + ) + + # values of other modalities should be unchanged + np.testing.assert_equal(original_protein_values, mdata2["protein_expression"].X) + np.testing.assert_equal(original_accessibility_values, mdata2["accessibility"].X) + + # and names should also be the same + np.testing.assert_equal( + mdata2["protein_expression"].var_names.to_numpy(), + mdata1["protein_expression"].var_names.to_numpy(), + ) + np.testing.assert_equal( + mdata2["accessibility"].var_names.to_numpy(), mdata1["accessibility"].var_names.to_numpy() + ) + MULTIVI.load_query_data(mdata2, dir_path) + + +def test_multivi_save_load_mudata_format(save_path: str="."): + mdata = synthetic_iid(return_mudata=True, protein_expression_key="protein") + invalid_mdata = mdata.copy() + invalid_mdata.mod["protein"] = invalid_mdata.mod["protein"][:, :10].copy() + MULTIVI.setup_mudata( + mdata, + modalities={"rna_layer": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata, n_genes=50, n_regions=50) + model.train(max_epochs=1) + + legacy_model_path = os.path.join(save_path, "legacy_model") + model.save( + legacy_model_path, + overwrite=True, + save_anndata=False, + legacy_mudata_format=True, + ) + + with pytest.raises(ValueError): + _ = MULTIVI.load(legacy_model_path, adata=invalid_mdata) + model = MULTIVI.load(legacy_model_path, adata=mdata) + + model_path = os.path.join(save_path, "model") + model.save( + model_path, + overwrite=True, + save_anndata=False, + legacy_mudata_format=False, + ) + with pytest.raises(ValueError): + _ = MULTIVI.load(legacy_model_path, adata=invalid_mdata) + model = MULTIVI.load(model_path, adata=mdata) From e03b00699e24d4fcfda2b86c33c19533d0d9801d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:36:00 +0000 Subject: [PATCH 12/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_multivi.py | 17 ++++++++++------- src/scvi/model/_totalvi.py | 2 +- tests/model/test_multivi.py | 33 ++++++++++++++++++++------------- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 42eb8d4aec..0bc5dfe517 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -2,7 +2,6 @@ import logging import warnings -from collections.abc import Iterable as IterableClass from functools import partial from typing import TYPE_CHECKING @@ -46,7 +45,7 @@ from anndata import AnnData from mudata import MuData - from scvi._types import Number, AnnOrMuData + from scvi._types import AnnOrMuData, Number logger = logging.getLogger(__name__) @@ -590,8 +589,9 @@ def get_accessibility_estimates( return pd.DataFrame( imputed, index=adata.obs_names[indices], - columns=adata["rna"].var_names[self.n_genes :][region_mask] if - type(adata).__name__ == "MuData" else adata.var_names[self.n_genes :][region_mask], + columns=adata["rna"].var_names[self.n_genes :][region_mask] + if type(adata).__name__ == "MuData" + else adata.var_names[self.n_genes :][region_mask], ) @torch.inference_mode() @@ -1089,7 +1089,8 @@ def setup_mudata( mod_key=modalities.rna_layer, is_count_data=True, mod_required=True, - )) + ) + ) if modalities.atac_layer is not None: mudata_fields.append( fields.MuDataLayerField( @@ -1098,7 +1099,8 @@ def setup_mudata( mod_key=modalities.atac_layer, is_count_data=True, mod_required=True, - )) + ) + ) if modalities.protein_layer is not None: mudata_fields.append( fields.MuDataProteinLayerField( @@ -1109,7 +1111,8 @@ def setup_mudata( batch_field=batch_field, is_count_data=True, mod_required=True, - )) + ) + ) adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index 379a307a63..a41123af72 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -35,7 +35,7 @@ from anndata import AnnData from mudata import MuData - from scvi._types import Number, AnnOrMuData + from scvi._types import AnnOrMuData, Number logger = logging.getLogger(__name__) diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 80ca75e424..e5c541bd90 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -1,14 +1,16 @@ +import os + +import anndata as ad import muon import numpy as np import pytest -from mudata import MuData import scanpy as sc -import anndata as ad +from mudata import MuData + import scvi -import os +from scvi import REGISTRY_KEYS from scvi.data import synthetic_iid from scvi.model import MULTIVI -from scvi import REGISTRY_KEYS from scvi.utils import attrdict @@ -145,13 +147,15 @@ def test_multivi_mudata_rna_atac_external(): n_top_genes=4000, flavor="seurat_v3", ) - mdata.mod["atac_subset"] = (mdata.mod["atac"][:, mdata.mod["atac"].var["highly_variable"]]. - copy()) + mdata.mod["atac_subset"] = mdata.mod["atac"][ + :, mdata.mod["atac"].var["highly_variable"] + ].copy() mdata.update() # mdata # mdata.mod - MULTIVI.setup_mudata(mdata, modalities={"rna_layer": "rna_subset", - "atac_layer": "atac_subset"}) + MULTIVI.setup_mudata( + mdata, modalities={"rna_layer": "rna_subset", "atac_layer": "atac_subset"} + ) model = MULTIVI(mdata, n_genes=50, n_regions=50) model.train(1, train_size=0.9) @@ -178,8 +182,11 @@ def test_multivi_mudata(): MULTIVI.setup_mudata( mdata, batch_key="batch", - modalities={"rna_layer": "rna", "protein_layer": "protein_expression", - "atac_layer": "accessibility"}, + modalities={ + "rna_layer": "rna", + "protein_layer": "protein_expression", + "atac_layer": "accessibility", + }, ) n_obs = mdata.n_obs # n_genes = np.min([mdata.n_vars, mdata["protein_expression"].n_vars]) @@ -354,7 +361,7 @@ def test_multivi_size_factor_mudata(): model.train(1, train_size=0.5) -def test_multivi_saving_and_loading_mudata(save_path: str="."): +def test_multivi_saving_and_loading_mudata(save_path: str = "."): adata = synthetic_iid() protein_adata = synthetic_iid(n_genes=50) mdata = MuData({"rna": adata, "protein": protein_adata}) @@ -415,7 +422,7 @@ def test_multivi_saving_and_loading_mudata(save_path: str="."): ) -def test_scarches_mudata_prep_layer(save_path: str="."): +def test_scarches_mudata_prep_layer(save_path: str = "."): n_latent = 5 mdata1 = synthetic_iid(return_mudata=True) @@ -463,7 +470,7 @@ def test_scarches_mudata_prep_layer(save_path: str="."): MULTIVI.load_query_data(mdata2, dir_path) -def test_multivi_save_load_mudata_format(save_path: str="."): +def test_multivi_save_load_mudata_format(save_path: str = "."): mdata = synthetic_iid(return_mudata=True, protein_expression_key="protein") invalid_mdata = mdata.copy() invalid_mdata.mod["protein"] = invalid_mdata.mod["protein"][:, :10].copy() From ea59fd1a7f7164de31897f4c203dfd5cedb5f829 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 7 Nov 2024 16:37:30 +0200 Subject: [PATCH 13/51] small fix --- tests/model/test_multivi.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 80ca75e424..c688bdaca3 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -272,11 +272,6 @@ def test_multivi_incorrect_mapping_mudata(): adata2.obs.batch = adata2.obs.batch.cat.rename_categories(["batch_0", "batch_10"]) with pytest.raises(ValueError): model.get_elbo(mdata2) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) - model.get_library_size_factors() - model.get_region_factors() def test_multivi_reordered_mapping_mudata(): From 9118af8a446c7009303942a1a52abad400d2fae8 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 7 Nov 2024 16:52:00 +0200 Subject: [PATCH 14/51] small fix --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4d19105208..85f1a38f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ loompy = ["loompy>=3.0.6"] scanpy = ["scanpy>=1.6"] optional = [ - "scvi-tools[autotune,aws,hub,loompy,muon,pymde,regseq,scanpy]" + "scvi-tools[autotune,aws,hub,loompy,muon,scikit-misc,pymde,regseq,scanpy]" ] tutorials = [ "cell2location", @@ -108,7 +108,6 @@ tutorials = [ "pooch", "pynndescent", "igraph", - "scikit-misc", "scrublet", "scib-metrics", "scvi-tools[optional]", From 870b0fc4b1c916c20ed5d9c61aeabe099a90b8e6 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 7 Nov 2024 16:54:08 +0200 Subject: [PATCH 15/51] small fix --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 85f1a38f43..e60cf754bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,8 @@ hub = ["huggingface_hub"] pymde = ["pymde"] # mudata dependencies muon = ["muon"] +# mudata dependencies +scikit-misc = ["scikit-misc"] # scvi.data.add_dna_sequence regseq = ["biopython>=1.81", "genomepy"] # read loom From 4f08c1b639fe53b46bc893878a7aefb504cb7dd4 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 7 Nov 2024 16:54:30 +0200 Subject: [PATCH 16/51] small fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e60cf754bb..7c4b44ae7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ hub = ["huggingface_hub"] pymde = ["pymde"] # mudata dependencies muon = ["muon"] -# mudata dependencies +# scanpy dependencies scikit-misc = ["scikit-misc"] # scvi.data.add_dna_sequence regseq = ["biopython>=1.81", "genomepy"] From 9622f88daefa700042664656b2a6dbecf245f8bf Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 7 Nov 2024 16:55:30 +0200 Subject: [PATCH 17/51] small fix --- pyproject.toml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7c4b44ae7b..bfd330362b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,17 +90,15 @@ hub = ["huggingface_hub"] pymde = ["pymde"] # mudata dependencies muon = ["muon"] -# scanpy dependencies -scikit-misc = ["scikit-misc"] # scvi.data.add_dna_sequence regseq = ["biopython>=1.81", "genomepy"] # read loom loompy = ["loompy>=3.0.6"] # scvi.criticism and read 10x -scanpy = ["scanpy>=1.6"] +scanpy = ["scanpy>=1.6","scikit-misc"] optional = [ - "scvi-tools[autotune,aws,hub,loompy,muon,scikit-misc,pymde,regseq,scanpy]" + "scvi-tools[autotune,aws,hub,loompy,muon,pymde,regseq,scanpy]" ] tutorials = [ "cell2location", From 3299355e7a9884ad170d20f77bf7ce65bef7bf93 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 12 Nov 2024 18:55:03 +0200 Subject: [PATCH 18/51] Added mudata minification models for MULTIVI & TOTALVI as well as tests --- CHANGELOG.md | 1 + src/scvi/data/_utils.py | 4 +- src/scvi/model/_multivi.py | 121 +++++- src/scvi/model/_totalvi.py | 103 ++++- src/scvi/model/base/__init__.py | 7 +- src/scvi/model/base/_base_model.py | 60 +++ src/scvi/model/utils/__init__.py | 4 +- src/scvi/model/utils/_minification.py | 30 ++ src/scvi/module/_multivae.py | 7 +- src/scvi/module/_totalvae.py | 7 +- .../test_models_with_mudata_minified_data.py | 358 ++++++++++++++++++ 11 files changed, 685 insertions(+), 17 deletions(-) create mode 100644 tests/model/test_models_with_mudata_minified_data.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c261cc775e..0c89a6484e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- Add MuData Minification option to {class}`~scvi.model.MULTIVI` and {class}`~scvi.model.TOTALVI` {pr}`30XX`. - Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`. diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index fc6228a29f..3fff82a13b 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None: return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None) -def _is_minified(adata: AnnData | str) -> bool: +def _is_minified(adata: AnnOrMuData | str) -> bool: uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY if isinstance(adata, AnnData): return adata.uns.get(uns_key, None) is not None + elif isinstance(adata, MuData): + return adata.uns.get(uns_key, None) is not None elif isinstance(adata, str): with h5py.File(adata) as fp: return uns_key in read_elem(fp["uns"]).keys() diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 0bc5dfe517..0e625e9ee1 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -13,13 +13,17 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields +from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE +from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( CategoricalJointObsField, CategoricalObsField, LayerField, NumericalJointObsField, NumericalObsField, + ObsmField, ProteinObsmField, + StringUnsField, ) from scvi.model._utils import ( _get_batch_code_from_category, @@ -28,11 +32,12 @@ ) from scvi.model.base import ( ArchesMixin, - BaseModelClass, + BaseMudataMinifiedModeModelClass, UnsupervisedTrainingMixin, VAEMixin, ) from scvi.model.base._de_core import _de_core +from scvi.model.utils import get_minified_mudata from scvi.module import MULTIVAE from scvi.train import AdversarialTrainingPlan from scvi.train._callbacks import SaveBestState @@ -45,12 +50,19 @@ from anndata import AnnData from mudata import MuData - from scvi._types import AnnOrMuData, Number + from scvi._types import AnnOrMuData, MinifiedDataType, Number + from scvi.data.fields import ( + BaseAnnDataField, + ) + +_MULTIVI_LATENT_QZM = "_multivi_latent_qzm" +_MULTIVI_LATENT_QZV = "_multivi_latent_qzv" +_MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size" logger = logging.getLogger(__name__) -class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): +class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, ArchesMixin, BaseMudataMinifiedModeModelClass): """Integration of multi-modal and single-modality data :cite:p:`AshuachGabitto21`. MultiVI is used to integrate multiomic datasets with single-modality (expression @@ -174,6 +186,10 @@ def __init__( use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + # TODO: ADD MINIFICATION CONSIDERATION HERE? + # if not use_size_factor_key and self.minified_data_type is None: + # library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) + if "n_proteins" in self.summary_stats: n_proteins = self.summary_stats.n_proteins else: @@ -224,6 +240,7 @@ def __init__( self.n_genes = n_genes self.n_regions = n_regions self.n_proteins = n_proteins + self.module.minified_data_type = self.minified_data_type @devices_dsp.dedent def train( @@ -414,6 +431,7 @@ def get_latent_representation( indices: Sequence[int] | None = None, give_mean: bool = True, batch_size: int | None = None, + return_dist: bool = False, ) -> np.ndarray: r"""Return the latent representation for each cell. @@ -430,6 +448,9 @@ def get_latent_representation( Give mean of distribution or sample from it. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_dist + If ``True``, returns the mean and variance of the latent distribution. Otherwise, + returns the mean of the latent distribution. Returns ------- @@ -457,6 +478,8 @@ def get_latent_representation( adata = self._validate_anndata(adata) scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent = [] + qz_means = [] + qz_vars = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) @@ -473,8 +496,17 @@ def get_latent_representation( else: z = qz_m + if return_dist: + qz_means.append(qz_m.cpu()) + qz_vars.append(qz_v.cpu()) + continue + latent += [z.cpu()] - return torch.cat(latent).numpy() + + if return_dist: + return torch.cat(qz_means).numpy(), torch.cat(qz_vars).numpy() + else: + return torch.cat(latent).numpy() @torch.inference_mode() def get_accessibility_estimates( @@ -1113,6 +1145,87 @@ def setup_mudata( mod_required=True, ) ) + # TODO: register new fields if the adata is minified + mdata_minify_type = _get_adata_minify_type(mdata) + if mdata_minify_type is not None: + mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type) adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) + + @staticmethod + def _get_fields_for_mudata_minification( + minified_data_type: MinifiedDataType, + ) -> list[BaseAnnDataField]: + """Return the fields required for adata minification of the given minified_data_type.""" + if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + fields = [ + ObsmField( + REGISTRY_KEYS.LATENT_QZM_KEY, + _MULTIVI_LATENT_QZM, + ), + ObsmField( + REGISTRY_KEYS.LATENT_QZV_KEY, + _MULTIVI_LATENT_QZV, + ), + NumericalObsField( + REGISTRY_KEYS.OBSERVED_LIB_SIZE, + _MULTIVI_OBSERVED_LIB_SIZE, + ), + ] + else: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + fields.append( + StringUnsField( + REGISTRY_KEYS.MINIFY_TYPE_KEY, + _ADATA_MINIFY_TYPE_UNS_KEY, + ), + ) + return fields + + def minify_mudata( + self, + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + ) -> None: + """Minifies the model's mudata. + + Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns + containing minified-adata type, and library size. + This also sets the appropriate property on the module to indicate that the mudata is + minified. + + Parameters + ---------- + minified_data_type + How to minify the data. Currently only supports `latent_posterior_parameters`. + If minified_data_type == `latent_posterior_parameters`: + + * the original count data is removed (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + use_latent_qzm_key + Key to use in `adata.obsm` where the latent qzm params are stored + use_latent_qzv_key + Key to use in `adata.obsm` where the latent qzv params are stored + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the adata. + """ + # without removing the original counts. + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + + # if self.module.use_observed_lib_size is False: + # raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") + + minified_adata = get_minified_mudata(self.adata, minified_data_type) + minified_adata.obsm[_MULTIVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] + minified_adata.obsm[_MULTIVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] + counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + minified_adata.obs[_MULTIVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) + self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type) + self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index a41123af72..0e3b2cd965 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -12,7 +12,9 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields -from scvi.data._utils import _check_nonnegative_integers +from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE +from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type +from scvi.data.fields import NumericalObsField, ObsmField, StringUnsField from scvi.dataloaders import DataSplitter from scvi.model._utils import ( _get_batch_code_from_category, @@ -22,11 +24,12 @@ get_max_epochs_heuristic, ) from scvi.model.base._de_core import _de_core +from scvi.model.utils import get_minified_mudata from scvi.module import TOTALVAE from scvi.train import AdversarialTrainingPlan, TrainRunner from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp -from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin +from .base import ArchesMixin, BaseMudataMinifiedModeModelClass, RNASeqMixin, VAEMixin if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -35,12 +38,19 @@ from anndata import AnnData from mudata import MuData - from scvi._types import AnnOrMuData, Number + from scvi._types import AnnOrMuData, MinifiedDataType, Number + from scvi.data.fields import ( + BaseAnnDataField, + ) + +_TOTALVI_LATENT_QZM = "_totalvi_latent_qzm" +_TOTALVI_LATENT_QZV = "_totalvi_latent_qzv" +_TOTALVI_OBSERVED_LIB_SIZE = "_totalvi_observed_lib_size" logger = logging.getLogger(__name__) -class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): +class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMudataMinifiedModeModelClass): """total Variational Inference :cite:p:`GayosoSteier21`. Parameters @@ -162,7 +172,8 @@ def __init__( n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None - if not use_size_factor_key: + # TODO: ADD MINIFICATION CONSIDERATION + if not use_size_factor_key and self.minified_data_type is None: library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( @@ -184,6 +195,7 @@ def __init__( library_log_vars=library_log_vars, **model_kwargs, ) + self.module.minified_data_type = self.minified_data_type self._model_summary_string = ( f"TotalVI Model with the following params: \nn_latent: {n_latent}, " f"gene_dispersion: {gene_dispersion}, protein_dispersion: {protein_dispersion}, " @@ -1331,6 +1343,87 @@ def setup_mudata( mod_required=True, ), ] + # TODO: register new fields if the mudata is minified + mdata_minify_type = _get_adata_minify_type(mdata) + if mdata_minify_type is not None: + mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type) adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) + + @staticmethod + def _get_fields_for_mudata_minification( + minified_data_type: MinifiedDataType, + ) -> list[BaseAnnDataField]: + """Return the fields required for mudata minification of the given minified_data_type.""" + if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + fields = [ + ObsmField( + REGISTRY_KEYS.LATENT_QZM_KEY, + _TOTALVI_LATENT_QZM, + ), + ObsmField( + REGISTRY_KEYS.LATENT_QZV_KEY, + _TOTALVI_LATENT_QZV, + ), + NumericalObsField( + REGISTRY_KEYS.OBSERVED_LIB_SIZE, + _TOTALVI_OBSERVED_LIB_SIZE, + ), + ] + else: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + fields.append( + StringUnsField( + REGISTRY_KEYS.MINIFY_TYPE_KEY, + _ADATA_MINIFY_TYPE_UNS_KEY, + ), + ) + return fields + + def minify_mudata( + self, + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + ) -> None: + """Minifies the model's mudata. + + Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns + containing minified-adata type, and library size. + This also sets the appropriate property on the module to indicate that the mudata is + minified. + + Parameters + ---------- + minified_data_type + How to minify the data. Currently only supports `latent_posterior_parameters`. + If minified_data_type == `latent_posterior_parameters`: + + * the original count data is removed (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + use_latent_qzm_key + Key to use in `adata.obsm` where the latent qzm params are stored + use_latent_qzv_key + Key to use in `adata.obsm` where the latent qzv params are stored + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the adata. + """ + # without removing the original counts. + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + + if self.module.use_observed_lib_size is False: + raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") + + minified_adata = get_minified_mudata(self.adata, minified_data_type) + minified_adata.obsm[_TOTALVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] + minified_adata.obsm[_TOTALVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] + counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + minified_adata.obs[_TOTALVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) + self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type) + self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/base/__init__.py b/src/scvi/model/base/__init__.py index e8573f8d53..4b38494caf 100644 --- a/src/scvi/model/base/__init__.py +++ b/src/scvi/model/base/__init__.py @@ -1,5 +1,9 @@ from ._archesmixin import ArchesMixin -from ._base_model import BaseMinifiedModeModelClass, BaseModelClass +from ._base_model import ( + BaseMinifiedModeModelClass, + BaseModelClass, + BaseMudataMinifiedModeModelClass, +) from ._differential import DifferentialComputation from ._embedding_mixin import EmbeddingMixin from ._jaxmixin import JaxTrainingMixin @@ -26,5 +30,6 @@ "DifferentialComputation", "JaxTrainingMixin", "BaseMinifiedModeModelClass", + "BaseMudataMinifiedModeModelClass", "EmbeddingMixin", ] diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index fd47bf1926..137bc9e04a 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -944,3 +944,63 @@ def summary_string(self): hasattr(self, "minified_data_type") and self.minified_data_type is not None ) return summary_string + + +class BaseMudataMinifiedModeModelClass(BaseModelClass): + """Abstract base class for scvi-tools models that can handle minified data.""" + + @property + def minified_data_type(self) -> MinifiedDataType | None: + """The type of minified data associated with this model, if applicable.""" + return ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) + if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry + else None + ) + + @abstractmethod + def minify_mudata( + self, + *args, + **kwargs, + ): + """Minifies the model's mudata. + + Minifies the mudata, and registers new mudata fields as required (can be model-specific). + This also sets the appropriate property on the module to indicate that the adata is + minified. + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the adata. + """ + + @staticmethod + @abstractmethod + def _get_fields_for_mudata_minification(minified_data_type: MinifiedDataType): + """Return the mudata fields required for adata minification of the given type.""" + + def _update_mudata_and_manager_post_minification( + self, minified_adata: AnnOrMuData, minified_data_type: MinifiedDataType + ): + """Update the mudata and manager inplace after creating a minified adata.""" + # Register this new adata with the model, creating a new manager in the cache + self._validate_anndata(minified_adata) + new_adata_manager = self.get_anndata_manager(minified_adata, required=True) + # This inplace edits the manager + new_adata_manager.register_new_fields( + self._get_fields_for_mudata_minification(minified_data_type) + ) + # We set the adata attribute of the model as this will update self.registry_ + # and self.adata_manager with the new adata manager + self.adata = minified_adata + + @property + def summary_string(self): + """Summary string of the model.""" + summary_string = super().summary_string + summary_string += "\nModel's adata is minified?: {}".format( + hasattr(self, "minified_data_type") and self.minified_data_type is not None + ) + return summary_string diff --git a/src/scvi/model/utils/__init__.py b/src/scvi/model/utils/__init__.py index 003b763e5e..0ee147802d 100644 --- a/src/scvi/model/utils/__init__.py +++ b/src/scvi/model/utils/__init__.py @@ -1,4 +1,4 @@ from ._mde import mde -from ._minification import get_minified_adata_scrna +from ._minification import get_minified_adata_scrna, get_minified_mudata -__all__ = ["mde", "get_minified_adata_scrna"] +__all__ = ["mde", "get_minified_adata_scrna", "get_minified_mudata"] diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index cf84687bc5..aab9cb79ff 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -1,4 +1,5 @@ from anndata import AnnData +from mudata import MuData from scipy.sparse import csr_matrix from scvi._types import MinifiedDataType @@ -41,3 +42,32 @@ def get_minified_adata_scrna( del bdata.uns[_SCVI_UUID_KEY] bdata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type return bdata + + +def get_minified_mudata( + mdata: MuData, + minified_data_type: MinifiedDataType, +) -> MuData: + """Returns a minified adata that works for most multi modality models (MULTIVI, TOTALVI). + + Parameters + ---------- + mdata + Original adata, of which we to create a minified version. + minified_data_type + How to minify the data. + """ + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + + bdata = mdata.copy() + for modality in mdata.mod_names: + all_zeros = csr_matrix(mdata[modality].X.shape) + bdata[modality].X = all_zeros + if len(mdata[modality].layers) > 0: + layers = {layer: all_zeros for layer in mdata[modality].layers} + bdata[modality].layers = layers + # Remove scvi uuid key to make bdata fresh w.r.t. the model's manager + del bdata.uns[_SCVI_UUID_KEY] + bdata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type + return bdata diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 6ad24b65d2..cec268d184 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -15,7 +15,7 @@ ZeroInflatedNegativeBinomial, ) from scvi.module._peakvae import Decoder as DecoderPeakVI -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderSCVI, Encoder, FCLayers from ._utils import masked_softmax @@ -179,7 +179,7 @@ def forward(self, z: torch.Tensor, *cat_list: int): return py_, log_pro_back_mean -class MULTIVAE(BaseModuleClass): +class MULTIVAE(BaseMinifiedModeModuleClass): """Variational auto-encoder model for joint paired + unpaired RNA-seq and ATAC-seq data. Parameters @@ -533,6 +533,9 @@ def __init__( def _get_inference_input(self, tensors): """Get input tensors for the inference model.""" + # from scvi.data._constants import ADATA_MINIFY_TYPE + # TODO: ADD MINIFICATION CONSIDERATION + x = tensors[REGISTRY_KEYS.X_KEY] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) diff --git a/src/scvi/module/_totalvae.py b/src/scvi/module/_totalvae.py index d3fb5488da..ba54ec0b6f 100644 --- a/src/scvi/module/_totalvae.py +++ b/src/scvi/module/_totalvae.py @@ -18,7 +18,7 @@ ZeroInflatedNegativeBinomial, ) from scvi.model.base import BaseModelClass -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderTOTALVI, EncoderTOTALVI from scvi.nn._utils import ExpActivation @@ -26,7 +26,7 @@ # VAE model -class TOTALVAE(BaseModuleClass): +class TOTALVAE(BaseMinifiedModeModuleClass): """Total variational inference for CITE-seq data. Implements the totalVI model of :cite:p:`GayosoSteier21`. @@ -325,6 +325,9 @@ def get_reconstruction_loss( return reconst_loss_gene, reconst_loss_protein def _get_inference_input(self, tensors): + # from scvi.data._constants import ADATA_MINIFY_TYPE + # TODO: ADD MINIFICATION CONSIDERATION + x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py new file mode 100644 index 0000000000..fab87eb2a5 --- /dev/null +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -0,0 +1,358 @@ +import numpy as np +import pytest + +import scvi +from scvi.data import synthetic_iid +from scvi.data._constants import ADATA_MINIFY_TYPE +from scvi.data._utils import _is_minified +from scvi.model import MULTIVI, TOTALVI + +_TOTALVI_OBSERVED_LIB_SIZE = "_totalvi_observed_lib_size" +_MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size" + + +def prep_model_mudata(cls=TOTALVI, layer=None, use_size_factor=False): + # create a synthetic dataset + mdata = synthetic_iid(return_mudata=True) + if use_size_factor: + mdata.obs["size_factor"] = np.random.randint(1, 5, size=(mdata.shape[0],)) + if layer is not None: + for mod in mdata.mod_names: + mdata[mod].layers[layer] = mdata[mod].X.copy() + mdata[mod].X = np.zeros_like(mdata[mod].X) + mdata.var["n_counts"] = np.squeeze( + np.concatenate( + [ + np.asarray(np.sum(mdata["rna"].X, axis=0)), + np.asarray(np.sum(mdata["protein_expression"].X, axis=0)), + np.asarray(np.sum(mdata["accessibility"].X, axis=0)), + ] + ) + ) + mdata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(mdata.shape[1], 3)) + mdata["rna"].layers["my_layer"] = np.ones_like(mdata["rna"].X) + mdata_before_setup = mdata.copy() + + # run setup_anndata + setup_kwargs = { + "batch_key": "batch", + } + if use_size_factor: + setup_kwargs["size_factor_key"] = "size_factor" + + if cls == TOTALVI: + # create and train the model + cls.setup_mudata( + mdata, + modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, + **setup_kwargs, + ) + model = cls(mdata, n_latent=5) + elif cls == MULTIVI: + # create and train the model + cls.setup_mudata( + mdata, + modalities={ + "rna_layer": "rna", + "protein_layer": "protein_expression", + "atac_layer": "accessibility", + }, + **setup_kwargs, + ) + model = cls(mdata, n_latent=5, n_genes=50, n_regions=50) + else: + raise ValueError("Bad Model name as input to test") + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + + # get the mdata lib size + mdata_lib_size = np.squeeze(np.asarray(mdata["rna"].X.sum(axis=1))) + assert ( + np.min(mdata_lib_size) > 0 + ) # make sure it's not all zeros and there are no negative values + + return model, mdata, mdata_lib_size, mdata_before_setup + + +def assert_approx_equal(a, b): + # Allclose because on GPU, the values are not exactly the same + # as some values are moved to cpu during data minification + np.testing.assert_allclose(a, b, rtol=3e-1, atol=5e-1) + + +def run_test_for_model_with_minified_mudata( + cls=TOTALVI, + layer: str = None, + use_size_factor=False, +): + model, mdata, mdata_lib_size, _ = prep_model_mudata(cls, layer, use_size_factor) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + scvi.settings.seed = 1 + mdata_orig = mdata.copy() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + assert model.adata_manager.registry is model.registry_ + + # make sure the original mdata we set up the model with was not changed + assert mdata is not model.adata + assert _is_minified(mdata) is False + assert _is_minified(model.adata) is True + + assert mdata_orig["rna"].layers.keys() == model.adata["rna"].layers.keys() + orig_obs_df = mdata_orig.obs + obs_keys = _TOTALVI_OBSERVED_LIB_SIZE if cls == TOTALVI else _MULTIVI_OBSERVED_LIB_SIZE + orig_obs_df[obs_keys] = mdata_lib_size + assert model.adata.obs.equals(orig_obs_df) + assert model.adata.var_names.equals(mdata_orig.var_names) + assert model.adata.var.equals(mdata_orig.var) + assert model.adata.varm.keys() == mdata_orig.varm.keys() + np.testing.assert_array_equal(model.adata.varm["my_varm"], mdata_orig.varm["my_varm"]) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +@pytest.mark.parametrize("layer", [None, "data_layer"]) +@pytest.mark.parametrize("use_size_factor", [False, True]) +def test_with_minified_mudata(cls, layer: str, use_size_factor: bool): + run_test_for_model_with_minified_mudata(cls=cls, layer=layer, use_size_factor=use_size_factor) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_get_normalized_expression(cls): + model, mdata, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + exprs_orig = model.get_normalized_expression() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + exprs_new = model.get_normalized_expression() + for ii in range(len(exprs_new)): + assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape + + for ii in range(len(exprs_new)): + np.testing.assert_array_equal(exprs_new[ii], exprs_orig[ii]) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls): + model, mdata, _, _ = prep_model_mudata(cls=cls) + + # non-default gene list and n_samples > 1 + gl = mdata.var_names[:5].to_list() + n_samples = 10 + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + exprs_orig = model.get_normalized_expression( + gene_list=gl, n_samples=n_samples, library_size="latent" + ) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + # do this so that we generate the same sequence of random numbers in the + # minified and non-minified cases (purely to get the tests to pass). this is + # because in the non-minified case we sample once more (in the call to z_encoder + # during inference) + exprs_new = model.get_normalized_expression( + gene_list=gl, n_samples=n_samples + 1, return_mean=False, library_size="latent" + ) + exprs_new = exprs_new[0][:, :, 1:].mean(2) + + assert exprs_new.shape == (mdata.shape[0], 5) + np.testing.assert_allclose(exprs_new, exprs_orig[0], rtol=3e-1, atol=5e-1) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_validate_unsupported_if_minified(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + common_err_msg = "The {} function currently does not support minified data." + + with pytest.raises(ValueError) as e: + model.get_elbo() + assert str(e.value) == common_err_msg.format("VAEMixin.get_elbo") + + with pytest.raises(ValueError) as e: + model.get_reconstruction_error() + assert str(e.value) == common_err_msg.format("VAEMixin.get_reconstruction_error") + + with pytest.raises(ValueError) as e: + model.get_marginal_ll() + assert str(e.value) == common_err_msg.format("VAEMixin.get_marginal_ll") + + if cls != TOTALVI: + with pytest.raises(AttributeError) as e: + model.get_latent_library_size() + assert str(e.value) == "'MULTIVI' object has no attribute 'get_latent_library_size'" + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_save_then_load(cls, save_path): + # create a model and minify its mdata, then save it and its mdata. + # Load it back up using the same (minified) mdata. Validate that the + # loaded model has the minified_data_type attribute set as expected. + model, mdata, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + # load saved model with saved (minified) mdata + loaded_model = cls.load(save_path, adata=mdata) + + assert loaded_model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path): + # create a model and minify its mdata, then save it and its mdata. + # Load it back up using a non-minified mdata. Validate that the + # loaded model does not has the minified_data_type attribute set. + model, mdata, _, mdata_before_setup = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + # load saved model with a non-minified mdata + loaded_model = cls.load(save_path, adata=mdata_before_setup) + + assert loaded_model.minified_data_type is None + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_save_then_load_with_minified_mdata(cls, save_path): + # create a model, then save it and its mdata (non-minified). + # Load it back up using a minified mdata. Validate that this + # fails, as expected because we don't have a way to validate + # whether the minified-mdata was set up correctly + model, _, _, _ = prep_model_mudata(cls=cls) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + # loading this model with a minified mdata is not allowed because + # we don't have a way to validate whether the minified-mdata was + # set up correctly + with pytest.raises(KeyError): + cls.load(save_path, adata=model.adata) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_get_latent_representation(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + latent_repr_orig = model.get_latent_representation() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + latent_repr_new = model.get_latent_representation() + + np.testing.assert_array_equal(latent_repr_new, latent_repr_orig) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_scvi_with_minified_mdata_posterior_predictive_sample(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + sample_orig = model.posterior_predictive_sample( + indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] + ) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + sample_new = model.posterior_predictive_sample( + indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] + ) + assert sample_new.shape == (3, 2) + + np.testing.assert_array_equal(sample_new.todense(), sample_orig.todense()) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_scvi_with_minified_mdata_get_feature_correlation_matrix(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + fcm_orig = model.get_feature_correlation_matrix( + correlation_type="pearson", + n_samples=1, + transform_batch=["batch_0", "batch_1"], + ) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + fcm_new = model.get_feature_correlation_matrix( + correlation_type="pearson", + n_samples=1, + transform_batch=["batch_0", "batch_1"], + ) + + assert_approx_equal(fcm_new, fcm_orig) From e9b72d96004e4a1088ceb30d2ec18d6894651f0c Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 12 Nov 2024 19:24:31 +0200 Subject: [PATCH 19/51] fix typos --- src/scvi/model/_multivi.py | 125 ++++++++++++++++++ .../test_models_with_mudata_minified_data.py | 41 ++---- 2 files changed, 134 insertions(+), 32 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 0e625e9ee1..612d8a7c03 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -2,6 +2,7 @@ import logging import warnings +from collections.abc import Iterable as IterableClass from functools import partial from typing import TYPE_CHECKING @@ -960,6 +961,130 @@ def differential_expression( return result + @torch.no_grad() + def get_protein_foreground_probability( + self, + adata: AnnData | None = None, + indices: Sequence[int] | None = None, + transform_batch: Sequence[Number | str] | None = None, + protein_list: Sequence[str] | None = None, + n_samples: int = 1, + batch_size: int | None = None, + use_z_mean: bool = True, + return_mean: bool = True, + return_numpy: bool | None = None, + ): + r"""Returns the foreground probability for proteins. + + This is denoted as :math:`(1 - \pi_{nt})` in the totalVI paper. + + Parameters + ---------- + adata + AnnData object with equivalent structure to initial AnnData. If ``None``, defaults to + the AnnData object used to initialize the model. + indices + Indices of cells in adata to use. If `None`, all cells are used. + transform_batch + Batch to condition on. + If transform_batch is: + + * ``None`` - real observed batch is used + * ``int`` - batch transform_batch is used + * ``List[int]`` - average over batches in list + protein_list + Return protein expression for a subset of genes. + This can save memory when working with large datasets and few genes are + of interest. + n_samples + Number of posterior samples to use for estimation. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_mean + Whether to return the mean of the samples. + return_numpy + Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame + includes gene names as columns. If either ``n_samples=1`` or ``return_mean=True``, + defaults to ``False``. Otherwise, it defaults to `True`. + + Returns + ------- + - **foreground_probability** - probability foreground for each protein + + If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. + Otherwise, shape is `(cells, genes)`. In this case, return type is + :class:`~pandas.DataFrame` unless `return_numpy` is True. + """ + adata = self._validate_anndata(adata) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + + if protein_list is None: + protein_mask = slice(None) + else: + all_proteins = self.scvi_setup_dict_["protein_names"] + protein_mask = [True if p in protein_list else False for p in all_proteins] + + if n_samples > 1 and return_mean is False: + if return_numpy is False: + warnings.warn( + "`return_numpy` must be `True` if `n_samples > 1` and `return_mean` is " + "`False`, returning an `np.ndarray`.", + UserWarning, + stacklevel=settings.warnings_stacklevel, + ) + return_numpy = True + if indices is None: + indices = np.arange(adata.n_obs) + + py_mixings = [] + if not isinstance(transform_batch, IterableClass): + transform_batch = [transform_batch] + + transform_batch = _get_batch_code_from_category(self.adata_manager, transform_batch) + for tensors in post: + y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] + py_mixing = torch.zeros_like(y[..., protein_mask]) + if n_samples > 1: + py_mixing = torch.stack(n_samples * [py_mixing]) + for _ in transform_batch: + # generative_kwargs = dict(transform_batch=b) + generative_kwargs = {"use_z_mean": use_z_mean} + inference_kwargs = {"n_samples": n_samples} + _, generative_outputs = self.module.forward( + tensors=tensors, + inference_kwargs=inference_kwargs, + generative_kwargs=generative_kwargs, + compute_loss=False, + ) + py_mixing += torch.sigmoid(generative_outputs["py_"]["mixing"])[ + ..., protein_mask + ].cpu() + py_mixing /= len(transform_batch) + py_mixings += [py_mixing] + if n_samples > 1: + # concatenate along batch dimension -> result shape = (samples, cells, features) + py_mixings = torch.cat(py_mixings, dim=1) + # (cells, features, samples) + py_mixings = py_mixings.permute(1, 2, 0) + else: + py_mixings = torch.cat(py_mixings, dim=0) + + if return_mean is True and n_samples > 1: + py_mixings = torch.mean(py_mixings, dim=-1) + + py_mixings = py_mixings.cpu().numpy() + + if return_numpy is True: + return 1 - py_mixings + else: + pro_names = self.protein_state_registry.column_names + foreground_prob = pd.DataFrame( + 1 - py_mixings, + columns=pro_names[protein_mask], + index=adata.obs_names[indices], + ) + return foreground_prob + @classmethod @setup_anndata_dsp.dedent def setup_anndata( diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py index fab87eb2a5..9aac0b35ba 100644 --- a/tests/model/test_models_with_mudata_minified_data.py +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -1,7 +1,6 @@ import numpy as np import pytest -import scvi from scvi.data import synthetic_iid from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified @@ -86,11 +85,10 @@ def run_test_for_model_with_minified_mudata( ): model, mdata, mdata_lib_size, _ = prep_model_mudata(cls, layer, use_size_factor) - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 + mdata_orig = mdata.copy() model.minify_mudata() @@ -121,21 +119,18 @@ def test_with_minified_mudata(cls, layer: str, use_size_factor: bool): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_scvi_with_minified_mdata_get_normalized_expression(cls): +def test_with_minified_mdata_get_normalized_expression(cls): model, mdata, _, _ = prep_model_mudata(cls=cls) - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 exprs_orig = model.get_normalized_expression() model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - scvi.settings.seed = 1 exprs_new = model.get_normalized_expression() for ii in range(len(exprs_new)): assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape @@ -145,19 +140,17 @@ def test_scvi_with_minified_mdata_get_normalized_expression(cls): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_scvi_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls): +def test_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls): model, mdata, _, _ = prep_model_mudata(cls=cls) # non-default gene list and n_samples > 1 gl = mdata.var_names[:5].to_list() n_samples = 10 - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 exprs_orig = model.get_normalized_expression( gene_list=gl, n_samples=n_samples, library_size="latent" ) @@ -165,7 +158,6 @@ def test_scvi_with_minified_mdata_get_normalized_expression_non_default_gene_lis model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - scvi.settings.seed = 1 # do this so that we generate the same sequence of random numbers in the # minified and non-minified cases (purely to get the tests to pass). this is # because in the non-minified case we sample once more (in the call to z_encoder @@ -211,19 +203,16 @@ def test_validate_unsupported_if_minified(cls): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_scvi_with_minified_mdata_save_then_load(cls, save_path): +def test_with_minified_mdata_save_then_load(cls, save_path): # create a model and minify its mdata, then save it and its mdata. # Load it back up using the same (minified) mdata. Validate that the # loaded model has the minified_data_type attribute set as expected. model, mdata, _, _ = prep_model_mudata(cls=cls) - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 - model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR @@ -235,19 +224,16 @@ def test_scvi_with_minified_mdata_save_then_load(cls, save_path): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_scvi_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path): +def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path): # create a model and minify its mdata, then save it and its mdata. # Load it back up using a non-minified mdata. Validate that the # loaded model does not has the minified_data_type attribute set. model, mdata, _, mdata_before_setup = prep_model_mudata(cls=cls) - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 - model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR @@ -259,7 +245,7 @@ def test_scvi_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, sa @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_scvi_save_then_load_with_minified_mdata(cls, save_path): +def test_save_then_load_with_minified_mdata(cls, save_path): # create a model, then save it and its mdata (non-minified). # Load it back up using a minified mdata. Validate that this # fails, as expected because we don't have a way to validate @@ -283,36 +269,31 @@ def test_scvi_save_then_load_with_minified_mdata(cls, save_path): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_scvi_with_minified_mdata_get_latent_representation(cls): +def test_with_minified_mdata_get_latent_representation(cls): model, _, _, _ = prep_model_mudata(cls=cls) - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 latent_repr_orig = model.get_latent_representation() model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - scvi.settings.seed = 1 latent_repr_new = model.get_latent_representation() np.testing.assert_array_equal(latent_repr_new, latent_repr_orig) @pytest.mark.parametrize("cls", [TOTALVI]) -def test_scvi_with_minified_mdata_posterior_predictive_sample(cls): +def test_with_minified_mdata_posterior_predictive_sample(cls): model, _, _, _ = prep_model_mudata(cls=cls) - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 sample_orig = model.posterior_predictive_sample( indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] ) @@ -320,7 +301,6 @@ def test_scvi_with_minified_mdata_posterior_predictive_sample(cls): model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - scvi.settings.seed = 1 sample_new = model.posterior_predictive_sample( indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] ) @@ -330,15 +310,13 @@ def test_scvi_with_minified_mdata_posterior_predictive_sample(cls): @pytest.mark.parametrize("cls", [TOTALVI]) -def test_scvi_with_minified_mdata_get_feature_correlation_matrix(cls): +def test_with_minified_mdata_get_feature_correlation_matrix(cls): model, _, _, _ = prep_model_mudata(cls=cls) - scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - scvi.settings.seed = 1 fcm_orig = model.get_feature_correlation_matrix( correlation_type="pearson", n_samples=1, @@ -348,7 +326,6 @@ def test_scvi_with_minified_mdata_get_feature_correlation_matrix(cls): model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - scvi.settings.seed = 1 fcm_new = model.get_feature_correlation_matrix( correlation_type="pearson", n_samples=1, From 371ef7abc89cad184e38fbb771f428cf41364ef5 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 13 Nov 2024 12:34:33 +0200 Subject: [PATCH 20/51] fixed comments --- src/scvi/model/_multivi.py | 169 +++++++++++++++++++++++++++++++++---- src/scvi/model/_totalvi.py | 6 ++ 2 files changed, 157 insertions(+), 18 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 0bc5dfe517..9525cdade5 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -2,6 +2,7 @@ import logging import warnings +from collections.abc import Iterable as IterableClass from functools import partial from typing import TYPE_CHECKING @@ -117,13 +118,15 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): -------- >>> adata_rna = anndata.read_h5ad(path_to_rna_anndata) >>> adata_atac = scvi.data.read_10x_atac(path_to_atac_anndata) - >>> adata_multi = scvi.data.read_10x_multiome(path_to_multiomic_anndata) - >>> adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna, adata_atac) - >>> scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key="modality") - >>> vae = scvi.model.MULTIVI(adata_mvi) + >>> adata_protein = anndata.read_h5ad(path_to_protein_anndata) + >>> mdata = MuData({"rna": adata_rna, "protein": adata_protein, "atac": adata_atac}) + >>> scvi.model.MULTIVI.setup_mudata(mdata, batch_key="batch", + >>> modalities={"rna_layer": "rna", "protein_layer": "protein", "batch_key": "rna", + >>> "atac_layer": "atac"}) + >>> vae = scvi.model.MULTIVI(mdata) >>> vae.train() - Notes + Notes (for using setup_anndata) ----- * The model assumes that the features are organized so that all expression features are consecutive, followed by all accessibility features. For example, if the data has 100 genes @@ -360,7 +363,7 @@ def train( @torch.inference_mode() def get_library_size_factors( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, indices: Sequence[int] = None, batch_size: int = 128, ) -> dict[str, np.ndarray]: @@ -369,8 +372,8 @@ def get_library_size_factors( Parameters ---------- adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults + to the AnnOrMuData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. batch_size @@ -409,7 +412,7 @@ def get_region_factors(self) -> np.ndarray: @torch.inference_mode() def get_latent_representation( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, modality: Literal["joint", "expression", "accessibility"] = "joint", indices: Sequence[int] | None = None, give_mean: bool = True, @@ -420,8 +423,8 @@ def get_latent_representation( Parameters ---------- adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults + to the AnnOrMuData object used to initialize the model. modality Return modality specific or joint latent representation. indices @@ -479,7 +482,7 @@ def get_latent_representation( @torch.inference_mode() def get_accessibility_estimates( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, indices: Sequence[int] = None, n_samples_overall: int | None = None, region_list: Sequence[str] | None = None, @@ -500,8 +503,8 @@ def get_accessibility_estimates( Parameters ---------- adata - AnnData object that has been registered with scvi. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object that has been registered with scvi. If `None`, defaults to the + AnnOrMuData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. n_samples_overall @@ -590,14 +593,14 @@ def get_accessibility_estimates( imputed, index=adata.obs_names[indices], columns=adata["rna"].var_names[self.n_genes :][region_mask] - if type(adata).__name__ == "MuData" + if isinstance(adata, MuData) else adata.var_names[self.n_genes :][region_mask], ) @torch.inference_mode() def get_normalized_expression( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, indices: Sequence[int] | None = None, n_samples_overall: int | None = None, transform_batch: Sequence[Number | str] | None = None, @@ -615,8 +618,8 @@ def get_normalized_expression( Parameters ---------- adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults + to the AnnOrMuData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. n_samples_overall @@ -928,6 +931,130 @@ def differential_expression( return result + @torch.no_grad() + def get_protein_foreground_probability( + self, + adata: AnnOrMuData | None = None, + indices: Sequence[int] | None = None, + transform_batch: Sequence[Number | str] | None = None, + protein_list: Sequence[str] | None = None, + n_samples: int = 1, + batch_size: int | None = None, + use_z_mean: bool = True, + return_mean: bool = True, + return_numpy: bool | None = None, + ): + r"""Returns the foreground probability for proteins. + + This is denoted as :math:`(1 - \pi_{nt})` in the totalVI paper. + + Parameters + ---------- + adata + AnnOrMuData object with equivalent structure to initial AnnData. If ``None``, defaults + to the AnnOrMuData object used to initialize the model. + indices + Indices of cells in adata to use. If `None`, all cells are used. + transform_batch + Batch to condition on. + If transform_batch is: + + * ``None`` - real observed batch is used + * ``int`` - batch transform_batch is used + * ``List[int]`` - average over batches in list + protein_list + Return protein expression for a subset of genes. + This can save memory when working with large datasets and few genes are + of interest. + n_samples + Number of posterior samples to use for estimation. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_mean + Whether to return the mean of the samples. + return_numpy + Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame + includes gene names as columns. If either ``n_samples=1`` or ``return_mean=True``, + defaults to ``False``. Otherwise, it defaults to `True`. + + Returns + ------- + - **foreground_probability** - probability foreground for each protein + + If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`. + Otherwise, shape is `(cells, genes)`. In this case, return type is + :class:`~pandas.DataFrame` unless `return_numpy` is True. + """ + adata = self._validate_anndata(adata) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + + if protein_list is None: + protein_mask = slice(None) + else: + all_proteins = self.scvi_setup_dict_["protein_names"] + protein_mask = [True if p in protein_list else False for p in all_proteins] + + if n_samples > 1 and return_mean is False: + if return_numpy is False: + warnings.warn( + "`return_numpy` must be `True` if `n_samples > 1` and `return_mean` is " + "`False`, returning an `np.ndarray`.", + UserWarning, + stacklevel=settings.warnings_stacklevel, + ) + return_numpy = True + if indices is None: + indices = np.arange(adata.n_obs) + + py_mixings = [] + if not isinstance(transform_batch, IterableClass): + transform_batch = [transform_batch] + + transform_batch = _get_batch_code_from_category(self.adata_manager, transform_batch) + for tensors in post: + y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] + py_mixing = torch.zeros_like(y[..., protein_mask]) + if n_samples > 1: + py_mixing = torch.stack(n_samples * [py_mixing]) + for _ in transform_batch: + # generative_kwargs = dict(transform_batch=b) + generative_kwargs = {"use_z_mean": use_z_mean} + inference_kwargs = {"n_samples": n_samples} + _, generative_outputs = self.module.forward( + tensors=tensors, + inference_kwargs=inference_kwargs, + generative_kwargs=generative_kwargs, + compute_loss=False, + ) + py_mixing += torch.sigmoid(generative_outputs["py_"]["mixing"])[ + ..., protein_mask + ].cpu() + py_mixing /= len(transform_batch) + py_mixings += [py_mixing] + if n_samples > 1: + # concatenate along batch dimension -> result shape = (samples, cells, features) + py_mixings = torch.cat(py_mixings, dim=1) + # (cells, features, samples) + py_mixings = py_mixings.permute(1, 2, 0) + else: + py_mixings = torch.cat(py_mixings, dim=0) + + if return_mean is True and n_samples > 1: + py_mixings = torch.mean(py_mixings, dim=-1) + + py_mixings = py_mixings.cpu().numpy() + + if return_numpy is True: + return 1 - py_mixings + else: + pro_names = self.protein_state_registry.column_names + foreground_prob = pd.DataFrame( + 1 - py_mixings, + columns=pro_names[protein_mask], + index=adata.obs_names[indices], + ) + return foreground_prob + @classmethod @setup_anndata_dsp.dedent def setup_anndata( @@ -959,6 +1086,12 @@ def setup_anndata( `adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign sequential names to proteins. """ + warnings.warn( + "MULTIVI is suppose to work with MuData. the use of anndata is " + "deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata", + DeprecationWarning, + stacklevel=settings.warnings_stacklevel, + ) setup_method_args = cls._get_setup_method_args(**locals()) adata.obs["_indices"] = np.arange(adata.n_obs) batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index a41123af72..05681942d1 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -1215,6 +1215,12 @@ def setup_anndata( ------- %(returns)s """ + warnings.warn( + "TOTALVI is suppose to work with MuData. the use of anndata is " + "deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata", + DeprecationWarning, + stacklevel=settings.warnings_stacklevel, + ) setup_method_args = cls._get_setup_method_args(**locals()) batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) anndata_fields = [ From 3cf01086b4983523b3b67f9c274655c077bf4a47 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 13 Nov 2024 12:43:44 +0200 Subject: [PATCH 21/51] fixed comments --- src/scvi/model/_multivi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 9525cdade5..4fd8dbe968 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -6,6 +6,7 @@ from functools import partial from typing import TYPE_CHECKING +from mudata import MuData import numpy as np import pandas as pd import torch @@ -44,7 +45,6 @@ from typing import Literal from anndata import AnnData - from mudata import MuData from scvi._types import AnnOrMuData, Number From ff82b284a2bc1d35f8e5ffe7166f65899b6dce4b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:44:00 +0000 Subject: [PATCH 22/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_multivi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 4fd8dbe968..94cabdbdf9 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -6,10 +6,10 @@ from functools import partial from typing import TYPE_CHECKING -from mudata import MuData import numpy as np import pandas as pd import torch +from mudata import MuData from scipy.sparse import csr_matrix, vstack from torch.distributions import Normal From b6d2028648376e2e233a7273f0e7c5f5f17cd4aa Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 14 Nov 2024 11:32:41 +0200 Subject: [PATCH 23/51] fixed typos --- src/scvi/model/_multivi.py | 4 ++-- src/scvi/model/_totalvi.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 94cabdbdf9..cbf0e657f8 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -1087,8 +1087,8 @@ def setup_anndata( sequential names to proteins. """ warnings.warn( - "MULTIVI is suppose to work with MuData. the use of anndata is " - "deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata", + "MULTIVI is supposed to work with MuData. the use of anndata is " + "deprecated and will be removed in scvi-tools 1.4. Please use setup_mudata", DeprecationWarning, stacklevel=settings.warnings_stacklevel, ) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index 05681942d1..0af51a49d1 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -1216,8 +1216,7 @@ def setup_anndata( %(returns)s """ warnings.warn( - "TOTALVI is suppose to work with MuData. the use of anndata is " - "deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata", + "TOTALVI is supposed to work with MuData.", DeprecationWarning, stacklevel=settings.warnings_stacklevel, ) From a9033592d89fdfe669c7dceaac4cfc8a8d03183c Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 18 Nov 2024 11:51:31 +0200 Subject: [PATCH 24/51] fix comments --- CHANGELOG.md | 6 ++++-- pyproject.toml | 5 ++--- tests/model/test_multivi.py | 37 ++++++------------------------------- 3 files changed, 12 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84cb265411..6029239362 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ Starting from version 0.20.1, this format is based on [Keep a Changelog], and th to [Semantic Versioning]. Full commit history is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/). -## Version 1.2 +## Version 1.3 ### 1.3.0 (2024-XX-XX) @@ -19,11 +19,13 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`. +## Version 1.2 + ### 1.2.1 (2024-XX-XX) #### Added -- Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method +- MuData support for {class}`~scvi.model.MULTIVI` via the method {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`. #### Fixed diff --git a/pyproject.toml b/pyproject.toml index bfd330362b..8478eb7d71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,8 +88,6 @@ census = ["cellxgene-census"] hub = ["huggingface_hub"] # scvi.model.utils.mde dependencies pymde = ["pymde"] -# mudata dependencies -muon = ["muon"] # scvi.data.add_dna_sequence regseq = ["biopython>=1.81", "genomepy"] # read loom @@ -98,12 +96,13 @@ loompy = ["loompy>=3.0.6"] scanpy = ["scanpy>=1.6","scikit-misc"] optional = [ - "scvi-tools[autotune,aws,hub,loompy,muon,pymde,regseq,scanpy]" + "scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]" ] tutorials = [ "cell2location", "jupyter", "leidenalg", + "muon", "plotnine", "pooch", "pynndescent", diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index ac96bbbd7c..70aecb01c4 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -1,7 +1,6 @@ import os import anndata as ad -import muon import numpy as np import pytest import scanpy as sc @@ -126,11 +125,7 @@ def test_multivi_mudata_rna_prot_external(): def test_multivi_mudata_rna_atac_external(): # optional data - mudata RNA/ATAC - url = ( - "https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X" - "/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5" - ) - mdata = muon.read_10x_h5("data/multiome10k.h5mu", backup_url=url) + mdata = synthetic_iid(return_mudata=True) # Preprocessing sc.pp.normalize_total(mdata.mod["rna"]) sc.pp.log1p(mdata.mod["rna"]) @@ -140,15 +135,15 @@ def test_multivi_mudata_rna_atac_external(): flavor="seurat_v3", ) mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy() - sc.pp.normalize_total(mdata.mod["atac"]) - sc.pp.log1p(mdata.mod["atac"]) + sc.pp.normalize_total(mdata.mod["accessibility"]) + sc.pp.log1p(mdata.mod["accessibility"]) sc.pp.highly_variable_genes( - mdata.mod["atac"], + mdata.mod["accessibility"], n_top_genes=4000, flavor="seurat_v3", ) - mdata.mod["atac_subset"] = mdata.mod["atac"][ - :, mdata.mod["atac"].var["highly_variable"] + mdata.mod["atac_subset"] = mdata.mod["accessibility"][ + :, mdata.mod["accessibility"].var["highly_variable"] ].copy() mdata.update() # mdata @@ -163,21 +158,6 @@ def test_multivi_mudata_rna_atac_external(): def test_multivi_mudata(): # use of syntetic data of rna/proteins/atac for speed - # adata = synthetic_iid() - # protein_adata = synthetic_iid() - # atac_adata = synthetic_iid() - # mdata = MuData({"rna": adata, "protein": protein_adata, "atac": atac_adata}) - # MULTIVI.setup_mudata( - # mdata, - # batch_key="batch", - # modalities={"rna_layer": "rna", "protein_layer": "protein", "batch_key": "rna", - # "atac_layer": "atac"}, - # ) - # n_obs = mdata.n_obs - # n_genes = np.min([adata.n_vars, protein_adata.n_vars]) - # n_regions = protein_adata.X.shape[1] - # n_latent = 10 - mdata = synthetic_iid(return_mudata=True) MULTIVI.setup_mudata( mdata, @@ -189,8 +169,6 @@ def test_multivi_mudata(): }, ) n_obs = mdata.n_obs - # n_genes = np.min([mdata.n_vars, mdata["protein_expression"].n_vars]) - # n_regions = mdata["protein_expression"].X.shape[1] n_latent = 10 model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) @@ -216,9 +194,6 @@ def test_multivi_mudata(): model.get_library_size_factors() model.get_region_factors() - # adata2 = synthetic_iid() - # protein_adata2 = synthetic_iid(n_genes=50) - # mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) mdata2 = synthetic_iid(return_mudata=True) MULTIVI.setup_mudata( mdata2, From 8aecdf81e4bfad1cd17b085d4ad03c77c6593871 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 18 Nov 2024 17:48:31 +0200 Subject: [PATCH 25/51] fix some tests --- src/scvi/model/_multivi.py | 4 +- .../test_models_with_mudata_minified_data.py | 64 ++++++++++--------- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 919914198f..9b053d97ef 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -1185,10 +1185,10 @@ def setup_mudata( %(param_mdata)s rna_layer RNA layer key. If `None`, will use `.X` of specified modality key. - protein_layer - Protein layer key. If `None`, will use `.X` of specified modality key. atac_layer ATAC layer key. If `None`, will use `.X` of specified modality key. + protein_layer + Protein layer key. If `None`, will use `.X` of specified modality key. %(param_batch_key)s %(param_size_factor_key)s %(param_cat_cov_keys)s diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py index 9aac0b35ba..63103cdaa3 100644 --- a/tests/model/test_models_with_mudata_minified_data.py +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -1,6 +1,7 @@ import numpy as np import pytest +# from mudata import MuData from scvi.data import synthetic_iid from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified @@ -10,15 +11,19 @@ _MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size" -def prep_model_mudata(cls=TOTALVI, layer=None, use_size_factor=False): +def prep_model_mudata(cls=TOTALVI, use_size_factor=False): # create a synthetic dataset + # adata = synthetic_iid() + # protein_adata = synthetic_iid(n_genes=50) + # mdata = MuData({"rna": adata, "protein_expression": protein_adata, + # "accessibility": synthetic_iid()}) mdata = synthetic_iid(return_mudata=True) if use_size_factor: mdata.obs["size_factor"] = np.random.randint(1, 5, size=(mdata.shape[0],)) - if layer is not None: - for mod in mdata.mod_names: - mdata[mod].layers[layer] = mdata[mod].X.copy() - mdata[mod].X = np.zeros_like(mdata[mod].X) + # if layer is not None: + # for mod in mdata.mod_names: + # mdata[mod].layers[layer] = mdata[mod].X.copy() + # mdata[mod].X = np.zeros_like(mdata[mod].X) mdata.var["n_counts"] = np.squeeze( np.concatenate( [ @@ -29,7 +34,7 @@ def prep_model_mudata(cls=TOTALVI, layer=None, use_size_factor=False): ) ) mdata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(mdata.shape[1], 3)) - mdata["rna"].layers["my_layer"] = np.ones_like(mdata["rna"].X) + # mdata["rna"].layers["my_layer"] = np.ones_like(mdata["rna"].X) mdata_before_setup = mdata.copy() # run setup_anndata @@ -64,7 +69,10 @@ def prep_model_mudata(cls=TOTALVI, layer=None, use_size_factor=False): model.train(1, check_val_every_n_epoch=1, train_size=0.5) # get the mdata lib size - mdata_lib_size = np.squeeze(np.asarray(mdata["rna"].X.sum(axis=1))) + if cls == TOTALVI: + mdata_lib_size = np.squeeze(np.asarray(mdata["rna"].X.sum(axis=1))) # TOTALVI + else: + mdata_lib_size = np.squeeze(np.asarray(mdata["accessibility"].X.sum(axis=1))) # MULTIVI assert ( np.min(mdata_lib_size) > 0 ) # make sure it's not all zeros and there are no negative values @@ -80,10 +88,9 @@ def assert_approx_equal(a, b): def run_test_for_model_with_minified_mudata( cls=TOTALVI, - layer: str = None, use_size_factor=False, ): - model, mdata, mdata_lib_size, _ = prep_model_mudata(cls, layer, use_size_factor) + model, mdata, mdata_lib_size, _ = prep_model_mudata(cls, use_size_factor) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -112,10 +119,9 @@ def run_test_for_model_with_minified_mudata( @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -@pytest.mark.parametrize("layer", [None, "data_layer"]) @pytest.mark.parametrize("use_size_factor", [False, True]) -def test_with_minified_mudata(cls, layer: str, use_size_factor: bool): - run_test_for_model_with_minified_mudata(cls=cls, layer=layer, use_size_factor=use_size_factor) +def test_with_minified_mudata(cls, use_size_factor: bool): + run_test_for_model_with_minified_mudata(cls=cls, use_size_factor=use_size_factor) @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) @@ -132,11 +138,14 @@ def test_with_minified_mdata_get_normalized_expression(cls): assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR exprs_new = model.get_normalized_expression() - for ii in range(len(exprs_new)): - assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape - - for ii in range(len(exprs_new)): - np.testing.assert_array_equal(exprs_new[ii], exprs_orig[ii]) + if type(exprs_new) is tuple: + for ii in range(len(exprs_new)): + assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape + for ii in range(len(exprs_new)): + np.testing.assert_array_equal(exprs_new[ii], exprs_orig[ii]) + else: + assert exprs_new.shape == mdata[mdata.mod_names].shape + np.testing.assert_array_equal(exprs_new, exprs_orig) @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) @@ -216,7 +225,8 @@ def test_with_minified_mdata_save_then_load(cls, save_path): model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + model.save(save_path, overwrite=True, save_anndata=True, legacy_mudata_format=True) + model.view_setup_args(save_path) # load saved model with saved (minified) mdata loaded_model = cls.load(save_path, adata=mdata) @@ -294,19 +304,15 @@ def test_with_minified_mdata_posterior_predictive_sample(cls): model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv - sample_orig = model.posterior_predictive_sample( - indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] - ) + sample_orig = model.posterior_predictive_sample() model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - sample_new = model.posterior_predictive_sample( - indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] - ) - assert sample_new.shape == (3, 2) + sample_new = model.posterior_predictive_sample() + # assert sample_new.shape == (3, 2) - np.testing.assert_array_equal(sample_new.todense(), sample_orig.todense()) + np.testing.assert_array_equal(sample_new, sample_orig) @pytest.mark.parametrize("cls", [TOTALVI]) @@ -318,8 +324,7 @@ def test_with_minified_mdata_get_feature_correlation_matrix(cls): model.adata.obsm["X_latent_qzv"] = qzv fcm_orig = model.get_feature_correlation_matrix( - correlation_type="pearson", - n_samples=1, + correlation_type="spearman", transform_batch=["batch_0", "batch_1"], ) @@ -327,8 +332,7 @@ def test_with_minified_mdata_get_feature_correlation_matrix(cls): assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR fcm_new = model.get_feature_correlation_matrix( - correlation_type="pearson", - n_samples=1, + correlation_type="spearman", transform_batch=["batch_0", "batch_1"], ) From 7b5f22fe3969895501ab560882cd6e93e653b150 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 09:11:26 +0000 Subject: [PATCH 26/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 250cff41df..634027d40b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1600,7 +1600,7 @@ will be compatible with this and future versions. Also, we dropped support for P #### Enhancements -##### Online updates of {class}`~scvi.model.SCVI`, {class}`~scvi.model.SCANVI`, and {class}`~scvi.model.TOTALVI` with the scArches method +##### Online updates of {class}`~scvi.model.SCVI`, {class}`~scvi.model.SCANVI`, and {class}`~scvi.model.TOTALVI` with the scArches method It is now possible to iteratively update these models with new samples, without altering the model for the "reference" population. Here we use the @@ -1662,7 +1662,7 @@ use scvi-tools is with our documentation and tutorials. - New high-level API and data loading, please see tutorials and examples for usage. - `GeneExpressionDataset` and associated classes have been removed. - Built-in datasets now return `AnnData` objects. -- `scvi-tools` now relies entirely on the \[AnnData\] format. +- `scvi-tools` now relies entirely on the [AnnData] format. - `scvi.models` has been moved to `scvi.core.module`. - `Posterior` classes have been reduced to wrappers on `DataLoaders` - `scvi.inference` has been split to `scvi.core.data_loaders` for `AnnDataLoader` classes and From 2a9975113090b8a21386be185b603b8a64ffdc79 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 19 Nov 2024 13:09:57 +0200 Subject: [PATCH 27/51] added atac registry field --- src/scvi/_constants.py | 1 + src/scvi/model/_multivi.py | 6 +++--- tests/model/test_multivi.py | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/scvi/_constants.py b/src/scvi/_constants.py index f565dc9f4a..ec6e4e914d 100644 --- a/src/scvi/_constants.py +++ b/src/scvi/_constants.py @@ -3,6 +3,7 @@ class _REGISTRY_KEYS_NT(NamedTuple): X_KEY: str = "X" + ATAC_X_KEY: str = "atac" BATCH_KEY: str = "batch" SAMPLE_KEY: str = "sample" LABELS_KEY: str = "labels" diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 7fa3daa139..80b2f13ebd 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -1153,10 +1153,10 @@ def setup_mudata( %(param_mdata)s rna_layer RNA layer key. If `None`, will use `.X` of specified modality key. - protein_layer - Protein layer key. If `None`, will use `.X` of specified modality key. atac_layer ATAC layer key. If `None`, will use `.X` of specified modality key. + protein_layer + Protein layer key. If `None`, will use `.X` of specified modality key. %(param_batch_key)s %(param_size_factor_key)s %(param_cat_cov_keys)s @@ -1227,7 +1227,7 @@ def setup_mudata( if modalities.atac_layer is not None: mudata_fields.append( fields.MuDataLayerField( - REGISTRY_KEYS.X_KEY, + REGISTRY_KEYS.ATAC_X_KEY, atac_layer, mod_key=modalities.atac_layer, is_count_data=True, diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 70aecb01c4..3a05a773e0 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -331,7 +331,7 @@ def test_multivi_size_factor_mudata(): model.train(1, train_size=0.5) -def test_multivi_saving_and_loading_mudata(save_path: str = "."): +def test_multivi_saving_and_loading_mudata(save_path: str): adata = synthetic_iid() protein_adata = synthetic_iid(n_genes=50) mdata = MuData({"rna": adata, "protein": protein_adata}) @@ -392,7 +392,7 @@ def test_multivi_saving_and_loading_mudata(save_path: str = "."): ) -def test_scarches_mudata_prep_layer(save_path: str = "."): +def test_scarches_mudata_prep_layer(save_path: str): n_latent = 5 mdata1 = synthetic_iid(return_mudata=True) @@ -440,7 +440,7 @@ def test_scarches_mudata_prep_layer(save_path: str = "."): MULTIVI.load_query_data(mdata2, dir_path) -def test_multivi_save_load_mudata_format(save_path: str = "."): +def test_multivi_save_load_mudata_format(save_path: str): mdata = synthetic_iid(return_mudata=True, protein_expression_key="protein") invalid_mdata = mdata.copy() invalid_mdata.mod["protein"] = invalid_mdata.mod["protein"][:, :10].copy() From a3de92ea424cbb1da887dc374a9d17038415f001 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Tue, 19 Nov 2024 12:51:26 -0800 Subject: [PATCH 28/51] Refactored minified models --- CHANGELOG.md | 4 ++ src/scvi/model/_scanvi.py | 18 ++++----- src/scvi/model/_scvi.py | 7 +--- src/scvi/model/base/_base_model.py | 2 - src/scvi/model/base/_vaemixin.py | 2 +- src/scvi/model/utils/_minification.py | 10 ++++- src/scvi/module/_vae.py | 31 ++++++++------- src/scvi/module/base/_base_module.py | 8 ++-- src/scvi/train/_trainingplans.py | 8 +++- src/scvi/utils/_decorators.py | 4 +- tests/model/test_models_with_minified_data.py | 39 ++++++++++++++++++- 11 files changed, 91 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 315aae5ce7..82d7274c9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,10 @@ to [Semantic Versioning]. Full commit history is available in the validation set, if available. {pr}`3036`. - Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`. +- Support for minified mode while retaining counts to skip the encoder. +- New Trainingplan argument `update_only_decoder` to use stored latent codes and skip training of + the encoder. +- Refactored code for minified models. #### Fixed diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index ea1f6617e8..604d4bfd4e 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -13,6 +13,7 @@ from scvi.data import AnnDataManager from scvi.data._constants import ( _SETUP_ARGS_KEY, + ADATA_MINIFY_TYPE, ) from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute from scvi.data.fields import ( @@ -41,10 +42,6 @@ from ._scvi import SCVI -_SCANVI_LATENT_QZM = "_scanvi_latent_qzm" -_SCANVI_LATENT_QZV = "_scanvi_latent_qzv" -_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" - logger = logging.getLogger(__name__) @@ -104,6 +101,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): _module_cls = SCANVAE _training_plan_cls = SemiSupervisedTrainingPlan + _LATENT_QZM = "scanvi_latent_qzm" + _LATENT_QZV = "scanvi_latent_qzv" def __init__( self, @@ -212,17 +211,18 @@ def from_scvi_model( ) del scanvi_kwargs[k] - if scvi_model.minified_data_type is not None: + if scvi_model.minified_data_type==ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise ValueError( - "We cannot use the given scvi model to initialize scanvi because it has a " - "minified adata." + "We cannot use the given scVI model to initialize scANVI because it has " + "minified adata. Keep counts when minifying model using " + "minified_data_type='latent_posterior_parameters_with_counts'." ) if adata is None: adata = scvi_model.adata else: if _is_minified(adata): - raise ValueError("Please provide a non-minified `adata` to initialize scanvi.") + raise ValueError("Please provide a non-minified `adata` to initialize scANVI.") # validate new anndata against old model scvi_model._validate_anndata(adata) @@ -230,7 +230,7 @@ def from_scvi_model( scvi_labels_key = scvi_setup_args["labels_key"] if labels_key is None and scvi_labels_key is None: raise ValueError( - "A `labels_key` is necessary as the SCVI model was initialized without one." + "A `labels_key` is necessary as the scVI model was initialized without one." ) if scvi_labels_key is None: scvi_setup_args.update({"labels_key": labels_key}) diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index a27dd1bb12..36ad6f2c39 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -26,11 +26,6 @@ from anndata import AnnData - -_SCVI_LATENT_QZM = "_scvi_latent_qzm" -_SCVI_LATENT_QZV = "_scvi_latent_qzv" -_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" - logger = logging.getLogger(__name__) @@ -105,6 +100,8 @@ class SCVI( """ _module_cls = VAE + _SCVI_LATENT_QZM = "scvi_latent_qzm" + _SCVI_LATENT_QZV = "scvi_latent_qzv" def __init__( self, diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index a6b576387d..64b0594e29 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -984,8 +984,6 @@ def _get_fields_for_adata_minification( fields.NumericalObsField(REGISTRY_KEYS.OBSERVED_LIB_SIZE, cls._OBSERVED_LIB_SIZE_KEY), fields.StringUnsField(REGISTRY_KEYS.MINIFY_TYPE_KEY, _ADATA_MINIFY_TYPE_UNS_KEY), ] - if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS: - mini_fields.append(fields.LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True)) return mini_fields diff --git a/src/scvi/model/base/_vaemixin.py b/src/scvi/model/base/_vaemixin.py index 7de4fe43c9..d6099d02fa 100644 --- a/src/scvi/model/base/_vaemixin.py +++ b/src/scvi/model/base/_vaemixin.py @@ -167,7 +167,7 @@ def get_reconstruction_error( adata: AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, - dataloader: Iterator[dict[str, Tensor | None]] = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, return_mean: bool = True, **kwargs, ) -> dict[str, float]: diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index ce9c8bb4de..ea3bb55a58 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -18,9 +18,15 @@ def get_minified_adata_scrna( """Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object.""" counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) all_zeros = csr_matrix(counts.shape) + if keep_count_data: + X = counts + layers = adata_manager.adata.layers + else: + X = all_zeros + layers = {layer: all_zeros.copy() for layer in adata_manager.adata.layers} return AnnData( - X=counts if keep_count_data else all_zeros, - layers={layer: all_zeros.copy() for layer in adata_manager.adata.layers}, + X=X, + layers=layers, obs=adata_manager.adata.obs.copy(), var=adata_manager.adata.var.copy(), uns=adata_manager.adata.uns.copy(), diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index 920b65ca18..aa1e284447 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -9,6 +9,7 @@ from torch.nn.functional import one_hot from scvi import REGISTRY_KEYS, settings +from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.module._constants import MODULE_KEYS from scvi.module.base import ( BaseMinifiedModeModuleClass, @@ -16,6 +17,7 @@ LossOutput, auto_move_data, ) +from scvi.utils import unsupported_if_adata_minified if TYPE_CHECKING: from collections.abc import Callable @@ -281,25 +283,31 @@ def __init__( def _get_inference_input( self, tensors: dict[str, torch.Tensor | None], + full_forward_pass: bool = False, ) -> dict[str, torch.Tensor | None]: """Get input tensors for the inference process.""" - from scvi.data._constants import ADATA_MINIFY_TYPE + if full_forward_pass or self.minified_data_type is None: + loader = "full_data" + elif self.minified_data_type in [ + ADATA_MINIFY_TYPE.LATENT_POSTERIOR, ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS + ]: + loader = "minified_data" + else: + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") - if self.minified_data_type is None: + if loader == "full_data": return { MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY], MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY], MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None), MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None), } - elif self.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + else: return { MODULE_KEYS.QZM_KEY: tensors[REGISTRY_KEYS.LATENT_QZM_KEY], MODULE_KEYS.QZV_KEY: tensors[REGISTRY_KEYS.LATENT_QZV_KEY], REGISTRY_KEYS.OBSERVED_LIB_SIZE: tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE], } - else: - raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") def _get_generative_input( self, @@ -414,14 +422,9 @@ def _cached_inference( """Run the cached inference process.""" from torch.distributions import Normal - from scvi.data._constants import ADATA_MINIFY_TYPE - - if self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") - - dist = Normal(qzm, qzv.sqrt()) + qz = Normal(qzm, qzv.sqrt()) # use dist.sample() rather than rsample because we aren't optimizing the z here - untran_z = dist.sample() if n_samples == 1 else dist.sample((n_samples,)) + untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) library = torch.log(observed_lib_size) if n_samples > 1: @@ -429,8 +432,7 @@ def _cached_inference( return { MODULE_KEYS.Z_KEY: z, - MODULE_KEYS.QZM_KEY: qzm, - MODULE_KEYS.QZV_KEY: qzv, + MODULE_KEYS.QZ_KEY: qz, MODULE_KEYS.QL_KEY: None, MODULE_KEYS.LIBRARY_KEY: library, } @@ -541,6 +543,7 @@ def generative( MODULE_KEYS.PZ_KEY: pz, } + @unsupported_if_adata_minified def loss( self, tensors: dict[str, torch.Tensor], diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index 39097c9039..7bfc00ace8 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -13,7 +13,6 @@ from torch import nn from scvi import settings -from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.utils._jax import device_selecting_PRNGKey from ._decorators import auto_move_data @@ -303,10 +302,7 @@ def inference(self, *args, **kwargs): Branches off to regular or cached inference depending on whether we have a minified adata that contains the latent posterior parameters. """ - if ( - self.minified_data_type is not None - and self.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - ): + if "qzm" in kwargs.keys() and "qzv" in kwargs.keys(): return self._cached_inference(*args, **kwargs) else: return self._regular_inference(*args, **kwargs) @@ -743,6 +739,8 @@ def _generic_forward( loss_kwargs = _get_dict_if_none(loss_kwargs) get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) + if not ("qzm" in tensors.keys() and "qzv" in tensors.keys()): + get_inference_input_kwargs.pop("full_forward_pass", None) inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) inference_outputs = module.inference(**inference_inputs, **inference_kwargs) diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 79aa4bf0e3..159442c805 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -146,6 +146,7 @@ def __init__( optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", optimizer_creator: TorchOptimizerCreator | None = None, lr: float = 1e-3, + update_only_decoder: bool = False, weight_decay: float = 1e-6, eps: float = 0.01, n_steps_kl_warmup: int = None, @@ -180,6 +181,7 @@ def __init__( self.min_kl_weight = min_kl_weight self.max_kl_weight = max_kl_weight self.optimizer_creator = optimizer_creator + self.update_only_decoder = update_only_decoder if self.optimizer_name == "Custom" and self.optimizer_creator is None: raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") @@ -275,7 +277,11 @@ def n_obs_validation(self, n_obs: int): def forward(self, *args, **kwargs): """Passthrough to the module's forward method.""" - return self.module(*args, **kwargs) + return self.module( + *args, + **kwargs, + get_inference_input_kwargs={'full_forward_pass': not self.update_only_decoder}, + ) @torch.inference_mode() def compute_and_log_metrics( diff --git a/src/scvi/utils/_decorators.py b/src/scvi/utils/_decorators.py index 156dc90e74..a688d3af22 100644 --- a/src/scvi/utils/_decorators.py +++ b/src/scvi/utils/_decorators.py @@ -1,13 +1,15 @@ from collections.abc import Callable from functools import wraps +from scvi.data._constants import ADATA_MINIFY_TYPE + def unsupported_if_adata_minified(fn: Callable) -> Callable: """Decorator to raise an error if the model's `adata` is minified.""" @wraps(fn) def wrapper(self, *args, **kwargs): - if getattr(self, "minified_data_type", None) is not None: + if getattr(self, "minified_data_type", None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise ValueError( f"The {fn.__qualname__} function currently does not support minified data." ) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 5b23cf4e3a..1cfe7d909b 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -160,7 +160,9 @@ def test_scanvi_from_scvi(save_path): scvi.model.SCANVI.from_scvi_model(model, "label_0") msg = ( - "We cannot use the given scvi model to initialize scanvi because it has a minified adata." + "We cannot use the given scVI model to initialize scANVI because it has minified adata. " + "Keep counts when minifying model using minified_data_type= " + "'latent_posterior_parameters_with_counts'." ) assert str(e.value) == msg @@ -173,7 +175,7 @@ def test_scanvi_from_scvi(save_path): adata2.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = ADATA_MINIFY_TYPE.LATENT_POSTERIOR with pytest.raises(ValueError) as e: scvi.model.SCANVI.from_scvi_model(loaded_model, "label_0", adata=adata2) - assert str(e.value) == "Please provide a non-minified `adata` to initialize scanvi." + assert str(e.value) == "Please provide a non-minified `adata` to initialize scANVI." scanvi_model = scvi.model.SCANVI.from_scvi_model(loaded_model, "label_0") scanvi_model.train(1) @@ -262,6 +264,39 @@ def test_validate_unsupported_if_minified(): model.get_latent_library_size() assert str(e.value) == common_err_msg.format("RNASeqMixin.get_latent_library_size") + with pytest.raises(ValueError) as e: + model.train() + assert str(e.value) == common_err_msg.format("VAE.loss") + + +def test_validate_supported_if_minified_keep_count(): + model, _, _, _ = prep_model() + model2, _, _, _ = prep_model() + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.minify_adata(minified_data_type="latent_posterior_parameters_with_counts") + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS + assert model2.minified_data_type is None + + assert np.allclose(model2.get_elbo(), model.get_elbo(), rtol=1e-2) + assert np.allclose( + model2.get_reconstruction_error()['reconstruction_loss'], + model.get_reconstruction_error()['reconstruction_loss'], rtol=1e-2 + ) + assert np.allclose(model2.get_marginal_ll(), model.get_marginal_ll(), rtol=1e-2) + + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + model.train(1, check_val_every_n_epoch=1, train_size=0.5, + plan_kwargs={"update_only_decoder": True}) + scanvi_model = scvi.model.SCANVI.from_scvi_model( + model, labels_key="labels", unlabeled_category="unknown") + scanvi_model.train() + scanvi_model.train(1, check_val_every_n_epoch=1, train_size=0.5, + plan_kwargs={"update_only_decoder": True}) + def test_scvi_with_minified_adata_save_then_load(save_path): # create a model and minify its adata, then save it and its adata. From bda6b30b3acbb46f1bde1bba4fae7d319d0e4fc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 20:51:38 +0000 Subject: [PATCH 29/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_scanvi.py | 2 +- src/scvi/module/_vae.py | 3 ++- src/scvi/train/_trainingplans.py | 2 +- src/scvi/utils/_decorators.py | 2 +- tests/model/test_models_with_minified_data.py | 18 +++++++++++------- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 604d4bfd4e..084d83be1f 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -211,7 +211,7 @@ def from_scvi_model( ) del scanvi_kwargs[k] - if scvi_model.minified_data_type==ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if scvi_model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise ValueError( "We cannot use the given scVI model to initialize scANVI because it has " "minified adata. Keep counts when minifying model using " diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index aa1e284447..513645ad67 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -289,7 +289,8 @@ def _get_inference_input( if full_forward_pass or self.minified_data_type is None: loader = "full_data" elif self.minified_data_type in [ - ADATA_MINIFY_TYPE.LATENT_POSTERIOR, ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS + ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS, ]: loader = "minified_data" else: diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 159442c805..b1ab32e8f9 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -280,7 +280,7 @@ def forward(self, *args, **kwargs): return self.module( *args, **kwargs, - get_inference_input_kwargs={'full_forward_pass': not self.update_only_decoder}, + get_inference_input_kwargs={"full_forward_pass": not self.update_only_decoder}, ) @torch.inference_mode() diff --git a/src/scvi/utils/_decorators.py b/src/scvi/utils/_decorators.py index a688d3af22..fe728ef33d 100644 --- a/src/scvi/utils/_decorators.py +++ b/src/scvi/utils/_decorators.py @@ -9,7 +9,7 @@ def unsupported_if_adata_minified(fn: Callable) -> Callable: @wraps(fn) def wrapper(self, *args, **kwargs): - if getattr(self, "minified_data_type", None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if getattr(self, "minified_data_type", None) == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise ValueError( f"The {fn.__qualname__} function currently does not support minified data." ) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 1cfe7d909b..858b7cd035 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -283,19 +283,23 @@ def test_validate_supported_if_minified_keep_count(): assert np.allclose(model2.get_elbo(), model.get_elbo(), rtol=1e-2) assert np.allclose( - model2.get_reconstruction_error()['reconstruction_loss'], - model.get_reconstruction_error()['reconstruction_loss'], rtol=1e-2 + model2.get_reconstruction_error()["reconstruction_loss"], + model.get_reconstruction_error()["reconstruction_loss"], + rtol=1e-2, ) assert np.allclose(model2.get_marginal_ll(), model.get_marginal_ll(), rtol=1e-2) model.train(1, check_val_every_n_epoch=1, train_size=0.5) - model.train(1, check_val_every_n_epoch=1, train_size=0.5, - plan_kwargs={"update_only_decoder": True}) + model.train( + 1, check_val_every_n_epoch=1, train_size=0.5, plan_kwargs={"update_only_decoder": True} + ) scanvi_model = scvi.model.SCANVI.from_scvi_model( - model, labels_key="labels", unlabeled_category="unknown") + model, labels_key="labels", unlabeled_category="unknown" + ) scanvi_model.train() - scanvi_model.train(1, check_val_every_n_epoch=1, train_size=0.5, - plan_kwargs={"update_only_decoder": True}) + scanvi_model.train( + 1, check_val_every_n_epoch=1, train_size=0.5, plan_kwargs={"update_only_decoder": True} + ) def test_scvi_with_minified_adata_save_then_load(save_path): From c5678e151ac820e931c385a4e2a713b5c3e23e7b Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 19 Nov 2024 23:43:27 +0200 Subject: [PATCH 30/51] added adata minification for multi/total vi and fixed several tests --- src/scvi/model/_multivi.py | 19 ++- src/scvi/model/_totalvi.py | 16 ++- .../test_models_with_mudata_minified_data.py | 128 ++++++++++++++---- 3 files changed, 133 insertions(+), 30 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index d18e44be74..e030d6961e 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -34,6 +34,7 @@ ) from scvi.model.base import ( ArchesMixin, + BaseMinifiedModeModelClass, BaseMudataMinifiedModeModelClass, UnsupervisedTrainingMixin, VAEMixin, @@ -63,7 +64,13 @@ logger = logging.getLogger(__name__) -class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, ArchesMixin, BaseMudataMinifiedModeModelClass): +class MULTIVI( + VAEMixin, + UnsupervisedTrainingMixin, + ArchesMixin, + BaseMinifiedModeModelClass, + BaseMudataMinifiedModeModelClass, +): """Integration of multi-modal and single-modality data :cite:p:`AshuachGabitto21`. MultiVI is used to integrate multiomic datasets with single-modality (expression @@ -637,6 +644,7 @@ def get_normalized_expression( n_samples_overall: int | None = None, transform_batch: Sequence[Number | str] | None = None, gene_list: Sequence[str] | None = None, + library_size: float | Literal["latent"] | None = 1, use_z_mean: bool = True, n_samples: int = 1, batch_size: int | None = None, @@ -666,6 +674,10 @@ def get_normalized_expression( Return frequencies of expression for a subset of genes. This can save memory when working with large datasets and few genes are of interest. + library_size + Scale the expression frequencies to a common library size. + This allows gene expression levels to be interpreted on a common scale of relevant + magnitude. use_z_mean If True, use the mean of the latent distribution, otherwise sample from it n_samples @@ -713,7 +725,10 @@ def get_normalized_expression( generative_kwargs={"use_z_mean": use_z_mean}, compute_loss=False, ) - output = generative_outputs["px_scale"] + if library_size == "latent": + output = generative_outputs["px_rate"] + else: + output = generative_outputs["px_scale"] output = output[..., gene_mask] output = output.cpu().numpy() per_batch_exprs.append(output) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index d1fe50b58f..0db7fad1a3 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -29,7 +29,13 @@ from scvi.train import AdversarialTrainingPlan, TrainRunner from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp -from .base import ArchesMixin, BaseMudataMinifiedModeModelClass, RNASeqMixin, VAEMixin +from .base import ( + ArchesMixin, + BaseMinifiedModeModelClass, + BaseMudataMinifiedModeModelClass, + RNASeqMixin, + VAEMixin, +) if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -50,7 +56,13 @@ logger = logging.getLogger(__name__) -class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMudataMinifiedModeModelClass): +class TOTALVI( + RNASeqMixin, + VAEMixin, + ArchesMixin, + BaseMinifiedModeModelClass, + BaseMudataMinifiedModeModelClass, +): """total Variational Inference :cite:p:`GayosoSteier21`. Parameters diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py index 63103cdaa3..3022f48f60 100644 --- a/tests/model/test_models_with_mudata_minified_data.py +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -11,6 +11,78 @@ _MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size" +def prep_model(cls=TOTALVI, use_size_factor=False): + # create a synthetic dataset + adata = synthetic_iid() + adata_counts = adata.X + if use_size_factor: + adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) + adata.var["n_counts"] = np.squeeze(np.asarray(np.sum(adata_counts, axis=0))) + adata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(adata.shape[1], 3)) + adata.layers["my_layer"] = np.ones_like(adata.X) + adata_before_setup = adata.copy() + + # run setup_anndata + setup_kwargs = { + "batch_key": "batch", + "protein_expression_obsm_key": "protein_expression", + "protein_names_uns_key": "protein_names", + } + if use_size_factor: + setup_kwargs["size_factor_key"] = "size_factor" + cls.setup_anndata( + adata, + **setup_kwargs, + ) + + # create and train the model + if cls == TOTALVI: + model = cls(adata, n_latent=5) + else: + model = cls(adata, n_latent=5, n_genes=50, n_regions=50) + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + + # get the adata lib size + adata_lib_size = np.squeeze(np.asarray(adata_counts.sum(axis=1))) + assert ( + np.min(adata_lib_size) > 0 + ) # make sure it's not all zeros and there are no negative values + + return model, adata, adata_lib_size, adata_before_setup + + +def run_test_for_model_with_minified_adata( + cls=TOTALVI, + n_samples: int = 1, + give_mean: bool = False, + use_size_factor=False, +): + model, adata, adata_lib_size, _ = prep_model(cls, use_size_factor) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + adata_orig = adata.copy() + + model.minify_adata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + assert model.adata_manager.registry is model.registry_ + + # make sure the original adata we set up the model with was not changed + assert adata is not model.adata + assert _is_minified(adata) is False + + assert adata_orig.layers.keys() == model.adata.layers.keys() + orig_obs_df = adata_orig.obs + obs_keys = "observed_lib_size" + orig_obs_df[obs_keys] = adata_lib_size + assert model.adata.obs.equals(orig_obs_df) + assert model.adata.var_names.equals(adata_orig.var_names) + assert model.adata.var.equals(adata_orig.var) + assert model.adata.varm.keys() == adata_orig.varm.keys() + np.testing.assert_array_equal(model.adata.varm["my_varm"], adata_orig.varm["my_varm"]) + + def prep_model_mudata(cls=TOTALVI, use_size_factor=False): # create a synthetic dataset # adata = synthetic_iid() @@ -69,10 +141,7 @@ def prep_model_mudata(cls=TOTALVI, use_size_factor=False): model.train(1, check_val_every_n_epoch=1, train_size=0.5) # get the mdata lib size - if cls == TOTALVI: - mdata_lib_size = np.squeeze(np.asarray(mdata["rna"].X.sum(axis=1))) # TOTALVI - else: - mdata_lib_size = np.squeeze(np.asarray(mdata["accessibility"].X.sum(axis=1))) # MULTIVI + mdata_lib_size = np.squeeze(np.asarray(mdata["rna"].X.sum(axis=1))) assert ( np.min(mdata_lib_size) > 0 ) # make sure it's not all zeros and there are no negative values @@ -80,12 +149,6 @@ def prep_model_mudata(cls=TOTALVI, use_size_factor=False): return model, mdata, mdata_lib_size, mdata_before_setup -def assert_approx_equal(a, b): - # Allclose because on GPU, the values are not exactly the same - # as some values are moved to cpu during data minification - np.testing.assert_allclose(a, b, rtol=3e-1, atol=5e-1) - - def run_test_for_model_with_minified_mudata( cls=TOTALVI, use_size_factor=False, @@ -118,6 +181,18 @@ def run_test_for_model_with_minified_mudata( np.testing.assert_array_equal(model.adata.varm["my_varm"], mdata_orig.varm["my_varm"]) +def assert_approx_equal(a, b): + # Allclose because on GPU, the values are not exactly the same + # as some values are moved to cpu during data minification + np.testing.assert_allclose(a, b, rtol=3e-1, atol=5e-1) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +@pytest.mark.parametrize("use_size_factor", [False, True]) +def test_with_minified_adata(cls, use_size_factor: bool): + run_test_for_model_with_minified_adata(cls=cls, use_size_factor=use_size_factor) + + @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) @pytest.mark.parametrize("use_size_factor", [False, True]) def test_with_minified_mudata(cls, use_size_factor: bool): @@ -142,10 +217,10 @@ def test_with_minified_mdata_get_normalized_expression(cls): for ii in range(len(exprs_new)): assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape for ii in range(len(exprs_new)): - np.testing.assert_array_equal(exprs_new[ii], exprs_orig[ii]) + assert_approx_equal(exprs_new[ii], exprs_orig[ii]) else: - assert exprs_new.shape == mdata[mdata.mod_names].shape - np.testing.assert_array_equal(exprs_new, exprs_orig) + assert exprs_new.shape == exprs_orig.shape + assert_approx_equal(exprs_new, exprs_orig) @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) @@ -167,17 +242,18 @@ def test_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - # do this so that we generate the same sequence of random numbers in the - # minified and non-minified cases (purely to get the tests to pass). this is - # because in the non-minified case we sample once more (in the call to z_encoder - # during inference) exprs_new = model.get_normalized_expression( - gene_list=gl, n_samples=n_samples + 1, return_mean=False, library_size="latent" + gene_list=gl, n_samples=n_samples + 1, library_size="latent" ) - exprs_new = exprs_new[0][:, :, 1:].mean(2) - assert exprs_new.shape == (mdata.shape[0], 5) - np.testing.assert_allclose(exprs_new, exprs_orig[0], rtol=3e-1, atol=5e-1) + if type(exprs_new) is tuple: + for ii in range(len(exprs_new)): + assert exprs_new[ii].shape == exprs_orig[ii].shape # mdata[mdata.mod_names[ii]].shape + for ii in range(len(exprs_new)): + assert_approx_equal(exprs_new[ii], exprs_orig[ii]) + else: + assert exprs_new.shape == exprs_orig.shape + assert_approx_equal(exprs_new, exprs_orig) @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) @@ -212,7 +288,7 @@ def test_validate_unsupported_if_minified(cls): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_with_minified_mdata_save_then_load(cls, save_path): +def test_with_minified_mdata_save_then_load(cls, save_path="."): # create a model and minify its mdata, then save it and its mdata. # Load it back up using the same (minified) mdata. Validate that the # loaded model has the minified_data_type attribute set as expected. @@ -234,7 +310,7 @@ def test_with_minified_mdata_save_then_load(cls, save_path): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path): +def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path="."): # create a model and minify its mdata, then save it and its mdata. # Load it back up using a non-minified mdata. Validate that the # loaded model does not has the minified_data_type attribute set. @@ -255,7 +331,7 @@ def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_pa @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_save_then_load_with_minified_mdata(cls, save_path): +def test_save_then_load_with_minified_mdata(cls, save_path="."): # create a model, then save it and its mdata (non-minified). # Load it back up using a minified mdata. Validate that this # fails, as expected because we don't have a way to validate @@ -293,7 +369,7 @@ def test_with_minified_mdata_get_latent_representation(cls): latent_repr_new = model.get_latent_representation() - np.testing.assert_array_equal(latent_repr_new, latent_repr_orig) + assert_approx_equal(latent_repr_new, latent_repr_orig) @pytest.mark.parametrize("cls", [TOTALVI]) @@ -312,7 +388,7 @@ def test_with_minified_mdata_posterior_predictive_sample(cls): sample_new = model.posterior_predictive_sample() # assert sample_new.shape == (3, 2) - np.testing.assert_array_equal(sample_new, sample_orig) + assert_approx_equal(sample_new, sample_orig) @pytest.mark.parametrize("cls", [TOTALVI]) From 108175d547b6a2a67a9b400d29e094f27cd8ac4d Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Tue, 19 Nov 2024 15:16:52 -0800 Subject: [PATCH 31/51] Fixed loss computation for keep count models --- src/scvi/model/base/_log_likelihood.py | 15 +++++++++++++-- src/scvi/model/base/_vaemixin.py | 2 +- src/scvi/module/_vae.py | 3 ++- src/scvi/module/base/_base_module.py | 2 +- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/scvi/model/base/_log_likelihood.py b/src/scvi/model/base/_log_likelihood.py index 97df6f557b..a545544661 100644 --- a/src/scvi/model/base/_log_likelihood.py +++ b/src/scvi/model/base/_log_likelihood.py @@ -1,5 +1,6 @@ from __future__ import annotations +from inspect import signature from typing import TYPE_CHECKING import torch @@ -48,8 +49,13 @@ def compute_elbo( The evidence lower bound (ELBO) of the data. """ elbo = [] + if "full_forward_pass" in signature(module._get_inference_input).parameters: + get_inference_input_kwargs = {"full_forward_pass": True} + else: + get_inference_input_kwargs = {} for tensors in dataloader: - _, _, losses = module(tensors, **kwargs) + _, _, losses = module(tensors, **kwargs, + get_inference_input_kwargs=get_inference_input_kwargs) if isinstance(losses.reconstruction_loss, dict): reconstruction_loss = torch.stack(list(losses.reconstruction_loss.values())).sum(dim=0) else: @@ -99,9 +105,14 @@ def compute_reconstruction_error( A dictionary of the reconstruction error of the data. """ # Iterate once over the data and computes the reconstruction error + if "full_forward_pass" in signature(module._get_inference_input).parameters: + get_inference_input_kwargs = {"full_forward_pass": True} + else: + get_inference_input_kwargs = {} log_lkl = {} for tensors in dataloader: - _, _, loss_output = module(tensors, loss_kwargs={"kl_weight": 1}, **kwargs) + _, _, loss_output = module(tensors, loss_kwargs={"kl_weight": 1}, **kwargs, + get_inference_input_kwargs=get_inference_input_kwargs) if not isinstance(loss_output.reconstruction_loss, dict): rec_loss_dict = {"reconstruction_loss": loss_output.reconstruction_loss} else: diff --git a/src/scvi/model/base/_vaemixin.py b/src/scvi/model/base/_vaemixin.py index d6099d02fa..1a2ea85bfa 100644 --- a/src/scvi/model/base/_vaemixin.py +++ b/src/scvi/model/base/_vaemixin.py @@ -29,7 +29,7 @@ def get_elbo( adata: AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, - dataloader: Iterator[dict[str, Tensor | None]] = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, return_mean: bool = True, **kwargs, ) -> float: diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index 513645ad67..25ad63efdc 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -674,7 +674,8 @@ def marginal_ll( for _ in range(n_passes): # Distribution parameters and sampled variables inference_outputs, _, losses = self.forward( - tensors, inference_kwargs={"n_samples": n_mc_samples_per_pass} + tensors, inference_kwargs={"n_samples": n_mc_samples_per_pass}, + get_inference_input_kwargs = {"full_forward_pass": True} ) qz = inference_outputs[MODULE_KEYS.QZ_KEY] ql = inference_outputs[MODULE_KEYS.QL_KEY] diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index 7bfc00ace8..ba1855e15b 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -739,7 +739,7 @@ def _generic_forward( loss_kwargs = _get_dict_if_none(loss_kwargs) get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) - if not ("qzm" in tensors.keys() and "qzv" in tensors.keys()): + if not ("latent_qzm" in tensors.keys() and "latent_qzv" in tensors.keys()): get_inference_input_kwargs.pop("full_forward_pass", None) inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) From 9706947f5d19a21906c2db8204786386911892f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:17:05 +0000 Subject: [PATCH 32/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/base/_log_likelihood.py | 13 +++++++++---- src/scvi/module/_vae.py | 5 +++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/scvi/model/base/_log_likelihood.py b/src/scvi/model/base/_log_likelihood.py index a545544661..78aa34b20c 100644 --- a/src/scvi/model/base/_log_likelihood.py +++ b/src/scvi/model/base/_log_likelihood.py @@ -54,8 +54,9 @@ def compute_elbo( else: get_inference_input_kwargs = {} for tensors in dataloader: - _, _, losses = module(tensors, **kwargs, - get_inference_input_kwargs=get_inference_input_kwargs) + _, _, losses = module( + tensors, **kwargs, get_inference_input_kwargs=get_inference_input_kwargs + ) if isinstance(losses.reconstruction_loss, dict): reconstruction_loss = torch.stack(list(losses.reconstruction_loss.values())).sum(dim=0) else: @@ -111,8 +112,12 @@ def compute_reconstruction_error( get_inference_input_kwargs = {} log_lkl = {} for tensors in dataloader: - _, _, loss_output = module(tensors, loss_kwargs={"kl_weight": 1}, **kwargs, - get_inference_input_kwargs=get_inference_input_kwargs) + _, _, loss_output = module( + tensors, + loss_kwargs={"kl_weight": 1}, + **kwargs, + get_inference_input_kwargs=get_inference_input_kwargs, + ) if not isinstance(loss_output.reconstruction_loss, dict): rec_loss_dict = {"reconstruction_loss": loss_output.reconstruction_loss} else: diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index 25ad63efdc..adfc2934ef 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -674,8 +674,9 @@ def marginal_ll( for _ in range(n_passes): # Distribution parameters and sampled variables inference_outputs, _, losses = self.forward( - tensors, inference_kwargs={"n_samples": n_mc_samples_per_pass}, - get_inference_input_kwargs = {"full_forward_pass": True} + tensors, + inference_kwargs={"n_samples": n_mc_samples_per_pass}, + get_inference_input_kwargs={"full_forward_pass": True}, ) qz = inference_outputs[MODULE_KEYS.QZ_KEY] ql = inference_outputs[MODULE_KEYS.QL_KEY] From a0ed985f86d181d3295268eec9f1af7a518a28ac Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Tue, 19 Nov 2024 16:52:49 -0800 Subject: [PATCH 33/51] Increase tolerance --- tests/model/test_models_with_minified_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 858b7cd035..74ece7ee67 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -281,13 +281,13 @@ def test_validate_supported_if_minified_keep_count(): assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS assert model2.minified_data_type is None - assert np.allclose(model2.get_elbo(), model.get_elbo(), rtol=1e-2) + assert np.allclose(model2.get_elbo(), model.get_elbo(), rtol=5e-2) assert np.allclose( model2.get_reconstruction_error()["reconstruction_loss"], model.get_reconstruction_error()["reconstruction_loss"], - rtol=1e-2, + rtol=5e-2, ) - assert np.allclose(model2.get_marginal_ll(), model.get_marginal_ll(), rtol=1e-2) + assert np.allclose(model2.get_marginal_ll(), model.get_marginal_ll(), rtol=5e-2) model.train(1, check_val_every_n_epoch=1, train_size=0.5) model.train( From 0bd790a7ae50f45b45d0a55dbc49f303207979db Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Tue, 19 Nov 2024 17:11:00 -0800 Subject: [PATCH 34/51] Typo --- scvi-tutorials | 1 + tests/model/test_models_with_minified_data.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 160000 scvi-tutorials diff --git a/scvi-tutorials b/scvi-tutorials new file mode 160000 index 0000000000..f8965ce9f2 --- /dev/null +++ b/scvi-tutorials @@ -0,0 +1 @@ +Subproject commit f8965ce9f2d19748c0895800ec5c7ce438be3ebe diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 74ece7ee67..52c9013362 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -161,7 +161,7 @@ def test_scanvi_from_scvi(save_path): msg = ( "We cannot use the given scVI model to initialize scANVI because it has minified adata. " - "Keep counts when minifying model using minified_data_type= " + "Keep counts when minifying model using minified_data_type=" "'latent_posterior_parameters_with_counts'." ) assert str(e.value) == msg From c971e91df22107b3bd9ead97f8ecfc9003daa715 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Tue, 19 Nov 2024 19:18:25 -0800 Subject: [PATCH 35/51] Changed keep adata when keep_counts --- src/scvi/model/base/_log_likelihood.py | 2 +- src/scvi/model/utils/_minification.py | 29 +++++++++++++------------- src/scvi/module/base/_base_module.py | 1 + 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/scvi/model/base/_log_likelihood.py b/src/scvi/model/base/_log_likelihood.py index 78aa34b20c..ea0552f61c 100644 --- a/src/scvi/model/base/_log_likelihood.py +++ b/src/scvi/model/base/_log_likelihood.py @@ -115,8 +115,8 @@ def compute_reconstruction_error( _, _, loss_output = module( tensors, loss_kwargs={"kl_weight": 1}, - **kwargs, get_inference_input_kwargs=get_inference_input_kwargs, + **kwargs, ) if not isinstance(loss_output.reconstruction_loss, dict): rec_loss_dict = {"reconstruction_loss": loss_output.reconstruction_loss} diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index ea3bb55a58..77fb737ba1 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -16,22 +16,21 @@ def get_minified_adata_scrna( keep_count_data: bool = False, ) -> AnnData: """Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object.""" - counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - all_zeros = csr_matrix(counts.shape) if keep_count_data: - X = counts - layers = adata_manager.adata.layers + return adata_manager.adata.copy() else: + counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + all_zeros = csr_matrix(counts.shape) X = all_zeros layers = {layer: all_zeros.copy() for layer in adata_manager.adata.layers} - return AnnData( - X=X, - layers=layers, - obs=adata_manager.adata.obs.copy(), - var=adata_manager.adata.var.copy(), - uns=adata_manager.adata.uns.copy(), - obsm=adata_manager.adata.obsm.copy(), - varm=adata_manager.adata.varm.copy(), - obsp=adata_manager.adata.obsp.copy(), - varp=adata_manager.adata.varp.copy(), - ) + return AnnData( + X=X, + layers=layers, + obs=adata_manager.adata.obs.copy(), + var=adata_manager.adata.var.copy(), + uns=adata_manager.adata.uns.copy(), + obsm=adata_manager.adata.obsm.copy(), + varm=adata_manager.adata.varm.copy(), + obsp=adata_manager.adata.obsp.copy(), + varp=adata_manager.adata.varp.copy(), + ) diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index ba1855e15b..46395e8752 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -740,6 +740,7 @@ def _generic_forward( get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) if not ("latent_qzm" in tensors.keys() and "latent_qzv" in tensors.keys()): + # Remove full_forward_pass if not minified model get_inference_input_kwargs.pop("full_forward_pass", None) inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) From 24177da7ffb2ba63b01e2cfe2b8dae51a236a58c Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Wed, 20 Nov 2024 00:13:11 -0800 Subject: [PATCH 36/51] Fixed multiVI mudata --- src/scvi/data/fields/_arraylike_field.py | 13 +++- src/scvi/model/_multivi.py | 33 +++++---- src/scvi/module/_multivae.py | 62 ++++++++-------- tests/model/test_multivi.py | 92 ++++++++++++++---------- 4 files changed, 119 insertions(+), 81 deletions(-) diff --git a/src/scvi/data/fields/_arraylike_field.py b/src/scvi/data/fields/_arraylike_field.py index 4e02751d8c..c1d51e3a9a 100644 --- a/src/scvi/data/fields/_arraylike_field.py +++ b/src/scvi/data/fields/_arraylike_field.py @@ -212,6 +212,8 @@ class BaseJointField(BaseArrayLikeField): Sequence of keys to combine to form the obsm or varm field. field_type Type of field. Can be either 'obsm' or 'varm'. + required + If True, the field must be present in the AnnData object """ def __init__( @@ -219,6 +221,7 @@ def __init__( registry_key: str, attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, + required: bool = True, ) -> None: super().__init__(registry_key) if field_type == "obsm": @@ -232,6 +235,7 @@ def __init__( self._attr_key = f"_scvi_{registry_key}" self._attr_keys = attr_keys if attr_keys is not None else [] self._is_empty = len(self.attr_keys) == 0 + self._required = required def validate_field(self, adata: AnnData) -> None: """Validate the field.""" @@ -266,6 +270,10 @@ def attr_key(self) -> str: @property def is_empty(self) -> bool: return self._is_empty + + @property + def required(self) -> bool: + return self._required class NumericalJointField(BaseJointField): @@ -282,6 +290,8 @@ class NumericalJointField(BaseJointField): Sequence of keys to combine to form the obsm or varm field. field_type Type of field. Can be either 'obsm' or 'varm'. + required + If True, the field must be present in the AnnData object """ COLUMNS_KEY = "columns" @@ -291,8 +301,9 @@ def __init__( registry_key: str, attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, + required: bool = True, ) -> None: - super().__init__(registry_key, attr_keys, field_type=field_type) + super().__init__(registry_key, attr_keys, field_type=field_type, required=required) self.count_stat_key = f"n_{self.registry_key}" diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index e030d6961e..406d2f93f1 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -164,8 +164,8 @@ class MULTIVI( def __init__( self, adata: AnnOrMuData, - n_genes: int, - n_regions: int, + n_genes: int | None = None, + n_regions: int | None = None, modality_weights: Literal["equal", "cell", "universal"] = "equal", modality_penalty: Literal["Jeffreys", "MMD", "None"] = "Jeffreys", n_hidden: int | None = None, @@ -187,6 +187,11 @@ def __init__( ): super().__init__(adata) + if n_genes is None or n_regions is None: + assert isinstance(adata, MuData), "n_genes and n_regions must be provided if using AnnData" + n_genes = self.summary_stats.get("n_vars", 0) + n_regions = self.summary_stats.get("n_atac", 0) + prior_mean, prior_scale = None, None n_cats_per_cov = ( self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key @@ -196,10 +201,6 @@ def __init__( use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - # TODO: ADD MINIFICATION CONSIDERATION HERE? - # if not use_size_factor_key and self.minified_data_type is None: - # library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) - if "n_proteins" in self.summary_stats: n_proteins = self.summary_stats.n_proteins else: @@ -618,7 +619,9 @@ def get_accessibility_estimates( imputed = vstack(imputed, format="csr") else: # imputed is a list of tensors imputed = torch.cat(imputed).numpy() - + print('SDSDSD', imputed.shape) + print(adata["rna"].var_names[self.n_genes :][region_mask].shape) + print(adata.obs_names[indices].shape) if return_numpy: return imputed elif threshold: @@ -1147,7 +1150,7 @@ def setup_anndata( batch_field, CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + NumericalJointObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), @@ -1205,12 +1208,16 @@ def setup_mudata( protein_layer Protein layer key. If `None`, will use `.X` of specified modality key. %(param_batch_key)s - %(param_size_factor_key)s + size_factor_key + Key in `mdata.obsm` for size factors. The first column corresponds to RNA size factors, + the second to ATAC size factors. + The second column need to be normalized and between 0 and 1. %(param_cat_cov_keys)s %(param_cont_cov_keys)s %(idx_layer)s %(param_modalities)s + Examples -------- >>> mdata = muon.read_10x_h5("filtered_feature_bc_matrix.h5") @@ -1238,10 +1245,10 @@ def setup_mudata( None, mod_key=None, ), - fields.MuDataNumericalObsField( + fields.MuDataNumericalJointObsField( REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, - mod_key=modalities.size_factor_key, + mod_key=None, required=False, ), fields.MuDataCategoricalJointObsField( @@ -1367,8 +1374,8 @@ def minify_mudata( if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - # if self.module.use_observed_lib_size is False: - # raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") + if self.module.use_size_factor is False: + raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") minified_adata = get_minified_mudata(self.adata, minified_data_type) minified_adata.obsm[_MULTIVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index cec268d184..5bda99cb7d 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -259,8 +259,6 @@ class MULTIVAE(BaseMinifiedModeModuleClass): RNA distribution. """ - # TODO: replace n_input_regions and n_input_genes with a gene/region mask (we don't dictate - # which comes first or that they're even contiguous) def __init__( self, n_input_regions: int = 0, @@ -301,7 +299,7 @@ def __init__( if n_input_regions == 0: self.n_hidden = np.min([128, int(np.sqrt(self.n_input_genes))]) else: - self.n_hidden = int(np.sqrt(self.n_input_regions)) + self.n_hidden = np.min([128, int(np.sqrt(self.n_input_regions))]) else: self.n_hidden = n_hidden self.n_batch = n_batch @@ -536,7 +534,12 @@ def _get_inference_input(self, tensors): # from scvi.data._constants import ADATA_MINIFY_TYPE # TODO: ADD MINIFICATION CONSIDERATION - x = tensors[REGISTRY_KEYS.X_KEY] + x = tensors.get(REGISTRY_KEYS.X_KEY, None) + x_atac = tensors.get(REGISTRY_KEYS.ATAC_X_KEY, None) + if x is not None and x_atac is not None: + x = torch.cat((x, x_atac), dim=-1) + elif x is None: + x = x_atac if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: @@ -546,6 +549,7 @@ def _get_inference_input(self, tensors): cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY) cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY) label = tensors[REGISTRY_KEYS.LABELS_KEY] + size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None) input_dict = { "x": x, "y": y, @@ -554,6 +558,7 @@ def _get_inference_input(self, tensors): "cat_covs": cat_covs, "label": label, "cell_idx": cell_idx, + "size_factor": size_factor, } return input_dict @@ -567,6 +572,7 @@ def inference( cat_covs, label, cell_idx, + size_factor, n_samples=1, ) -> dict[str, torch.Tensor]: """Run the inference model.""" @@ -576,21 +582,21 @@ def inference( else: x_rna = x[:, : self.n_input_genes] if self.n_input_regions == 0: - x_chr = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) + x_atac = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: - x_chr = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] mask_expr = x_rna.sum(dim=1) > 0 - mask_acc = x_chr.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 mask_pro = y.sum(dim=1) > 0 if cont_covs is not None and self.encode_covariates: encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1) - encoder_input_accessibility = torch.cat((x_chr, cont_covs), dim=-1) + encoder_input_accessibility = torch.cat((x_atac, cont_covs), dim=-1) encoder_input_protein = torch.cat((y, cont_covs), dim=-1) else: encoder_input_expression = x_rna - encoder_input_accessibility = x_chr + encoder_input_accessibility = x_atac encoder_input_protein = y if cat_covs is not None and self.encode_covariates: @@ -610,12 +616,16 @@ def inference( ) # L encoders - libsize_expr = self.l_encoder_expression( - encoder_input_expression, batch_index, *categorical_input - ) - libsize_acc = self.l_encoder_accessibility( - encoder_input_accessibility, batch_index, *categorical_input - ) + if self.use_size_factor_key: + libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) + libsize_acc = size_factor[:, [1]] + else: + libsize_expr = self.l_encoder_expression( + encoder_input_expression, batch_index, *categorical_input + ) + libsize_acc = self.l_encoder_accessibility( + encoder_input_accessibility, batch_index, *categorical_input + ) # mix representations if self.modality_weights == "cell": @@ -654,6 +664,7 @@ def unsqz(zt, n_s): z = self.z_encoder_accessibility.z_transformation(untran_z) outputs = { + "x": x, "z": z, "qz_m": qz_m, "qz_v": qz_v, @@ -677,11 +688,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None qz_m = inference_outputs["qz_m"] libsize_expr = inference_outputs["libsize_expr"] - size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY - size_factor = ( - torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None - ) - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None @@ -701,7 +707,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None "cont_covs": cont_covs, "cat_covs": cat_covs, "libsize_expr": libsize_expr, - "size_factor": size_factor, "label": label, } return input_dict @@ -715,7 +720,6 @@ def generative( cont_covs=None, cat_covs=None, libsize_expr=None, - size_factor=None, use_z_mean=False, label: torch.Tensor = None, ): @@ -739,12 +743,10 @@ def generative( p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) # Expression Decoder - if not self.use_size_factor_key: - size_factor = libsize_expr px_scale, _, px_rate, px_dropout = self.z_decoder_expression( self.gene_dispersion, decoder_input, - size_factor, + libsize_expr, batch_index, *categorical_input, label, @@ -786,24 +788,23 @@ def generative( def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): """Computes the loss function for the model.""" # Get the data - x = tensors[REGISTRY_KEYS.X_KEY] + x = inference_outputs["x"] - # TODO: CHECK IF THIS FAILS IN ONLY RNA DATA x_rna = x[:, : self.n_input_genes] - x_chr = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] mask_expr = x_rna.sum(dim=1) > 0 - mask_acc = x_chr.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 mask_pro = y.sum(dim=1) > 0 # Compute Accessibility loss p = generative_outputs["p"] libsize_acc = inference_outputs["libsize_acc"] - rl_accessibility = self.get_reconstruction_loss_accessibility(x_chr, p, libsize_acc) + rl_accessibility = self.get_reconstruction_loss_accessibility(x_atac, p, libsize_acc) # Compute Expression loss px_rate = generative_outputs["px_rate"] @@ -822,7 +823,6 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float rl_protein = torch.zeros(x.shape[0], device=x.device, requires_grad=False) # calling without weights makes this act like a masked sum - # TODO : CHECK MIXING HERE recon_loss_expression = rl_expression * mask_expr recon_loss_accessibility = rl_accessibility * mask_acc recon_loss_protein = rl_protein * mask_pro diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 3a05a773e0..a818429531 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -42,7 +42,7 @@ def test_multivi(): # Test with size factor data = synthetic_iid() data.obs["size_factor"] = np.random.randint(1, 5, size=(data.shape[0],)) - MULTIVI.setup_anndata(data, batch_key="batch", size_factor_key="size_factor") + MULTIVI.setup_anndata(data, batch_key="batch") vae = MULTIVI( data, n_genes=50, @@ -87,11 +87,9 @@ def test_multivi_single_batch(): def test_multivi_mudata_rna_prot_external(): - # Example on how to download protein adata to mudata (from multivi tutotial) - mudata RNA/PROT + # Example on how to download protein adata to mudata (from multivi tutorial) - mudata RNA/PROT adata = scvi.data.pbmcs_10x_cite_seq() adata.layers["counts"] = adata.X.copy() - sc.pp.normalize_total(adata) - sc.pp.log1p(adata) adata.obs_names_make_unique() protein_adata = ad.AnnData(adata.obsm["protein_expression"]) protein_adata.obs_names = adata.obs_names @@ -119,24 +117,19 @@ def test_multivi_mudata_rna_prot_external(): "batch_key": "rna_subset", }, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) model.train(1, train_size=0.9) def test_multivi_mudata_rna_atac_external(): # optional data - mudata RNA/ATAC mdata = synthetic_iid(return_mudata=True) - # Preprocessing - sc.pp.normalize_total(mdata.mod["rna"]) - sc.pp.log1p(mdata.mod["rna"]) sc.pp.highly_variable_genes( mdata.mod["rna"], n_top_genes=4000, flavor="seurat_v3", ) mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy() - sc.pp.normalize_total(mdata.mod["accessibility"]) - sc.pp.log1p(mdata.mod["accessibility"]) sc.pp.highly_variable_genes( mdata.mod["accessibility"], n_top_genes=4000, @@ -146,13 +139,46 @@ def test_multivi_mudata_rna_atac_external(): :, mdata.mod["accessibility"].var["highly_variable"] ].copy() mdata.update() - # mdata - # mdata.mod MULTIVI.setup_mudata( - mdata, modalities={"rna_layer": "rna_subset", "atac_layer": "atac_subset"} + mdata, + modalities={"rna_layer": "rna_subset", "atac_layer": "atac_subset"}, + ) + model = MULTIVI(mdata) + model.train(1, train_size=0.9) + + +def test_multivi_mudata_trimodal_external(): + # optional data - mudata RNA/ATAC + mdata = synthetic_iid(return_mudata=True) + MULTIVI.setup_mudata( + mdata, + modalities={ + "rna_layer": "rna", + "atac_layer": "accessibility", + "protein_layer": "protein_expression" + }, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) + model.train(1, train_size=0.9) model.train(1, train_size=0.9) + assert model.is_trained is True + model.get_latent_representation() + model.get_elbo() + model.get_reconstruction_error() + model.get_normalized_expression() + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + model.get_elbo(indices=model.validation_indices) + model.get_reconstruction_error(indices=model.validation_indices) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() def test_multivi_mudata(): @@ -171,7 +197,7 @@ def test_multivi_mudata(): n_obs = mdata.n_obs n_latent = 10 - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) model.train(1, train_size=0.9) assert model.is_trained is True z = model.get_latent_representation() @@ -224,7 +250,7 @@ def test_multivi_auto_transfer_mudata(): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) @@ -247,7 +273,7 @@ def test_multivi_incorrect_mapping_mudata(): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) @@ -266,7 +292,7 @@ def test_multivi_reordered_mapping_mudata(): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) @@ -291,7 +317,7 @@ def test_multivi_model_library_size_mudata(): ) n_latent = 10 - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) model.train(1, train_size=0.5) assert model.is_trained is True model.get_elbo() @@ -303,30 +329,24 @@ def test_multivi_model_library_size_mudata(): def test_multivi_size_factor_mudata(): - adata = synthetic_iid() - adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) - protein_adata = synthetic_iid(n_genes=50) - mdata = MuData({"rna": adata, "protein": protein_adata}) + mdata = synthetic_iid(return_mudata=True) + mdata.obs['size_factor_rna'] = mdata["rna"].X.sum(1) + mdata.obs['size_factor_atac'] = ( + mdata["accessibility"].X.sum(1) + 1) / (np.max(mdata["accessibility"].X.sum(1)) + 1.01) MULTIVI.setup_mudata( mdata, - batch_key="batch", - size_factor_key="size_factor", - modalities={ - "rna_layer": "rna", - "batch_key": "rna", - "protein_layer": "protein", - "size_factor_key": "rna", - }, + modalities={"rna_layer": "rna", "atac_layer": "accessibility"}, + size_factor_key=["size_factor_rna", "size_factor_atac"], ) n_latent = 10 # Test size_factor_key overrides use_observed_lib_size. - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) assert model.module.use_size_factor_key model.train(1, train_size=0.5) - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) assert model.module.use_size_factor_key model.train(1, train_size=0.5) @@ -340,7 +360,7 @@ def test_multivi_saving_and_loading_mudata(save_path: str): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) model.train(1, train_size=0.2) z1 = model.get_latent_representation(mdata) test_idx1 = model.validation_indices @@ -402,7 +422,7 @@ def test_scarches_mudata_prep_layer(save_path: str): batch_key="batch", modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, ) - model = MULTIVI(mdata1, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata1, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) @@ -448,7 +468,7 @@ def test_multivi_save_load_mudata_format(save_path: str): mdata, modalities={"rna_layer": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) model.train(max_epochs=1) legacy_model_path = os.path.join(save_path, "legacy_model") From c7b0a3bbb94e061f511c77bf9fcdf2c05ec2efb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Nov 2024 08:13:24 +0000 Subject: [PATCH 37/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/data/fields/_arraylike_field.py | 2 +- src/scvi/model/_multivi.py | 6 ++++-- tests/model/test_multivi.py | 9 +++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/scvi/data/fields/_arraylike_field.py b/src/scvi/data/fields/_arraylike_field.py index c1d51e3a9a..b4dafc9aed 100644 --- a/src/scvi/data/fields/_arraylike_field.py +++ b/src/scvi/data/fields/_arraylike_field.py @@ -270,7 +270,7 @@ def attr_key(self) -> str: @property def is_empty(self) -> bool: return self._is_empty - + @property def required(self) -> bool: return self._required diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 406d2f93f1..3f4032d188 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -188,7 +188,9 @@ def __init__( super().__init__(adata) if n_genes is None or n_regions is None: - assert isinstance(adata, MuData), "n_genes and n_regions must be provided if using AnnData" + assert isinstance( + adata, MuData + ), "n_genes and n_regions must be provided if using AnnData" n_genes = self.summary_stats.get("n_vars", 0) n_regions = self.summary_stats.get("n_atac", 0) @@ -619,7 +621,7 @@ def get_accessibility_estimates( imputed = vstack(imputed, format="csr") else: # imputed is a list of tensors imputed = torch.cat(imputed).numpy() - print('SDSDSD', imputed.shape) + print("SDSDSD", imputed.shape) print(adata["rna"].var_names[self.n_genes :][region_mask].shape) print(adata.obs_names[indices].shape) if return_numpy: diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index a818429531..b6ce86b4ff 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -155,7 +155,7 @@ def test_multivi_mudata_trimodal_external(): modalities={ "rna_layer": "rna", "atac_layer": "accessibility", - "protein_layer": "protein_expression" + "protein_layer": "protein_expression", }, ) model = MULTIVI(mdata) @@ -330,9 +330,10 @@ def test_multivi_model_library_size_mudata(): def test_multivi_size_factor_mudata(): mdata = synthetic_iid(return_mudata=True) - mdata.obs['size_factor_rna'] = mdata["rna"].X.sum(1) - mdata.obs['size_factor_atac'] = ( - mdata["accessibility"].X.sum(1) + 1) / (np.max(mdata["accessibility"].X.sum(1)) + 1.01) + mdata.obs["size_factor_rna"] = mdata["rna"].X.sum(1) + mdata.obs["size_factor_atac"] = (mdata["accessibility"].X.sum(1) + 1) / ( + np.max(mdata["accessibility"].X.sum(1)) + 1.01 + ) MULTIVI.setup_mudata( mdata, modalities={"rna_layer": "rna", "atac_layer": "accessibility"}, From 101641a2692866c4567746a686416836f59538f0 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 20 Nov 2024 11:03:39 +0200 Subject: [PATCH 38/51] following can's fixes --- src/scvi/data/fields/_arraylike_field.py | 13 +++- src/scvi/model/_multivi.py | 26 +++++-- src/scvi/module/_multivae.py | 62 ++++++++-------- tests/model/test_multivi.py | 93 +++++++++++++++--------- 4 files changed, 120 insertions(+), 74 deletions(-) diff --git a/src/scvi/data/fields/_arraylike_field.py b/src/scvi/data/fields/_arraylike_field.py index 4e02751d8c..b4dafc9aed 100644 --- a/src/scvi/data/fields/_arraylike_field.py +++ b/src/scvi/data/fields/_arraylike_field.py @@ -212,6 +212,8 @@ class BaseJointField(BaseArrayLikeField): Sequence of keys to combine to form the obsm or varm field. field_type Type of field. Can be either 'obsm' or 'varm'. + required + If True, the field must be present in the AnnData object """ def __init__( @@ -219,6 +221,7 @@ def __init__( registry_key: str, attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, + required: bool = True, ) -> None: super().__init__(registry_key) if field_type == "obsm": @@ -232,6 +235,7 @@ def __init__( self._attr_key = f"_scvi_{registry_key}" self._attr_keys = attr_keys if attr_keys is not None else [] self._is_empty = len(self.attr_keys) == 0 + self._required = required def validate_field(self, adata: AnnData) -> None: """Validate the field.""" @@ -267,6 +271,10 @@ def attr_key(self) -> str: def is_empty(self) -> bool: return self._is_empty + @property + def required(self) -> bool: + return self._required + class NumericalJointField(BaseJointField): """An AnnDataField for a collection of numerical .obs or .var fields in AnnData. @@ -282,6 +290,8 @@ class NumericalJointField(BaseJointField): Sequence of keys to combine to form the obsm or varm field. field_type Type of field. Can be either 'obsm' or 'varm'. + required + If True, the field must be present in the AnnData object """ COLUMNS_KEY = "columns" @@ -291,8 +301,9 @@ def __init__( registry_key: str, attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, + required: bool = True, ) -> None: - super().__init__(registry_key, attr_keys, field_type=field_type) + super().__init__(registry_key, attr_keys, field_type=field_type, required=required) self.count_stat_key = f"n_{self.registry_key}" diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 80b2f13ebd..0de7ab328b 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -145,8 +145,8 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): def __init__( self, adata: AnnOrMuData, - n_genes: int, - n_regions: int, + n_genes: int | None = None, + n_regions: int | None = None, modality_weights: Literal["equal", "cell", "universal"] = "equal", modality_penalty: Literal["Jeffreys", "MMD", "None"] = "Jeffreys", n_hidden: int | None = None, @@ -168,6 +168,13 @@ def __init__( ): super().__init__(adata) + if n_genes is None or n_regions is None: + assert isinstance( + adata, MuData + ), "n_genes and n_regions must be provided if using AnnData" + n_genes = self.summary_stats.get("n_vars", 0) + n_regions = self.summary_stats.get("n_atac", 0) + prior_mean, prior_scale = None, None n_cats_per_cov = ( self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key @@ -580,6 +587,10 @@ def get_accessibility_estimates( else: # imputed is a list of tensors imputed = torch.cat(imputed).numpy() + print("SDSDSD", imputed.shape) + print(adata["rna"].var_names[self.n_genes :][region_mask].shape) + print(adata.obs_names[indices].shape) + if return_numpy: return imputed elif threshold: @@ -1100,7 +1111,7 @@ def setup_anndata( batch_field, CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + NumericalJointObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), @@ -1158,7 +1169,10 @@ def setup_mudata( protein_layer Protein layer key. If `None`, will use `.X` of specified modality key. %(param_batch_key)s - %(param_size_factor_key)s + size_factor_key + Key in `mdata.obsm` for size factors. The first column corresponds to RNA size factors, + the second to ATAC size factors. + The second column need to be normalized and between 0 and 1. %(param_cat_cov_keys)s %(param_cont_cov_keys)s %(idx_layer)s @@ -1191,10 +1205,10 @@ def setup_mudata( None, mod_key=None, ), - fields.MuDataNumericalObsField( + fields.MuDataNumericalJointObsField( REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, - mod_key=modalities.size_factor_key, + mod_key=None, required=False, ), fields.MuDataCategoricalJointObsField( diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 6ad24b65d2..e35c0d418d 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -259,8 +259,6 @@ class MULTIVAE(BaseModuleClass): RNA distribution. """ - # TODO: replace n_input_regions and n_input_genes with a gene/region mask (we don't dictate - # which comes first or that they're even contiguous) def __init__( self, n_input_regions: int = 0, @@ -301,7 +299,7 @@ def __init__( if n_input_regions == 0: self.n_hidden = np.min([128, int(np.sqrt(self.n_input_genes))]) else: - self.n_hidden = int(np.sqrt(self.n_input_regions)) + self.n_hidden = np.min([128, int(np.sqrt(self.n_input_regions))]) else: self.n_hidden = n_hidden self.n_batch = n_batch @@ -533,7 +531,12 @@ def __init__( def _get_inference_input(self, tensors): """Get input tensors for the inference model.""" - x = tensors[REGISTRY_KEYS.X_KEY] + x = tensors.get(REGISTRY_KEYS.X_KEY, None) + x_atac = tensors.get(REGISTRY_KEYS.ATAC_X_KEY, None) + if x is not None and x_atac is not None: + x = torch.cat((x, x_atac), dim=-1) + elif x is None: + x = x_atac if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: @@ -543,6 +546,7 @@ def _get_inference_input(self, tensors): cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY) cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY) label = tensors[REGISTRY_KEYS.LABELS_KEY] + size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None) input_dict = { "x": x, "y": y, @@ -551,6 +555,7 @@ def _get_inference_input(self, tensors): "cat_covs": cat_covs, "label": label, "cell_idx": cell_idx, + "size_factor": size_factor, } return input_dict @@ -564,6 +569,7 @@ def inference( cat_covs, label, cell_idx, + size_factor, n_samples=1, ) -> dict[str, torch.Tensor]: """Run the inference model.""" @@ -573,21 +579,21 @@ def inference( else: x_rna = x[:, : self.n_input_genes] if self.n_input_regions == 0: - x_chr = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) + x_atac = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: - x_chr = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] mask_expr = x_rna.sum(dim=1) > 0 - mask_acc = x_chr.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 mask_pro = y.sum(dim=1) > 0 if cont_covs is not None and self.encode_covariates: encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1) - encoder_input_accessibility = torch.cat((x_chr, cont_covs), dim=-1) + encoder_input_accessibility = torch.cat((x_atac, cont_covs), dim=-1) encoder_input_protein = torch.cat((y, cont_covs), dim=-1) else: encoder_input_expression = x_rna - encoder_input_accessibility = x_chr + encoder_input_accessibility = x_atac encoder_input_protein = y if cat_covs is not None and self.encode_covariates: @@ -607,12 +613,16 @@ def inference( ) # L encoders - libsize_expr = self.l_encoder_expression( - encoder_input_expression, batch_index, *categorical_input - ) - libsize_acc = self.l_encoder_accessibility( - encoder_input_accessibility, batch_index, *categorical_input - ) + if self.use_size_factor_key: + libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) + libsize_acc = size_factor[:, [1]] + else: + libsize_expr = self.l_encoder_expression( + encoder_input_expression, batch_index, *categorical_input + ) + libsize_acc = self.l_encoder_accessibility( + encoder_input_accessibility, batch_index, *categorical_input + ) # mix representations if self.modality_weights == "cell": @@ -651,6 +661,7 @@ def unsqz(zt, n_s): z = self.z_encoder_accessibility.z_transformation(untran_z) outputs = { + "x": x, "z": z, "qz_m": qz_m, "qz_v": qz_v, @@ -674,11 +685,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None qz_m = inference_outputs["qz_m"] libsize_expr = inference_outputs["libsize_expr"] - size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY - size_factor = ( - torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None - ) - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None @@ -698,7 +704,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None "cont_covs": cont_covs, "cat_covs": cat_covs, "libsize_expr": libsize_expr, - "size_factor": size_factor, "label": label, } return input_dict @@ -712,7 +717,6 @@ def generative( cont_covs=None, cat_covs=None, libsize_expr=None, - size_factor=None, use_z_mean=False, label: torch.Tensor = None, ): @@ -736,12 +740,10 @@ def generative( p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) # Expression Decoder - if not self.use_size_factor_key: - size_factor = libsize_expr px_scale, _, px_rate, px_dropout = self.z_decoder_expression( self.gene_dispersion, decoder_input, - size_factor, + libsize_expr, batch_index, *categorical_input, label, @@ -783,24 +785,23 @@ def generative( def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): """Computes the loss function for the model.""" # Get the data - x = tensors[REGISTRY_KEYS.X_KEY] + x = inference_outputs["x"] - # TODO: CHECK IF THIS FAILS IN ONLY RNA DATA x_rna = x[:, : self.n_input_genes] - x_chr = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] mask_expr = x_rna.sum(dim=1) > 0 - mask_acc = x_chr.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 mask_pro = y.sum(dim=1) > 0 # Compute Accessibility loss p = generative_outputs["p"] libsize_acc = inference_outputs["libsize_acc"] - rl_accessibility = self.get_reconstruction_loss_accessibility(x_chr, p, libsize_acc) + rl_accessibility = self.get_reconstruction_loss_accessibility(x_atac, p, libsize_acc) # Compute Expression loss px_rate = generative_outputs["px_rate"] @@ -819,7 +820,6 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float rl_protein = torch.zeros(x.shape[0], device=x.device, requires_grad=False) # calling without weights makes this act like a masked sum - # TODO : CHECK MIXING HERE recon_loss_expression = rl_expression * mask_expr recon_loss_accessibility = rl_accessibility * mask_acc recon_loss_protein = rl_protein * mask_pro diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 3a05a773e0..b6ce86b4ff 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -42,7 +42,7 @@ def test_multivi(): # Test with size factor data = synthetic_iid() data.obs["size_factor"] = np.random.randint(1, 5, size=(data.shape[0],)) - MULTIVI.setup_anndata(data, batch_key="batch", size_factor_key="size_factor") + MULTIVI.setup_anndata(data, batch_key="batch") vae = MULTIVI( data, n_genes=50, @@ -87,11 +87,9 @@ def test_multivi_single_batch(): def test_multivi_mudata_rna_prot_external(): - # Example on how to download protein adata to mudata (from multivi tutotial) - mudata RNA/PROT + # Example on how to download protein adata to mudata (from multivi tutorial) - mudata RNA/PROT adata = scvi.data.pbmcs_10x_cite_seq() adata.layers["counts"] = adata.X.copy() - sc.pp.normalize_total(adata) - sc.pp.log1p(adata) adata.obs_names_make_unique() protein_adata = ad.AnnData(adata.obsm["protein_expression"]) protein_adata.obs_names = adata.obs_names @@ -119,24 +117,19 @@ def test_multivi_mudata_rna_prot_external(): "batch_key": "rna_subset", }, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) model.train(1, train_size=0.9) def test_multivi_mudata_rna_atac_external(): # optional data - mudata RNA/ATAC mdata = synthetic_iid(return_mudata=True) - # Preprocessing - sc.pp.normalize_total(mdata.mod["rna"]) - sc.pp.log1p(mdata.mod["rna"]) sc.pp.highly_variable_genes( mdata.mod["rna"], n_top_genes=4000, flavor="seurat_v3", ) mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy() - sc.pp.normalize_total(mdata.mod["accessibility"]) - sc.pp.log1p(mdata.mod["accessibility"]) sc.pp.highly_variable_genes( mdata.mod["accessibility"], n_top_genes=4000, @@ -146,15 +139,48 @@ def test_multivi_mudata_rna_atac_external(): :, mdata.mod["accessibility"].var["highly_variable"] ].copy() mdata.update() - # mdata - # mdata.mod MULTIVI.setup_mudata( - mdata, modalities={"rna_layer": "rna_subset", "atac_layer": "atac_subset"} + mdata, + modalities={"rna_layer": "rna_subset", "atac_layer": "atac_subset"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) model.train(1, train_size=0.9) +def test_multivi_mudata_trimodal_external(): + # optional data - mudata RNA/ATAC + mdata = synthetic_iid(return_mudata=True) + MULTIVI.setup_mudata( + mdata, + modalities={ + "rna_layer": "rna", + "atac_layer": "accessibility", + "protein_layer": "protein_expression", + }, + ) + model = MULTIVI(mdata) + model.train(1, train_size=0.9) + model.train(1, train_size=0.9) + assert model.is_trained is True + model.get_latent_representation() + model.get_elbo() + model.get_reconstruction_error() + model.get_normalized_expression() + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + model.get_elbo(indices=model.validation_indices) + model.get_reconstruction_error(indices=model.validation_indices) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + def test_multivi_mudata(): # use of syntetic data of rna/proteins/atac for speed @@ -171,7 +197,7 @@ def test_multivi_mudata(): n_obs = mdata.n_obs n_latent = 10 - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) model.train(1, train_size=0.9) assert model.is_trained is True z = model.get_latent_representation() @@ -224,7 +250,7 @@ def test_multivi_auto_transfer_mudata(): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) @@ -247,7 +273,7 @@ def test_multivi_incorrect_mapping_mudata(): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) @@ -266,7 +292,7 @@ def test_multivi_reordered_mapping_mudata(): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) adata2 = synthetic_iid() protein_adata2 = synthetic_iid(n_genes=50) mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) @@ -291,7 +317,7 @@ def test_multivi_model_library_size_mudata(): ) n_latent = 10 - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) model.train(1, train_size=0.5) assert model.is_trained is True model.get_elbo() @@ -303,30 +329,25 @@ def test_multivi_model_library_size_mudata(): def test_multivi_size_factor_mudata(): - adata = synthetic_iid() - adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) - protein_adata = synthetic_iid(n_genes=50) - mdata = MuData({"rna": adata, "protein": protein_adata}) + mdata = synthetic_iid(return_mudata=True) + mdata.obs["size_factor_rna"] = mdata["rna"].X.sum(1) + mdata.obs["size_factor_atac"] = (mdata["accessibility"].X.sum(1) + 1) / ( + np.max(mdata["accessibility"].X.sum(1)) + 1.01 + ) MULTIVI.setup_mudata( mdata, - batch_key="batch", - size_factor_key="size_factor", - modalities={ - "rna_layer": "rna", - "batch_key": "rna", - "protein_layer": "protein", - "size_factor_key": "rna", - }, + modalities={"rna_layer": "rna", "atac_layer": "accessibility"}, + size_factor_key=["size_factor_rna", "size_factor_atac"], ) n_latent = 10 # Test size_factor_key overrides use_observed_lib_size. - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) assert model.module.use_size_factor_key model.train(1, train_size=0.5) - model = MULTIVI(mdata, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata, n_latent=n_latent) assert model.module.use_size_factor_key model.train(1, train_size=0.5) @@ -340,7 +361,7 @@ def test_multivi_saving_and_loading_mudata(save_path: str): batch_key="batch", modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) model.train(1, train_size=0.2) z1 = model.get_latent_representation(mdata) test_idx1 = model.validation_indices @@ -402,7 +423,7 @@ def test_scarches_mudata_prep_layer(save_path: str): batch_key="batch", modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, ) - model = MULTIVI(mdata1, n_latent=n_latent, n_genes=50, n_regions=50) + model = MULTIVI(mdata1, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) @@ -448,7 +469,7 @@ def test_multivi_save_load_mudata_format(save_path: str): mdata, modalities={"rna_layer": "rna", "protein_layer": "protein"}, ) - model = MULTIVI(mdata, n_genes=50, n_regions=50) + model = MULTIVI(mdata) model.train(max_epochs=1) legacy_model_path = os.path.join(save_path, "legacy_model") From b32a92f59e05fee66baab741fff0d36a9866ed5f Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 20 Nov 2024 11:09:59 +0200 Subject: [PATCH 39/51] update branch --- src/scvi/model/_multivi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 3f4032d188..6e0281eea6 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -621,9 +621,11 @@ def get_accessibility_estimates( imputed = vstack(imputed, format="csr") else: # imputed is a list of tensors imputed = torch.cat(imputed).numpy() + print("SDSDSD", imputed.shape) print(adata["rna"].var_names[self.n_genes :][region_mask].shape) print(adata.obs_names[indices].shape) + if return_numpy: return imputed elif threshold: @@ -1219,7 +1221,6 @@ def setup_mudata( %(idx_layer)s %(param_modalities)s - Examples -------- >>> mdata = muon.read_10x_h5("filtered_feature_bc_matrix.h5") From a0cd0bd3fc8aa08faa3650fc1d6b43dc9ab9d3bf Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 20 Nov 2024 15:07:38 +0200 Subject: [PATCH 40/51] fix get_accessibility was using gene indices, should have used regions instead --- src/scvi/model/_multivi.py | 41 ++++++++++++++++++++----------------- tests/model/test_multivi.py | 8 +++++--- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 0de7ab328b..4406593d05 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -554,7 +554,7 @@ def get_accessibility_estimates( if region_list is None: region_mask = slice(None) else: - region_mask = [region in region_list for region in adata.var_names[self.n_genes :]] + region_mask = [region in region_list for region in adata.var_names[:self.n_regions]] if threshold is not None and (threshold < 0 or threshold > 1): raise ValueError("the provided threshold must be between 0 and 1") @@ -587,26 +587,29 @@ def get_accessibility_estimates( else: # imputed is a list of tensors imputed = torch.cat(imputed).numpy() - print("SDSDSD", imputed.shape) - print(adata["rna"].var_names[self.n_genes :][region_mask].shape) - print(adata.obs_names[indices].shape) - - if return_numpy: - return imputed - elif threshold: - return pd.DataFrame.sparse.from_spmatrix( - imputed, - index=adata.obs_names[indices], - columns=adata.var_names[self.n_genes :][region_mask], - ) - else: + if np.all(imputed is None): return pd.DataFrame( imputed, index=adata.obs_names[indices], - columns=adata["rna"].var_names[self.n_genes :][region_mask] - if isinstance(adata, MuData) - else adata.var_names[self.n_genes :][region_mask], + columns=[], ) + else: + if return_numpy: + return imputed + elif threshold: + return pd.DataFrame.sparse.from_spmatrix( + imputed, + index=adata.obs_names[indices], + columns=adata["rna"].var_names[:self.n_regions][region_mask] if + isinstance(adata, MuData) else adata.var_names[:self.n_regions][region_mask], + ) + else: + return pd.DataFrame( + imputed, + index=adata.obs_names[indices], + columns=adata["rna"].var_names[:self.n_regions][region_mask] if + isinstance(adata, MuData) else adata.var_names[:self.n_regions][region_mask], + ) @torch.inference_mode() def get_normalized_expression( @@ -793,7 +796,7 @@ def differential_accessibility( """ self._check_adata_modality_weights(adata) adata = self._validate_anndata(adata) - col_names = adata.var_names[self.n_genes :] + col_names = adata.var_names[:self.n_genes] model_fn = partial( self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size ) @@ -814,7 +817,7 @@ def m1_domain_fn(samples): all_stats_fn = partial( scatac_raw_counts_properties, - var_idx=np.arange(adata.shape[1])[self.n_genes :], + var_idx=np.arange(adata.shape[1])[:self.n_genes], ) result = _de_core( diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index b6ce86b4ff..210baec931 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -181,7 +181,9 @@ def test_multivi_mudata_trimodal_external(): model.get_region_factors() -def test_multivi_mudata(): +@pytest.mark.parametrize("n_genes", [25, 50, 100]) +@pytest.mark.parametrize("n_regions", [25, 50, 100]) +def test_multivi_mudata(n_genes: int, n_regions: int): # use of syntetic data of rna/proteins/atac for speed mdata = synthetic_iid(return_mudata=True) @@ -197,7 +199,7 @@ def test_multivi_mudata(): n_obs = mdata.n_obs n_latent = 10 - model = MULTIVI(mdata, n_latent=n_latent) + model = MULTIVI(mdata, n_latent=n_latent, n_genes=n_genes, n_regions=n_regions) model.train(1, train_size=0.9) assert model.is_trained is True z = model.get_latent_representation() @@ -227,7 +229,7 @@ def test_multivi_mudata(): modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, ) norm_exp = model.get_normalized_expression(mdata2, indices=[1, 2, 3]) - assert norm_exp.shape == (3, 50) + assert norm_exp.shape == (3, n_genes) # test transfer_anndata_setup + view mdata3 = synthetic_iid(return_mudata=True) From 711ec548a361c734cd5474fc15f487e65319cb3d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Nov 2024 13:09:03 +0000 Subject: [PATCH 41/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_multivi.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 4406593d05..d9c9871169 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -554,7 +554,7 @@ def get_accessibility_estimates( if region_list is None: region_mask = slice(None) else: - region_mask = [region in region_list for region in adata.var_names[:self.n_regions]] + region_mask = [region in region_list for region in adata.var_names[: self.n_regions]] if threshold is not None and (threshold < 0 or threshold > 1): raise ValueError("the provided threshold must be between 0 and 1") @@ -600,15 +600,17 @@ def get_accessibility_estimates( return pd.DataFrame.sparse.from_spmatrix( imputed, index=adata.obs_names[indices], - columns=adata["rna"].var_names[:self.n_regions][region_mask] if - isinstance(adata, MuData) else adata.var_names[:self.n_regions][region_mask], + columns=adata["rna"].var_names[: self.n_regions][region_mask] + if isinstance(adata, MuData) + else adata.var_names[: self.n_regions][region_mask], ) else: return pd.DataFrame( imputed, index=adata.obs_names[indices], - columns=adata["rna"].var_names[:self.n_regions][region_mask] if - isinstance(adata, MuData) else adata.var_names[:self.n_regions][region_mask], + columns=adata["rna"].var_names[: self.n_regions][region_mask] + if isinstance(adata, MuData) + else adata.var_names[: self.n_regions][region_mask], ) @torch.inference_mode() @@ -796,7 +798,7 @@ def differential_accessibility( """ self._check_adata_modality_weights(adata) adata = self._validate_anndata(adata) - col_names = adata.var_names[:self.n_genes] + col_names = adata.var_names[: self.n_genes] model_fn = partial( self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size ) @@ -817,7 +819,7 @@ def m1_domain_fn(samples): all_stats_fn = partial( scatac_raw_counts_properties, - var_idx=np.arange(adata.shape[1])[:self.n_genes], + var_idx=np.arange(adata.shape[1])[: self.n_genes], ) result = _de_core( From afaf1b3a51a932aba47c633e0254d029bda04c60 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 20 Nov 2024 17:26:09 +0200 Subject: [PATCH 42/51] some fixes --- src/scvi/data/fields/_arraylike_field.py | 10 +++++++--- src/scvi/model/_multivi.py | 8 ++++---- src/scvi/module/_multivae.py | 17 +++++++++++++---- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/scvi/data/fields/_arraylike_field.py b/src/scvi/data/fields/_arraylike_field.py index b4dafc9aed..73fed5b0c1 100644 --- a/src/scvi/data/fields/_arraylike_field.py +++ b/src/scvi/data/fields/_arraylike_field.py @@ -240,9 +240,13 @@ def __init__( def validate_field(self, adata: AnnData) -> None: """Validate the field.""" super().validate_field(adata) - for key in self.attr_keys: - if key not in getattr(adata, self.source_attr_name): - raise KeyError(f"{key} not found in adata.{self.source_attr_name}.") + if isinstance(self.attr_keys, str): + if self.attr_keys not in getattr(adata, self.source_attr_name): + raise KeyError(f"{self.attr_keys} not found in adata.{self.source_attr_name}.") + else: + for key in self.attr_keys: + if key not in getattr(adata, self.source_attr_name): + raise KeyError(f"{key} not found in adata.{self.source_attr_name}.") def _combine_fields(self, adata: AnnData) -> None: """Combine the .obs or .var fields into a single .obsm or .varm field.""" diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 6233a8b889..b8538b7889 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -1159,7 +1159,7 @@ def setup_anndata( batch_field, CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - NumericalJointObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), @@ -1253,7 +1253,7 @@ def setup_mudata( None, mod_key=None, ), - fields.MuDataNumericalJointObsField( + fields.MuDataNumericalObsField( REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, mod_key=None, @@ -1382,8 +1382,8 @@ def minify_mudata( if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - if self.module.use_size_factor is False: - raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") + if self.module.use_size_factor_key is False: + raise ValueError("Cannot minify the data if `use_size_factor_key` is False") minified_adata = get_minified_mudata(self.adata, minified_data_type) minified_adata.obsm[_MULTIVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 5bda99cb7d..7d7db15815 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -15,7 +15,7 @@ ZeroInflatedNegativeBinomial, ) from scvi.module._peakvae import Decoder as DecoderPeakVI -from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderSCVI, Encoder, FCLayers from ._utils import masked_softmax @@ -179,7 +179,7 @@ def forward(self, z: torch.Tensor, *cat_list: int): return py_, log_pro_back_mean -class MULTIVAE(BaseMinifiedModeModuleClass): +class MULTIVAE(BaseModuleClass): """Variational auto-encoder model for joint paired + unpaired RNA-seq and ATAC-seq data. Parameters @@ -618,7 +618,7 @@ def inference( # L encoders if self.use_size_factor_key: libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) - libsize_acc = size_factor[:, [1]] + libsize_acc = torch.log(size_factor[:, [0]] + 1e-6) else: libsize_expr = self.l_encoder_expression( encoder_input_expression, batch_index, *categorical_input @@ -688,6 +688,11 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None qz_m = inference_outputs["qz_m"] libsize_expr = inference_outputs["libsize_expr"] + size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY + size_factor = ( + torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None + ) + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None @@ -707,6 +712,7 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None "cont_covs": cont_covs, "cat_covs": cat_covs, "libsize_expr": libsize_expr, + "size_factor": size_factor, "label": label, } return input_dict @@ -720,6 +726,7 @@ def generative( cont_covs=None, cat_covs=None, libsize_expr=None, + size_factor=None, use_z_mean=False, label: torch.Tensor = None, ): @@ -743,10 +750,12 @@ def generative( p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) # Expression Decoder + if not self.use_size_factor_key: + size_factor = libsize_expr px_scale, _, px_rate, px_dropout = self.z_decoder_expression( self.gene_dispersion, decoder_input, - libsize_expr, + size_factor, batch_index, *categorical_input, label, From 02a8b8f5499de99c89d291ecd8acff474fe4a31c Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 20 Nov 2024 17:27:05 +0200 Subject: [PATCH 43/51] update tests --- .../test_models_with_mudata_minified_data.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py index 3022f48f60..3cd4560ca6 100644 --- a/tests/model/test_models_with_mudata_minified_data.py +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -188,20 +188,20 @@ def assert_approx_equal(a, b): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -@pytest.mark.parametrize("use_size_factor", [False, True]) +@pytest.mark.parametrize("use_size_factor", [True]) def test_with_minified_adata(cls, use_size_factor: bool): run_test_for_model_with_minified_adata(cls=cls, use_size_factor=use_size_factor) @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -@pytest.mark.parametrize("use_size_factor", [False, True]) +@pytest.mark.parametrize("use_size_factor", [True]) def test_with_minified_mudata(cls, use_size_factor: bool): run_test_for_model_with_minified_mudata(cls=cls, use_size_factor=use_size_factor) @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) def test_with_minified_mdata_get_normalized_expression(cls): - model, mdata, _, _ = prep_model_mudata(cls=cls) + model, mdata, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -225,7 +225,7 @@ def test_with_minified_mdata_get_normalized_expression(cls): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) def test_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls): - model, mdata, _, _ = prep_model_mudata(cls=cls) + model, mdata, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) # non-default gene list and n_samples > 1 gl = mdata.var_names[:5].to_list() @@ -258,7 +258,7 @@ def test_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) def test_validate_unsupported_if_minified(cls): - model, _, _, _ = prep_model_mudata(cls=cls) + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -288,11 +288,11 @@ def test_validate_unsupported_if_minified(cls): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_with_minified_mdata_save_then_load(cls, save_path="."): +def test_with_minified_mdata_save_then_load(cls, save_path): # create a model and minify its mdata, then save it and its mdata. # Load it back up using the same (minified) mdata. Validate that the # loaded model has the minified_data_type attribute set as expected. - model, mdata, _, _ = prep_model_mudata(cls=cls) + model, mdata, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -310,11 +310,11 @@ def test_with_minified_mdata_save_then_load(cls, save_path="."): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path="."): +def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path): # create a model and minify its mdata, then save it and its mdata. # Load it back up using a non-minified mdata. Validate that the # loaded model does not has the minified_data_type attribute set. - model, mdata, _, mdata_before_setup = prep_model_mudata(cls=cls) + model, mdata, _, mdata_before_setup = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -331,12 +331,12 @@ def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_pa @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) -def test_save_then_load_with_minified_mdata(cls, save_path="."): +def test_save_then_load_with_minified_mdata(cls, save_path): # create a model, then save it and its mdata (non-minified). # Load it back up using a minified mdata. Validate that this # fails, as expected because we don't have a way to validate # whether the minified-mdata was set up correctly - model, _, _, _ = prep_model_mudata(cls=cls) + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -356,7 +356,7 @@ def test_save_then_load_with_minified_mdata(cls, save_path="."): @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) def test_with_minified_mdata_get_latent_representation(cls): - model, _, _, _ = prep_model_mudata(cls=cls) + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -374,7 +374,7 @@ def test_with_minified_mdata_get_latent_representation(cls): @pytest.mark.parametrize("cls", [TOTALVI]) def test_with_minified_mdata_posterior_predictive_sample(cls): - model, _, _, _ = prep_model_mudata(cls=cls) + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm @@ -393,7 +393,7 @@ def test_with_minified_mdata_posterior_predictive_sample(cls): @pytest.mark.parametrize("cls", [TOTALVI]) def test_with_minified_mdata_get_feature_correlation_matrix(cls): - model, _, _, _ = prep_model_mudata(cls=cls) + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm From 69c06e4e4c31eb1dac3a70e0086ca04257bc3867 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 20 Nov 2024 18:31:12 +0200 Subject: [PATCH 44/51] fix tests, save/load --- src/scvi/model/base/_base_model.py | 1 + .../test_models_with_mudata_minified_data.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index bb9dadded2..27d9d7a646 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -1067,6 +1067,7 @@ def _update_mudata_and_manager_post_minification( new_adata_manager.register_new_fields( self._get_fields_for_mudata_minification(minified_data_type) ) + new_adata_manager.registry["setup_method_name"] = "setup_mudata" # We set the adata attribute of the model as this will update self.registry_ # and self.adata_manager with the new adata manager self.adata = minified_adata diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py index 3cd4560ca6..0d5aeadeb2 100644 --- a/tests/model/test_models_with_mudata_minified_data.py +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -2,6 +2,7 @@ import pytest # from mudata import MuData +import scvi from scvi.data import synthetic_iid from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified @@ -203,16 +204,20 @@ def test_with_minified_mudata(cls, use_size_factor: bool): def test_with_minified_mdata_get_normalized_expression(cls): model, mdata, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv + scvi.settings.seed = 1 exprs_orig = model.get_normalized_expression() model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + scvi.settings.seed = 1 exprs_new = model.get_normalized_expression() + if type(exprs_new) is tuple: for ii in range(len(exprs_new)): assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape @@ -231,19 +236,22 @@ def test_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls gl = mdata.var_names[:5].to_list() n_samples = 10 + scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv + scvi.settings.seed = 1 exprs_orig = model.get_normalized_expression( - gene_list=gl, n_samples=n_samples, library_size="latent" + gene_list=gl, n_samples=n_samples, library_size="latent", return_mean=False ) model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + scvi.settings.seed = 1 exprs_new = model.get_normalized_expression( - gene_list=gl, n_samples=n_samples + 1, library_size="latent" + gene_list=gl, n_samples=n_samples, library_size="latent", return_mean=False ) if type(exprs_new) is tuple: @@ -301,12 +309,12 @@ def test_with_minified_mdata_save_then_load(cls, save_path): model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - model.save(save_path, overwrite=True, save_anndata=True, legacy_mudata_format=True) + model.save(save_path, overwrite=True, save_anndata=True) model.view_setup_args(save_path) # load saved model with saved (minified) mdata loaded_model = cls.load(save_path, adata=mdata) - assert loaded_model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + assert loaded_model.minified_data_type is None @pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) @@ -376,15 +384,18 @@ def test_with_minified_mdata_get_latent_representation(cls): def test_with_minified_mdata_posterior_predictive_sample(cls): model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + scvi.settings.seed = 1 qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) model.adata.obsm["X_latent_qzm"] = qzm model.adata.obsm["X_latent_qzv"] = qzv + scvi.settings.seed = 1 sample_orig = model.posterior_predictive_sample() model.minify_mudata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + scvi.settings.seed = 1 sample_new = model.posterior_predictive_sample() # assert sample_new.shape == (3, 2) From b5eb0a665b2ace3c1b27f30e1d6f43b36d117128 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 20 Nov 2024 19:22:46 +0200 Subject: [PATCH 45/51] more fixes. --- src/scvi/model/_multivi.py | 4 ++-- src/scvi/module/_multivae.py | 13 ++---------- .../test_models_with_mudata_minified_data.py | 20 +++++++++++++------ 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index b8538b7889..4c9c6b2090 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -1159,7 +1159,7 @@ def setup_anndata( batch_field, CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + NumericalJointObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), @@ -1253,7 +1253,7 @@ def setup_mudata( None, mod_key=None, ), - fields.MuDataNumericalObsField( + fields.MuDataNumericalJointObsField( REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, mod_key=None, diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 7d7db15815..e3c319ec16 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -618,7 +618,7 @@ def inference( # L encoders if self.use_size_factor_key: libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) - libsize_acc = torch.log(size_factor[:, [0]] + 1e-6) + libsize_acc = size_factor[:, [1]] else: libsize_expr = self.l_encoder_expression( encoder_input_expression, batch_index, *categorical_input @@ -688,11 +688,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None qz_m = inference_outputs["qz_m"] libsize_expr = inference_outputs["libsize_expr"] - size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY - size_factor = ( - torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None - ) - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None @@ -712,7 +707,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None "cont_covs": cont_covs, "cat_covs": cat_covs, "libsize_expr": libsize_expr, - "size_factor": size_factor, "label": label, } return input_dict @@ -726,7 +720,6 @@ def generative( cont_covs=None, cat_covs=None, libsize_expr=None, - size_factor=None, use_z_mean=False, label: torch.Tensor = None, ): @@ -750,12 +743,10 @@ def generative( p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) # Expression Decoder - if not self.use_size_factor_key: - size_factor = libsize_expr px_scale, _, px_rate, px_dropout = self.z_decoder_expression( self.gene_dispersion, decoder_input, - size_factor, + libsize_expr, batch_index, *categorical_input, label, diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py index 0d5aeadeb2..02db23e1bb 100644 --- a/tests/model/test_models_with_mudata_minified_data.py +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -92,7 +92,11 @@ def prep_model_mudata(cls=TOTALVI, use_size_factor=False): # "accessibility": synthetic_iid()}) mdata = synthetic_iid(return_mudata=True) if use_size_factor: - mdata.obs["size_factor"] = np.random.randint(1, 5, size=(mdata.shape[0],)) + mdata.obs["size_factor_rna"] = mdata["rna"].X.sum(1) + mdata.obs["size_factor_atac"] = (mdata["accessibility"].X.sum(1) + 1) / ( + np.max(mdata["accessibility"].X.sum(1)) + 1.01 + ) + # mdata.obs["size_factor"] = np.random.randint(1, 5, size=(mdata.shape[0],)) # if layer is not None: # for mod in mdata.mod_names: # mdata[mod].layers[layer] = mdata[mod].X.copy() @@ -114,10 +118,11 @@ def prep_model_mudata(cls=TOTALVI, use_size_factor=False): setup_kwargs = { "batch_key": "batch", } - if use_size_factor: - setup_kwargs["size_factor_key"] = "size_factor" if cls == TOTALVI: + if use_size_factor: + setup_kwargs["size_factor_key"] = "size_factor_rna" + # create and train the model cls.setup_mudata( mdata, @@ -126,6 +131,9 @@ def prep_model_mudata(cls=TOTALVI, use_size_factor=False): ) model = cls(mdata, n_latent=5) elif cls == MULTIVI: + if use_size_factor: + setup_kwargs["size_factor_key"] = ["size_factor_rna", "size_factor_atac"] + # create and train the model cls.setup_mudata( mdata, @@ -188,7 +196,7 @@ def assert_approx_equal(a, b): np.testing.assert_allclose(a, b, rtol=3e-1, atol=5e-1) -@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +@pytest.mark.parametrize("cls", [TOTALVI]) @pytest.mark.parametrize("use_size_factor", [True]) def test_with_minified_adata(cls, use_size_factor: bool): run_test_for_model_with_minified_adata(cls=cls, use_size_factor=use_size_factor) @@ -243,7 +251,7 @@ def test_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls scvi.settings.seed = 1 exprs_orig = model.get_normalized_expression( - gene_list=gl, n_samples=n_samples, library_size="latent", return_mean=False + gene_list=gl, n_samples=n_samples, library_size="latent" ) model.minify_mudata() @@ -251,7 +259,7 @@ def test_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls scvi.settings.seed = 1 exprs_new = model.get_normalized_expression( - gene_list=gl, n_samples=n_samples, library_size="latent", return_mean=False + gene_list=gl, n_samples=n_samples, library_size="latent" ) if type(exprs_new) is tuple: From cd9c5b2803a06793a8ea087c8b32a4095f97e36b Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Thu, 21 Nov 2024 00:25:35 -0800 Subject: [PATCH 46/51] Refactor code and add tests --- src/scvi/model/_multivi.py | 65 ++++++++++++++----- src/scvi/module/_multivae.py | 12 ++-- src/scvi/module/_totalvae.py | 122 ++++++++++++++++++++++------------- tests/model/test_multivi.py | 44 +++++-------- 4 files changed, 149 insertions(+), 94 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 4c9c6b2090..27d9cd965c 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -193,7 +193,17 @@ def __init__( ), "n_genes and n_regions must be provided if using AnnData" n_genes = self.summary_stats.get("n_vars", 0) n_regions = self.summary_stats.get("n_atac", 0) - + if isinstance(adata, MuData): + assert ( + n_genes == self.summary_stats.get("n_vars", 0) + ), "n_genes must match MuData" + assert ( + n_regions == self.summary_stats.get("n_atac", 0) + ), "n_regions must match MuData" + if modality_weights == "cell": + assert ( + self.registry_['setup_args']['index_key'] is not None + ), "index_key must be set if using cell modality weights" prior_mean, prior_scale = None, None n_cats_per_cov = ( self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key @@ -254,6 +264,10 @@ def __init__( self.n_regions = n_regions self.n_proteins = n_proteins self.module.minified_data_type = self.minified_data_type + if isinstance(adata, MuData): + self.modality_keys = self.get_anndata_manager(adata).registry['setup_args']['modalities'] + else: + self.modality_keys = None @devices_dsp.dedent def train( @@ -525,7 +539,7 @@ def get_latent_representation( def get_accessibility_estimates( self, adata: AnnOrMuData | None = None, - indices: Sequence[int] = None, + indices: Sequence[int] | None = None, n_samples_overall: int | None = None, region_list: Sequence[str] | None = None, transform_batch: str | int | None = None, @@ -629,23 +643,23 @@ def get_accessibility_estimates( columns=[], ) else: + if isinstance(adata, MuData): + peak_names = adata[self.modality_keys['atac_layer']].var_names[region_mask] + else: + peak_names = adata.var_names[self.n_regions:][region_mask] if return_numpy: return imputed elif threshold: return pd.DataFrame.sparse.from_spmatrix( imputed, index=adata.obs_names[indices], - columns=adata["rna"].var_names[: self.n_regions][region_mask] - if isinstance(adata, MuData) - else adata.var_names[: self.n_regions][region_mask], + columns=peak_names, ) else: return pd.DataFrame( imputed, index=adata.obs_names[indices], - columns=adata["rna"].var_names[: self.n_regions][region_mask] - if isinstance(adata, MuData) - else adata.var_names[: self.n_regions][region_mask], + columns=peak_names, ) @torch.inference_mode() @@ -760,9 +774,14 @@ def get_normalized_expression( if return_numpy: return exprs else: + if isinstance(adata, MuData): + gene_names = adata[self.modality_keys['rna_layer']].var_names[gene_mask] + else: + gene_names = adata.var_names[:self.n_genes][gene_mask] + return pd.DataFrame( exprs, - columns=adata.var_names[: self.n_genes][gene_mask], + columns=gene_names, index=adata.obs_names[indices], ) @@ -1126,6 +1145,7 @@ def setup_anndata( continuous_covariate_keys: list[str] | None = None, protein_expression_obsm_key: str | None = None, protein_names_uns_key: str | None = None, + index_key: str | None = None, **kwargs, ): """%(summary)s. @@ -1144,6 +1164,8 @@ def setup_anndata( key in `adata.uns` for protein names. If None, will use the column names of `adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign sequential names to proteins. + use_cell_indices + If True, will use the indices of the cells in the AnnData object. """ warnings.warn( "MULTIVI is supposed to work with MuData. the use of anndata is " @@ -1152,8 +1174,10 @@ def setup_anndata( stacklevel=settings.warnings_stacklevel, ) setup_method_args = cls._get_setup_method_args(**locals()) - adata.obs["_indices"] = np.arange(adata.n_obs) batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) + if index_key is not None: + if index_key not in adata.obs: + adata.obs[index_key] = np.arange(adata.n_obs) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), batch_field, @@ -1162,7 +1186,7 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), - NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), + NumericalObsField(REGISTRY_KEYS.INDICES_KEY, index_key, required=False), ] if protein_expression_obsm_key is not None: anndata_fields.append( @@ -1201,7 +1225,7 @@ def setup_mudata( size_factor_key: str | None = None, categorical_covariate_keys: list[str] | None = None, continuous_covariate_keys: list[str] | None = None, - idx_layer: str | None = None, + index_key: str | None = None, modalities: dict[str, str] | None = None, **kwargs, ): @@ -1223,7 +1247,8 @@ def setup_mudata( The second column need to be normalized and between 0 and 1. %(param_cat_cov_keys)s %(param_cont_cov_keys)s - %(idx_layer)s + index_key + Key in `mdata.obs` for cell indices. If `None`, will skip using indices. %(param_modalities)s Examples @@ -1239,7 +1264,15 @@ def setup_mudata( if modalities is None: raise ValueError("Modalities cannot be None.") modalities = cls._create_modalities_attr_dict(modalities, setup_method_args) - mdata.obs["_indices"] = np.arange(mdata.n_obs) + if index_key is not None: + if modalities.index_key is not None: + index_layer = mdata[modalities.index_key] + else: + index_layer = mdata + if "_indices" not in index_layer.obs: + index_layer.obs["_indices"] = np.arange(mdata.n_obs) + else: + index_key = None batch_field = fields.MuDataCategoricalObsField( REGISTRY_KEYS.BATCH_KEY, @@ -1271,8 +1304,8 @@ def setup_mudata( ), fields.MuDataNumericalObsField( REGISTRY_KEYS.INDICES_KEY, - "_indices", - mod_key=modalities.idx_layer, + index_key, + mod_key=modalities.index_key, required=False, ), ] diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index e3c319ec16..5a34d54cbe 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -545,7 +545,9 @@ def _get_inference_input(self, tensors): else: y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - cell_idx = tensors.get(REGISTRY_KEYS.INDICES_KEY).long().ravel() + cell_idx = tensors.get(REGISTRY_KEYS.INDICES_KEY) + if cell_idx is not None: + cell_idx = cell_idx.long().ravel() cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY) cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY) label = tensors[REGISTRY_KEYS.LABELS_KEY] @@ -620,18 +622,18 @@ def inference( libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) libsize_acc = size_factor[:, [1]] else: - libsize_expr = self.l_encoder_expression( - encoder_input_expression, batch_index, *categorical_input - ) libsize_acc = self.l_encoder_accessibility( encoder_input_accessibility, batch_index, *categorical_input ) + libsize_expr = self.l_encoder_expression( + encoder_input_expression, batch_index, *categorical_input + ) # mix representations if self.modality_weights == "cell": weights = self.mod_weights[cell_idx, :] else: - weights = self.mod_weights.unsqueeze(0).expand(len(cell_idx), -1) + weights = self.mod_weights.unsqueeze(0).expand(x.shape[0], -1) qz_m = mix_modalities( (qzm_expr, qzm_acc, qzm_pro), (mask_expr, mask_acc, mask_pro), weights diff --git a/src/scvi/module/_totalvae.py b/src/scvi/module/_totalvae.py index ba54ec0b6f..0abeeaa364 100644 --- a/src/scvi/module/_totalvae.py +++ b/src/scvi/module/_totalvae.py @@ -12,12 +12,14 @@ from scvi import REGISTRY_KEYS from scvi.data import _constants +from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.distributions import ( NegativeBinomial, NegativeBinomialMixture, ZeroInflatedNegativeBinomial, ) from scvi.model.base import BaseModelClass +from scvi.module._constants import MODULE_KEYS from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderTOTALVI, EncoderTOTALVI from scvi.nn._utils import ExpActivation @@ -324,28 +326,37 @@ def get_reconstruction_loss( return reconst_loss_gene, reconst_loss_protein - def _get_inference_input(self, tensors): - # from scvi.data._constants import ADATA_MINIFY_TYPE - # TODO: ADD MINIFICATION CONSIDERATION - - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - - cont_key = REGISTRY_KEYS.CONT_COVS_KEY - cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None - - cat_key = REGISTRY_KEYS.CAT_COVS_KEY - cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None - - input_dict = { - "x": x, - "y": y, - "batch_index": batch_index, - "cat_covs": cat_covs, - "cont_covs": cont_covs, - } - return input_dict + def _get_inference_input( + self, + tensors, + full_forward_pass: bool = False, + ) -> dict[str, torch.Tensor | None]: + """Get input tensors for the inference process.""" + if full_forward_pass or self.minified_data_type is None: + loader = "full_data" + elif self.minified_data_type in [ + ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS, + ]: + loader = "minified_data" + else: + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") + + if loader == "full_data": + return { + MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY], + MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY], + MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY], + MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None), + MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None), + } + else: + return { + MODULE_KEYS.QZM_KEY: tensors[REGISTRY_KEYS.LATENT_QZM_KEY], + MODULE_KEYS.QZV_KEY: tensors[REGISTRY_KEYS.LATENT_QZV_KEY], + REGISTRY_KEYS.OBSERVED_LIB_SIZE: tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE], + MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY], + } def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] @@ -436,7 +447,45 @@ def generative( } @auto_move_data - def inference( + def _cached_inference( + self, + qzm: torch.Tensor, + qzv: torch.Tensor, + batch_index: torch.Tensor, + observed_lib_size: torch.Tensor, + n_samples: int = 1, + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """Run the cached inference process.""" + library = observed_lib_size + qz = Normal(qzm, qzv) + untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) + z = self.encoder.z_transformation(untran_z) + library = torch.log(observed_lib_size) + if n_samples > 1: + library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) + + if self.n_batch > 0: + py_back_alpha_prior = F.linear( + one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.background_pro_alpha + ) + py_back_beta_prior = F.linear( + one_hot(batch_index.squeeze(-1), self.n_batch).float(), + torch.exp(self.background_pro_log_beta), + ) + else: + py_back_alpha_prior = self.background_pro_alpha + py_back_beta_prior = torch.exp(self.background_pro_log_beta) + self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) + + return { + MODULE_KEYS.Z_KEY: z, + MODULE_KEYS.QZ_KEY: qz, + MODULE_KEYS.QL_KEY: None, + "library_gene": observed_lib_size, + } + + @auto_move_data + def _regular_inference( self, x: torch.Tensor, y: torch.Tensor, @@ -518,24 +567,6 @@ def inference( else: library_gene = self.encoder.l_transformation(untran_l) - # Background regularization - if self.gene_dispersion == "gene-label": - # px_r gets transposed - last dimension is nb genes - px_r = F.linear(one_hot(label.squeeze(-1), self.n_labels).float(), self.px_r) - elif self.gene_dispersion == "gene-batch": - px_r = F.linear(one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.px_r) - elif self.gene_dispersion == "gene": - px_r = self.px_r - px_r = torch.exp(px_r) - - if self.protein_dispersion == "protein-label": - # py_r gets transposed - last dimension is n_proteins - py_r = F.linear(one_hot(label.squeeze(-1), self.n_labels).float(), self.py_r) - elif self.protein_dispersion == "protein-batch": - py_r = F.linear(one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.py_r) - elif self.protein_dispersion == "protein": - py_r = self.py_r - py_r = torch.exp(py_r) if self.n_batch > 0: py_back_alpha_prior = F.linear( one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.background_pro_alpha @@ -550,10 +581,9 @@ def inference( self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) return { - "qz": qz, - "z": z, - "untran_z": untran_z, - "ql": ql, + MODULE_KEYS.Z_KEY: z, + MODULE_KEYS.QZ_KEY: qz, + MODULE_KEYS.QL_KEY: ql, "library_gene": library_gene, "untran_l": untran_l, } @@ -697,7 +727,6 @@ def marginal_ll(self, tensors, n_mc_samples, return_mean: bool = True): qz = inference_outputs["qz"] ql = inference_outputs["ql"] py_ = generative_outputs["py_"] - log_library = inference_outputs["untran_l"] # really need not softmax transformed random variable z = inference_outputs["untran_z"] log_pro_back_mean = generative_outputs["log_pro_back_mean"] @@ -711,6 +740,7 @@ def marginal_ll(self, tensors, n_mc_samples, return_mean: bool = True): log_prob_sum = torch.zeros(qz.loc.shape[0]).to(self.device) if not self.use_observed_lib_size: + log_library = inference_outputs["untran_l"] n_batch = self.library_log_means.shape[1] local_library_log_means = F.linear( one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_means diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 210baec931..938add5b85 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -15,6 +15,7 @@ def test_multivi(): data = synthetic_iid() + data2 = data.copy() MULTIVI.setup_anndata( data, batch_key="batch", @@ -28,7 +29,7 @@ def test_multivi(): vae.train(1, adversarial_mixing=False) vae.train(3) vae.get_elbo(indices=vae.validation_indices) - vae.get_accessibility_estimates() + vae.get_accessibility_estimates(data2) vae.get_accessibility_estimates(normalize_cells=True) vae.get_accessibility_estimates(normalize_regions=True) vae.get_normalized_expression() @@ -53,6 +54,9 @@ def test_multivi(): # Test with modality weights and penalties data = synthetic_iid() MULTIVI.setup_anndata(data, batch_key="batch") + with pytest.raises(AssertionError): + vae = MULTIVI(data, n_genes=50, n_regions=50, modality_weights="cell") + MULTIVI.setup_anndata(data, batch_key="batch", index_key="_indices") vae = MULTIVI(data, n_genes=50, n_regions=50, modality_weights="cell") vae.train(3) vae = MULTIVI(data, n_genes=50, n_regions=50, modality_weights="universal") @@ -67,6 +71,7 @@ def test_multivi(): batch_key="batch", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", + index_key="_indices" ) vae = MULTIVI( data, @@ -181,12 +186,16 @@ def test_multivi_mudata_trimodal_external(): model.get_region_factors() -@pytest.mark.parametrize("n_genes", [25, 50, 100]) -@pytest.mark.parametrize("n_regions", [25, 50, 100]) -def test_multivi_mudata(n_genes: int, n_regions: int): - # use of syntetic data of rna/proteins/atac for speed - +@pytest.mark.parametrize("n_genes", [23, 42]) +@pytest.mark.parametrize("n_regions", [24, 44]) +@pytest.mark.parametrize("n_proteins", [27, 48]) +def test_multivi_mudata(n_genes: int, n_regions: int, n_proteins: int): mdata = synthetic_iid(return_mudata=True) + mdata.mod["rna"] = mdata.mod["rna"][:, 0:n_genes].copy() + mdata.mod["accessibility"] = mdata.mod["accessibility"][:, 0:n_regions].copy() + mdata.mod["protein_expression"] = mdata.mod["protein_expression"][:, 0:n_proteins].copy() + mdata.update() + mdata2 = mdata.copy() MULTIVI.setup_mudata( mdata, batch_key="batch", @@ -207,7 +216,7 @@ def test_multivi_mudata(n_genes: int, n_regions: int): model.get_elbo() model.get_reconstruction_error() model.get_normalized_expression() - model.get_normalized_expression(transform_batch=["batch_0", "batch_1"]) + model.get_normalized_expression(mdata2, transform_batch=["batch_0", "batch_1"]) model.get_accessibility_estimates() model.get_accessibility_estimates(normalize_cells=True) model.get_accessibility_estimates(normalize_regions=True) @@ -217,26 +226,7 @@ def test_multivi_mudata(n_genes: int, n_regions: int): model.get_elbo(indices=model.validation_indices) model.get_reconstruction_error(indices=model.validation_indices) model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) - model.get_library_size_factors() - model.get_region_factors() - - mdata2 = synthetic_iid(return_mudata=True) - MULTIVI.setup_mudata( - mdata2, - batch_key="batch", - modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, - ) - norm_exp = model.get_normalized_expression(mdata2, indices=[1, 2, 3]) - assert norm_exp.shape == (3, n_genes) - - # test transfer_anndata_setup + view - mdata3 = synthetic_iid(return_mudata=True) - mdata3.obs["_indices"] = np.arange(mdata3.n_obs) - model.get_elbo(mdata3[:10]) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(mdata2, normalize_cells=True) model.get_accessibility_estimates(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() From 165633c5958d6086f180b16e3e6749cedf02688b Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 21 Nov 2024 16:33:56 +0200 Subject: [PATCH 47/51] updated multivi tutorials with mudata and minification --- docs/user_guide/models/multivi.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/user_guide/models/multivi.md b/docs/user_guide/models/multivi.md index a552622706..73a4f51c70 100644 --- a/docs/user_guide/models/multivi.md +++ b/docs/user_guide/models/multivi.md @@ -92,7 +92,7 @@ The latent variables, along with their description are summarized in the followi - Low-dimensional representation capturing the state of a cell. - N/A * - :math:`\rho_n \in \Delta^{G-1}` - - Denoised/normalized gene expression. This is a vector that sums to 1 within a cell, unless `size_factor_key is not None` in :class:`~scvi.model.MULTVI.setup_anndata`, in which case this is only forced to be non-negative via softplus. + - Denoised/normalized gene expression. This is a vector that sums to 1 within a cell, unless `size_factor_key is not None` in :class:`~scvi.model.MULTVI.setup_anndata` or :class:`~scvi.model.MULTIVI.setup_mudata`, in which case this is only forced to be non-negative via softplus. - ``px_scale`` * - :math:`\ell_n \in (0, \infty)` - Library size for RNA. @@ -104,7 +104,7 @@ The latent variables, along with their description are summarized in the followi - Accessibility probability estimate - N/A * - :math:`\ell_n \in \left[0,1\right]` - - Cell-wise scaling factor. Learned, but can be set manually with `size_factor_key` in :class:`~scvi.model.MULTIVI.setup_anndata`. + - Cell-wise scaling factor. Learned, but can be set manually with `size_factor_key` in :class:`~scvi.model.MULTIVI.setup_anndata` or :class:`~scvi.model.MULTIVI.setup_mudata`. - ``d`` * - :math:`r_j \in \left[0,1\right]` - Region-wise scaling factor From 01434711bf15249520a9262f795e57ad9db3014b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:06:35 +0000 Subject: [PATCH 48/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_multivi.py | 22 ++++++++++------------ src/scvi/module/_totalvae.py | 2 +- tests/model/test_multivi.py | 2 +- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 27d9cd965c..4fcd11a931 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -194,15 +194,11 @@ def __init__( n_genes = self.summary_stats.get("n_vars", 0) n_regions = self.summary_stats.get("n_atac", 0) if isinstance(adata, MuData): - assert ( - n_genes == self.summary_stats.get("n_vars", 0) - ), "n_genes must match MuData" - assert ( - n_regions == self.summary_stats.get("n_atac", 0) - ), "n_regions must match MuData" + assert n_genes == self.summary_stats.get("n_vars", 0), "n_genes must match MuData" + assert n_regions == self.summary_stats.get("n_atac", 0), "n_regions must match MuData" if modality_weights == "cell": assert ( - self.registry_['setup_args']['index_key'] is not None + self.registry_["setup_args"]["index_key"] is not None ), "index_key must be set if using cell modality weights" prior_mean, prior_scale = None, None n_cats_per_cov = ( @@ -265,7 +261,9 @@ def __init__( self.n_proteins = n_proteins self.module.minified_data_type = self.minified_data_type if isinstance(adata, MuData): - self.modality_keys = self.get_anndata_manager(adata).registry['setup_args']['modalities'] + self.modality_keys = self.get_anndata_manager(adata).registry["setup_args"][ + "modalities" + ] else: self.modality_keys = None @@ -644,9 +642,9 @@ def get_accessibility_estimates( ) else: if isinstance(adata, MuData): - peak_names = adata[self.modality_keys['atac_layer']].var_names[region_mask] + peak_names = adata[self.modality_keys["atac_layer"]].var_names[region_mask] else: - peak_names = adata.var_names[self.n_regions:][region_mask] + peak_names = adata.var_names[self.n_regions :][region_mask] if return_numpy: return imputed elif threshold: @@ -775,9 +773,9 @@ def get_normalized_expression( return exprs else: if isinstance(adata, MuData): - gene_names = adata[self.modality_keys['rna_layer']].var_names[gene_mask] + gene_names = adata[self.modality_keys["rna_layer"]].var_names[gene_mask] else: - gene_names = adata.var_names[:self.n_genes][gene_mask] + gene_names = adata.var_names[: self.n_genes][gene_mask] return pd.DataFrame( exprs, diff --git a/src/scvi/module/_totalvae.py b/src/scvi/module/_totalvae.py index 0abeeaa364..1a695fbc5c 100644 --- a/src/scvi/module/_totalvae.py +++ b/src/scvi/module/_totalvae.py @@ -481,7 +481,7 @@ def _cached_inference( MODULE_KEYS.Z_KEY: z, MODULE_KEYS.QZ_KEY: qz, MODULE_KEYS.QL_KEY: None, - "library_gene": observed_lib_size, + "library_gene": observed_lib_size, } @auto_move_data diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 938add5b85..2e5b2301f9 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -71,7 +71,7 @@ def test_multivi(): batch_key="batch", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", - index_key="_indices" + index_key="_indices", ) vae = MULTIVI( data, From 237ff136ab94ed4ea01cd57ddfb0745578e5166e Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Fri, 22 Nov 2024 00:12:05 -0800 Subject: [PATCH 49/51] Fixed multiVI --- src/scvi/model/_multivi.py | 22 +- src/scvi/module/_multivae copy.py | 1104 +++++++++++++++++++++++++++++ src/scvi/module/_multivae.py | 298 ++++++-- src/scvi/module/_peakvae.py | 6 +- 4 files changed, 1340 insertions(+), 90 deletions(-) create mode 100644 src/scvi/module/_multivae copy.py diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 4fcd11a931..d93f553f62 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -439,14 +439,18 @@ def get_library_size_factors( } @torch.inference_mode() - def get_region_factors(self) -> np.ndarray: + def get_region_factors(self, return_numpy=True) -> np.ndarray: """Return region-specific factors.""" - if self.n_regions == 0: + if self.n_regions == 0 : return np.zeros(1) else: if self.module.region_factors is None: raise RuntimeError("region factors were not included in this model") - return torch.sigmoid(self.module.region_factors).cpu().numpy() + region_factors = torch.sigmoid(self.scale_region_factor * self.module.region_factors) + if return_numpy: + return region_factors.cpu().numpy() + else: + return region_factors @torch.inference_mode() def get_latent_representation( @@ -621,7 +625,7 @@ def get_accessibility_estimates( if normalize_cells: p *= inference_outputs["libsize_acc"].cpu() if normalize_regions: - p *= torch.sigmoid(self.module.region_factors).cpu() + p *= self.get_region_factors(return_numpy=False).cpu() if threshold: p[p < threshold] = 0 p = csr_matrix(p.numpy()) @@ -750,9 +754,9 @@ def get_normalized_expression( compute_loss=False, ) if library_size == "latent": - output = generative_outputs["px_rate"] + output = generative_outputs["px"].get_normalized("px_rate") else: - output = generative_outputs["px_scale"] + output = generative_outputs["px"].get_normalized("px_scale") output = output[..., gene_mask] output = output.cpu().numpy() per_batch_exprs.append(output) @@ -858,22 +862,18 @@ def differential_accessibility( """ self._check_adata_modality_weights(adata) adata = self._validate_anndata(adata) - col_names = adata.var_names[: self.n_genes] + col_names = adata.var_names[self.n_genes :] model_fn = partial( self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size ) - # TODO check if change_fn in kwargs and raise error if so def change_fn(a, b): return a - b if two_sided: - def m1_domain_fn(samples): return np.abs(samples) >= delta - else: - def m1_domain_fn(samples): return samples >= delta diff --git a/src/scvi/module/_multivae copy.py b/src/scvi/module/_multivae copy.py new file mode 100644 index 0000000000..1bedd5d928 --- /dev/null +++ b/src/scvi/module/_multivae copy.py @@ -0,0 +1,1104 @@ +from collections.abc import Iterable +from typing import Literal + +import numpy as np +import torch +from torch import nn +from torch.distributions import Normal +from torch.distributions import kl_divergence as kld +from torch.nn import functional as F + +from scvi import REGISTRY_KEYS +from scvi.distributions import ( + NegativeBinomial, + NegativeBinomialMixture, + Poisson, + ZeroInflatedNegativeBinomial, +) +from scvi.module._peakvae import Decoder as DecoderPeakVI +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.nn import DecoderSCVI, Encoder, FCLayers + +from ._utils import masked_softmax + + +class LibrarySizeEncoder(torch.nn.Module): + """Library size encoder.""" + + def __init__( + self, + n_input: int, + n_cat_list: Iterable[int] = None, + n_layers: int = 2, + n_hidden: int = 128, + use_batch_norm: bool = False, + use_layer_norm: bool = True, + inject_covariates: bool = False, + **kwargs, + ): + super().__init__() + self.encoder = FCLayers( + n_in=n_input, + n_out=n_hidden, + n_cat_list=n_cat_list, + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=0, + activation_fn=torch.nn.LeakyReLU, + use_batch_norm=use_batch_norm, + use_layer_norm=use_layer_norm, + inject_covariates=inject_covariates, + **kwargs, + ) + self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU()) + + def forward(self, x: torch.Tensor, *cat_list: int): + """Forward pass.""" + return self.output(self.encoder(x, *cat_list)) + + +class DecoderADT(torch.nn.Module): + """Decoder for just surface proteins (ADT).""" + + def __init__( + self, + n_input: int, + n_output_proteins: int, + n_cat_list: Iterable[int] = None, + n_layers: int = 2, + n_hidden: int = 128, + dropout_rate: float = 0.1, + use_batch_norm: bool = False, + use_layer_norm: bool = True, + inject_covariates: bool = False, + ): + super().__init__() + self.n_output_proteins = n_output_proteins + + linear_args = { + "n_layers": 1, + "use_activation": False, + "use_batch_norm": False, + "use_layer_norm": False, + "dropout_rate": 0, + } + + self.py_fore_decoder = FCLayers( + n_in=n_input, + n_out=n_hidden, + n_cat_list=n_cat_list, + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=dropout_rate, + use_batch_norm=use_batch_norm, + use_layer_norm=use_layer_norm, + ) + self.py_fore_scale_decoder = FCLayers( + n_in=n_hidden + n_input, + n_out=n_output_proteins, + n_cat_list=n_cat_list, + n_layers=1, + use_activation=True, + use_batch_norm=False, + use_layer_norm=False, + dropout_rate=0, + activation_fn=nn.ReLU, + ) + + self.py_background_decoder = FCLayers( + n_in=n_hidden + n_input, + n_out=n_output_proteins, + n_cat_list=n_cat_list, + **linear_args, + ) + + # dropout (mixture component for proteins, ZI probability for genes) + self.sigmoid_decoder = FCLayers( + n_in=n_input, + n_out=n_hidden, + n_cat_list=n_cat_list, + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=dropout_rate, + use_batch_norm=use_batch_norm, + use_layer_norm=use_layer_norm, + ) + + # background mean parameters second decoder + self.py_back_mean_log_alpha = FCLayers( + n_in=n_hidden + n_input, + n_out=n_output_proteins, + n_cat_list=n_cat_list, + **linear_args, + ) + self.py_back_mean_log_beta = FCLayers( + n_in=n_hidden + n_input, + n_out=n_output_proteins, + n_cat_list=n_cat_list, + **linear_args, + ) + + # background mean first decoder + self.py_back_decoder = FCLayers( + n_in=n_input, + n_out=n_hidden, + n_cat_list=n_cat_list, + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=dropout_rate, + use_batch_norm=use_batch_norm, + use_layer_norm=use_layer_norm, + ) + + def forward(self, z: torch.Tensor, *cat_list: int): + """Forward pass.""" + # z is the latent repr + py_ = {} + + py_back = self.py_back_decoder(z, *cat_list) + py_back_cat_z = torch.cat([py_back, z], dim=-1) + + py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list) + py_["back_beta"] = ( + torch.nn.functional.softplus(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) + + 1e-8 + ) + log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample() + py_["rate_back"] = torch.exp(log_pro_back_mean) + + py_fore = self.py_fore_decoder(z, *cat_list) + py_fore_cat_z = torch.cat([py_fore, z], dim=-1) + py_["fore_scale"] = self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8 + py_["rate_fore"] = py_["rate_back"] * py_["fore_scale"] + + p_mixing = self.sigmoid_decoder(z, *cat_list) + p_mixing_cat_z = torch.cat([p_mixing, z], dim=-1) + py_["mixing"] = self.py_background_decoder(p_mixing_cat_z, *cat_list) + + protein_mixing = 1 / (1 + torch.exp(-py_["mixing"])) + py_["scale"] = torch.nn.functional.normalize( + (1 - protein_mixing) * py_["rate_fore"], p=1, dim=-1 + ) + + return py_, log_pro_back_mean + + +class MULTIVAE(BaseModuleClass): + """Variational auto-encoder model for joint paired + unpaired RNA-seq and ATAC-seq data. + + Parameters + ---------- + n_input_regions + Number of input regions. + n_input_genes + Number of input genes. + n_input_proteins + Number of input proteins + modality_weights + Weighting scheme across modalities. One of the following: + * ``"equal"``: Equal weight in each modality + * ``"universal"``: Learn weights across modalities w_m. + * ``"cell"``: Learn weights across modalities and cells. w_{m,c} + modality_penalty + Training Penalty across modalities. One of the following: + * ``"Jeffreys"``: Jeffreys penalty to align modalities + * ``"MMD"``: MMD penalty to align modalities + * ``"None"``: No penalty + n_batch + Number of batches, if 0, no batch correction is performed. + gene_likelihood + The distribution to use for gene expression data. One of the following + * ``'zinb'`` - Zero-Inflated Negative Binomial + * ``'nb'`` - Negative Binomial + * ``'poisson'`` - Poisson + gene_dispersion + One of the following: + * ``'gene'`` - dispersion parameter of NB is constant per gene across cells + * ``'gene-batch'`` - dispersion can differ between different batches + * ``'gene-label'`` - dispersion can differ between different labels + * ``'gene-cell'`` - dispersion can differ for every gene in every cell + protein_dispersion + One of the following: + + * ``'protein'`` - protein_dispersion parameter is constant per protein across cells + * ``'protein-batch'`` - protein_dispersion can differ between different batches NOT TESTED + * ``'protein-label'`` - protein_dispersion can differ between different labels NOT TESTED + n_hidden + Number of nodes per hidden layer. If `None`, defaults to square root + of number of regions. + n_latent + Dimensionality of the latent space. If `None`, defaults to square root + of `n_hidden`. + n_layers_encoder + Number of hidden layers used for encoder NN. + n_layers_decoder + Number of hidden layers used for decoder NN. + dropout_rate + Dropout rate for neural networks + region_factors + Include region-specific factors in the model + scale_region_factors + Scale region factors by a fixed number to speed up convergence + use_batch_norm + One of the following + * ``'encoder'`` - use batch normalization in the encoder only + * ``'decoder'`` - use batch normalization in the decoder only + * ``'none'`` - do not use batch normalization + * ``'both'`` - use batch normalization in both the encoder and decoder + use_layer_norm + One of the following + * ``'encoder'`` - use layer normalization in the encoder only + * ``'decoder'`` - use layer normalization in the decoder only + * ``'none'`` - do not use layer normalization + * ``'both'`` - use layer normalization in both the encoder and decoder + latent_distribution + which latent distribution to use, options are + * ``'normal'`` - Normal distribution + * ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax) + deeply_inject_covariates + Whether to deeply inject covariates into all layers of the decoder. If False, + covariates will only be included in the input layer. + encode_covariates + If True, include covariates in the input to the encoder. + use_size_factor_key + Use size_factor AnnDataField defined by the user as scaling factor in mean of conditional + RNA distribution. + """ + + def __init__( + self, + n_input_regions: int = 0, + n_input_genes: int = 0, + n_input_proteins: int = 0, + modality_weights: Literal["equal", "cell", "universal"] = "equal", + modality_penalty: Literal["Jeffreys", "MMD", "None"] = "Jeffreys", + n_batch: int = 0, + n_obs: int = 0, + n_labels: int = 0, + gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", + gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", + atac_likelihood: Literal["zinb", "nb", "poisson", "bernoulli"] = "bernoulli", + atac_dispersion: Literal["peak", "peak-batch", "peak-label", "peak-cell"] = "peak", + n_hidden: int = None, + n_latent: int = None, + n_layers_encoder: int = 2, + n_layers_decoder: int = 2, + n_continuous_cov: int = 0, + n_cats_per_cov: Iterable[int] | None = None, + dropout_rate: float = 0.1, + region_factors: bool = True, + scale_region_factors: float = 1., + use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", + use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", + latent_distribution: Literal["normal", "ln"] = "normal", + deeply_inject_covariates: bool = False, + encode_covariates: bool = False, + use_size_factor_key: bool = False, + protein_background_prior_mean: np.ndarray | None = None, + protein_background_prior_scale: np.ndarray | None = None, + protein_dispersion: str = "protein", + ): + super().__init__() + + # INIT PARAMS + self.n_input_regions = n_input_regions + self.n_input_genes = n_input_genes + self.n_input_proteins = n_input_proteins + if n_hidden is None: + if n_input_regions == 0: + self.n_hidden = np.min([128, int(np.sqrt(self.n_input_genes))]) + else: + self.n_hidden = np.min([128, int(np.sqrt(self.n_input_regions))]) + else: + self.n_hidden = n_hidden + self.n_batch = n_batch + + self.gene_likelihood = gene_likelihood + self.latent_distribution = latent_distribution + + self.n_latent = int(np.sqrt(self.n_hidden)) if n_latent is None else n_latent + self.n_layers_encoder = n_layers_encoder + self.n_layers_decoder = n_layers_decoder + self.n_cats_per_cov = n_cats_per_cov + self.n_continuous_cov = n_continuous_cov + self.dropout_rate = dropout_rate + + self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both") + self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both") + self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both") + self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both") + self.encode_covariates = encode_covariates + self.deeply_inject_covariates = deeply_inject_covariates + self.use_size_factor_key = use_size_factor_key + + cat_list = [n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else [] + encoder_cat_list = cat_list if encode_covariates else None + + # expression + # expression dispersion parameters + self.gene_likelihood = gene_likelihood + self.gene_dispersion = gene_dispersion + if self.gene_dispersion == "gene": + self.px_r = torch.nn.Parameter(torch.randn(n_input_genes)) + elif self.gene_dispersion == "gene-batch": + self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_batch)) + elif self.gene_dispersion == "gene-label": + self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_labels)) + elif self.gene_dispersion == "gene-cell": + pass + else: + raise ValueError( + "dispersion must be one of ['gene', 'gene-batch'," + " 'gene-label', 'gene-cell'], but input was " + "{}.format(self.dispersion)" + ) + + # expression encoder + if self.n_input_genes == 0: + input_exp = 1 + else: + input_exp = self.n_input_genes + n_input_encoder_exp = input_exp + n_continuous_cov * encode_covariates + self.z_encoder_expression = Encoder( + n_input=n_input_encoder_exp, + n_output=self.n_latent, + n_cat_list=encoder_cat_list, + n_layers=self.n_layers_encoder, + n_hidden=self.n_hidden, + dropout_rate=self.dropout_rate, + distribution=self.latent_distribution, + inject_covariates=deeply_inject_covariates, + use_batch_norm=self.use_batch_norm_encoder, + use_layer_norm=self.use_layer_norm_encoder, + activation_fn=torch.nn.LeakyReLU, + var_eps=0, + return_dist=False, + ) + + # expression library size encoder + self.l_encoder_expression = LibrarySizeEncoder( + n_input_encoder_exp, + n_cat_list=encoder_cat_list, + n_layers=self.n_layers_encoder, + n_hidden=self.n_hidden, + use_batch_norm=self.use_batch_norm_encoder, + use_layer_norm=self.use_layer_norm_encoder, + inject_covariates=self.deeply_inject_covariates, + ) + + # expression decoder + n_input_decoder = self.n_latent + self.n_continuous_cov + self.z_decoder_expression = DecoderSCVI( + n_input_decoder, + n_input_genes, + n_cat_list=cat_list, + n_layers=n_layers_decoder, + n_hidden=self.n_hidden, + inject_covariates=self.deeply_inject_covariates, + use_batch_norm=self.use_batch_norm_decoder, + use_layer_norm=self.use_layer_norm_decoder, + scale_activation="softplus" if use_size_factor_key else "softmax", + ) + + # accessibility + # atac dispersion parameters + self.atac_likelihood = atac_likelihood + self.atac_dispersion = atac_dispersion + if self.atac_dispersion == "peak": + self.px_r_atac = torch.nn.Parameter(2 * torch.rand(self.n_input_regions)) + elif self.atac_dispersion == "peak-batch": + self.px_r_atac = torch.nn.Parameter(2 * torch.rand(self.n_input_regions, n_batch)) + elif self.atac_dispersion == "peak-label": + self.px_r_atac = torch.nn.Parameter(2 * torch.rand(self.n_input_regions, n_labels)) + elif self.atac_dispersion == "peak-cell": + pass + else: + raise ValueError( + "dispersion must be one of ['gene', 'gene-batch'," + " 'gene-label', 'gene-cell'], but input was " + "{}.format(self.dispersion)" + ) + + # accessibility encoder + if self.n_input_regions == 0: + input_acc = 1 + else: + input_acc = self.n_input_regions + n_input_encoder_acc = input_acc + n_continuous_cov * encode_covariates + self.z_encoder_accessibility = Encoder( + n_input=n_input_encoder_acc, + n_layers=self.n_layers_encoder, + n_output=self.n_latent, + n_hidden=self.n_hidden, + n_cat_list=encoder_cat_list, + dropout_rate=self.dropout_rate, + activation_fn=torch.nn.LeakyReLU, + distribution=self.latent_distribution, + var_eps=0, + use_batch_norm=self.use_batch_norm_encoder, + use_layer_norm=self.use_layer_norm_encoder, + return_dist=False, + ) + + # accessibility region-specific factors + self.region_factors = None + self.scale_region_factors = scale_region_factors + if region_factors: + self.region_factors = torch.nn.Parameter(torch.zeros(self.n_input_regions)) + + # accessibility decoder + if self.atac_likelihood == 'bernoulli': + atac_decoder_fn = DecoderPeakVI + decoder_atac_kwargs = {} + else: + atac_decoder_fn = DecoderSCVI + decoder_atac_kwargs = { + "scale_activation": "softplus" if use_size_factor_key else "softmax" + } + + self.z_decoder_accessibility = atac_decoder_fn( + n_input_decoder, + n_input_regions, + n_cat_list=cat_list, + n_layers=n_layers_decoder, + n_hidden=self.n_hidden, + inject_covariates=self.deeply_inject_covariates, + use_batch_norm=self.use_batch_norm_decoder, + use_layer_norm=self.use_layer_norm_decoder, + **decoder_atac_kwargs + ) + + # accessibility library size encoder + self.l_encoder_accessibility = LibrarySizeEncoder( + n_input_encoder_acc, + n_cat_list=encoder_cat_list, + n_layers=self.n_layers_encoder, + n_hidden=self.n_hidden, + use_batch_norm=self.use_batch_norm_encoder, + use_layer_norm=self.use_layer_norm_encoder, + inject_covariates=self.deeply_inject_covariates, + ) + + # protein + self.protein_dispersion = protein_dispersion + if protein_background_prior_mean is None: + if n_batch > 0: + self.background_pro_alpha = torch.nn.Parameter( + torch.randn(n_input_proteins, n_batch) + ) + self.background_pro_log_beta = torch.nn.Parameter( + torch.clamp(torch.randn(n_input_proteins, n_batch), -10, 1) + ) + else: + self.background_pro_alpha = torch.nn.Parameter(torch.randn(n_input_proteins)) + self.background_pro_log_beta = torch.nn.Parameter( + torch.clamp(torch.randn(n_input_proteins), -10, 1) + ) + else: + if protein_background_prior_mean.shape[1] == 1 and n_batch != 1: + init_mean = protein_background_prior_mean.ravel() + init_scale = protein_background_prior_scale.ravel() + else: + init_mean = protein_background_prior_mean + init_scale = protein_background_prior_scale + self.background_pro_alpha = torch.nn.Parameter( + torch.from_numpy(init_mean.astype(np.float32)) + ) + self.background_pro_log_beta = torch.nn.Parameter( + torch.log(torch.from_numpy(init_scale.astype(np.float32))) + ) + + # protein encoder + if self.n_input_proteins == 0: + input_pro = 1 + else: + input_pro = self.n_input_proteins + n_input_encoder_pro = input_pro + n_continuous_cov * encode_covariates + self.z_encoder_protein = Encoder( + n_input=n_input_encoder_pro, + n_layers=self.n_layers_encoder, + n_output=self.n_latent, + n_hidden=self.n_hidden, + n_cat_list=encoder_cat_list, + dropout_rate=self.dropout_rate, + activation_fn=torch.nn.LeakyReLU, + distribution=self.latent_distribution, + var_eps=0, + use_batch_norm=self.use_batch_norm_encoder, + use_layer_norm=self.use_layer_norm_encoder, + return_dist=False, + ) + + # protein decoder + self.z_decoder_pro = DecoderADT( + n_input=n_input_decoder, + n_output_proteins=n_input_proteins, + n_hidden=self.n_hidden, + n_cat_list=cat_list, + n_layers=self.n_layers_decoder, + use_batch_norm=self.use_batch_norm_decoder, + use_layer_norm=self.use_layer_norm_decoder, + inject_covariates=self.deeply_inject_covariates, + ) + + # protein dispersion parameters + if self.protein_dispersion == "protein": + self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins)) + elif self.protein_dispersion == "protein-batch": + self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins, n_batch)) + elif self.protein_dispersion == "protein-label": + self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins, n_labels)) + else: # protein-cell + pass + + # modality alignment + self.n_obs = n_obs + self.modality_weights = modality_weights + self.modality_penalty = modality_penalty + self.n_modalities = int(n_input_genes > 0) + int(n_input_regions > 0) + max_n_modalities = 3 + if modality_weights == "equal": + mod_weights = torch.ones(max_n_modalities) + self.register_buffer("mod_weights", mod_weights) + elif modality_weights == "universal": + self.mod_weights = torch.nn.Parameter(torch.ones(max_n_modalities)) + else: # cell-specific weights + self.mod_weights = torch.nn.Parameter(torch.ones(n_obs, max_n_modalities)) + + def _get_inference_input(self, tensors): + """Get input tensors for the inference model.""" + # from scvi.data._constants import ADATA_MINIFY_TYPE + # TODO: ADD MINIFICATION CONSIDERATION + + x = tensors.get(REGISTRY_KEYS.X_KEY, None) + x_atac = tensors.get(REGISTRY_KEYS.ATAC_X_KEY, None) + if x is not None and x_atac is not None: + x = torch.cat((x, x_atac), dim=-1) + elif x is None: + x = x_atac + if self.n_input_proteins == 0: + y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) + else: + y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] + cell_idx = tensors.get(REGISTRY_KEYS.INDICES_KEY) + if cell_idx is not None: + cell_idx = cell_idx.long().ravel() + cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY) + cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY) + label = tensors[REGISTRY_KEYS.LABELS_KEY] + size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None) + input_dict = { + "x": x, + "y": y, + "batch_index": batch_index, + "cont_covs": cont_covs, + "cat_covs": cat_covs, + "label": label, + "cell_idx": cell_idx, + "size_factor": size_factor, + } + return input_dict + + @auto_move_data + def inference( + self, + x, + y, + batch_index, + cont_covs, + cat_covs, + label, + cell_idx, + size_factor, + n_samples=1, + ) -> dict[str, torch.Tensor]: + """Run the inference model.""" + # Get Data and Additional Covs + if self.n_input_genes == 0: + x_rna = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) + else: + x_rna = x[:, : self.n_input_genes] + if self.n_input_regions == 0: + x_atac = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) + else: + x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + + mask_expr = x_rna.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 + mask_pro = y.sum(dim=1) > 0 + + if cont_covs is not None and self.encode_covariates: + encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1) + encoder_input_accessibility = torch.cat((x_atac, cont_covs), dim=-1) + encoder_input_protein = torch.cat((y, cont_covs), dim=-1) + else: + encoder_input_expression = x_rna + encoder_input_accessibility = x_atac + encoder_input_protein = y + + if cat_covs is not None and self.encode_covariates: + categorical_input = torch.split(cat_covs, 1, dim=1) + else: + categorical_input = () + + # Z Encoders + qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility( + encoder_input_accessibility, batch_index, *categorical_input + ) + qzm_expr, qzv_expr, z_expr = self.z_encoder_expression( + encoder_input_expression, batch_index, *categorical_input + ) + qzm_pro, qzv_pro, z_pro = self.z_encoder_protein( + encoder_input_protein, batch_index, *categorical_input + ) + + # L encoders + if self.use_size_factor_key: + libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) + libsize_acc = size_factor[:, [1]] + if self.atac_likelihood != "bernoulli": + libsize_acc = torch.log(libsize_acc + 1e-6) + else: + libsize_acc = self.l_encoder_accessibility( + encoder_input_accessibility, batch_index, *categorical_input + ) + if self.atac_likelihood == "bernoulli": + libsize_acc = torch.sigmoid(libsize_acc) + libsize_expr = self.l_encoder_expression( + encoder_input_expression, batch_index, *categorical_input + ) + + # mix representations + if self.modality_weights == "cell": + weights = self.mod_weights[cell_idx, :] + else: + weights = self.mod_weights.unsqueeze(0).expand(x.shape[0], -1) + + qz_m = mix_modalities( + (qzm_expr, qzm_acc, qzm_pro), (mask_expr, mask_acc, mask_pro), weights + ) + qz_v = mix_modalities( + (qzv_expr, qzv_acc, qzv_pro), + (mask_expr, mask_acc, mask_pro), + weights, + torch.sqrt, + ) + + # sample + if n_samples > 1: + + def unsqz(zt, n_s): + return zt.unsqueeze(0).expand((n_s, zt.size(0), zt.size(1))) + + untran_za = Normal(qzm_acc, qzv_acc.sqrt()).sample((n_samples,)) + z_acc = self.z_encoder_accessibility.z_transformation(untran_za) + untran_ze = Normal(qzm_expr, qzv_expr.sqrt()).sample((n_samples,)) + z_expr = self.z_encoder_expression.z_transformation(untran_ze) + untran_zp = Normal(qzm_pro, qzv_pro.sqrt()).sample((n_samples,)) + z_pro = self.z_encoder_protein.z_transformation(untran_zp) + + libsize_expr = unsqz(libsize_expr, n_samples) + libsize_acc = unsqz(libsize_acc, n_samples) + + # sample from the mixed representation + untran_z = Normal(qz_m, qz_v.sqrt()).rsample() + z = self.z_encoder_accessibility.z_transformation(untran_z) + + outputs = { + "x": x, + "z": z, + "qz_m": qz_m, + "qz_v": qz_v, + "z_expr": z_expr, + "qzm_expr": qzm_expr, + "qzv_expr": qzv_expr, + "z_acc": z_acc, + "qzm_acc": qzm_acc, + "qzv_acc": qzv_acc, + "z_pro": z_pro, + "qzm_pro": qzm_pro, + "qzv_pro": qzv_pro, + "libsize_expr": libsize_expr, + "libsize_acc": libsize_acc, + } + return outputs + + def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): + """Get the input for the generative model.""" + z = inference_outputs["z"] + qz_m = inference_outputs["qz_m"] + libsize_expr = inference_outputs["libsize_expr"] + libsize_acc = inference_outputs["libsize_acc"] + + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] + cont_key = REGISTRY_KEYS.CONT_COVS_KEY + cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None + + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None + + if transform_batch is not None: + batch_index = torch.ones_like(batch_index) * transform_batch + + label = tensors[REGISTRY_KEYS.LABELS_KEY] + + input_dict = { + "z": z, + "qz_m": qz_m, + "batch_index": batch_index, + "cont_covs": cont_covs, + "cat_covs": cat_covs, + "libsize_expr": libsize_expr, + "libsize_acc": libsize_acc, + "label": label, + } + return input_dict + + @auto_move_data + def generative( + self, + z, + qz_m, + batch_index, + cont_covs=None, + cat_covs=None, + libsize_expr=None, + libsize_acc=None, + use_z_mean=False, + label: torch.Tensor = None, + ): + """Runs the generative model.""" + if cat_covs is not None: + categorical_input = torch.split(cat_covs, 1, dim=1) + else: + categorical_input = () + + latent = z if not use_z_mean else qz_m + if cont_covs is None: + decoder_input = latent + elif latent.dim() != cont_covs.dim(): + decoder_input = torch.cat( + [latent, cont_covs.unsqueeze(0).expand(latent.size(0), -1, -1)], dim=-1 + ) + else: + decoder_input = torch.cat([latent, cont_covs], dim=-1) + + # Accessibility Decoder + region_factor = ( + torch.sigmoid(self.scale_region_factors * self.region_factors) + if self.region_factors is not None else 1. + ) + if self.atac_likelihood == "bernoulli": + p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) + px_atac = {'px_rate': libsize_acc * region_factor * p, 'px_scale': p} + else: + # ATAC Decoder + px_scale_atac, px_r_atac, px_rate_atac, px_dropout_atac = self.z_decoder_accessibility( + self.atac_dispersion, + decoder_input, + libsize_acc, + batch_index, + *categorical_input, + label, + ) + # scale by 2 to match the initial scale of the region factor (0.5). + px_rate_atac = px_rate_atac * region_factor + # ATAC Dispersion + if self.atac_dispersion == "peak-label": + px_r_atac = F.linear( + F.one_hot(label.squeeze(-1), self.n_labels).float(), self.px_r_atac + ) # px_r gets transposed - last dimension is nb genes + elif self.atac_dispersion == "peak-batch": + px_r_atac = F.linear( + F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), + self.px_r_atac) + elif self.atac_dispersion == "peak": + px_r_atac = self.px_r_atac + px_r_atac = torch.exp(px_r_atac) + if self.atac_likelihood == "zinb": + px_atac = ZeroInflatedNegativeBinomial( + mu=px_rate_atac, + theta=px_r_atac, + zi_logits=px_dropout_atac, + scale=px_scale_atac, + ) + elif self.atac_likelihood == "nb": + px_atac = NegativeBinomial(mu=px_rate_atac, theta=px_r_atac, scale=px_scale_atac) + elif self.atac_likelihood == "poisson": + px_atac = Poisson(rate=px_rate_atac, scale=px_scale_atac) + elif self.atac_likelihood == "normal": + px_atac = Normal(px_rate_atac, px_r_atac, normal_mu=px_scale_atac) + + # Expression Decoder + px_scale, px_scale, px_rate, px_dropout = self.z_decoder_expression( + self.gene_dispersion, + decoder_input, + libsize_expr, + batch_index, + *categorical_input, + label, + ) + # Expression Dispersion + if self.gene_dispersion == "gene-label": + px_r = F.linear( + F.one_hot(label.squeeze(-1), self.n_labels).float(), self.px_r + ) # px_r gets transposed - last dimension is nb genes + elif self.gene_dispersion == "gene-batch": + px_r = F.linear(F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.px_r) + elif self.gene_dispersion == "gene": + px_r = self.px_r + px_r = torch.exp(px_r) + + if self.gene_likelihood == "zinb": + px = ZeroInflatedNegativeBinomial( + mu=px_rate, + theta=px_r, + zi_logits=px_dropout, + scale=px_scale, + ) + elif self.gene_likelihood == "nb": + px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) + elif self.gene_likelihood == "poisson": + px = Poisson(rate=px_rate, scale=px_scale) + elif self.gene_likelihood == "normal": + px = Normal(px_rate, px_r, normal_mu=px_scale) + + # Protein Decoder + py_, log_pro_back_mean = self.z_decoder_pro(decoder_input, batch_index, *categorical_input) + # Protein Dispersion + if self.protein_dispersion == "protein-label": + # py_r gets transposed - last dimension is n_proteins + py_r = F.linear(F.one_hot(label.squeeze(-1), self.n_labels).float(), self.py_r) + elif self.protein_dispersion == "protein-batch": + py_r = F.linear(F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.py_r) + elif self.protein_dispersion == "protein": + py_r = self.py_r + py_r = torch.exp(py_r) + py_["r"] = py_r + + return { + "px_atac": px_atac, + "px": px, + "py_": py_, + "log_pro_back_mean": log_pro_back_mean, + } + + def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): + """Computes the loss function for the model.""" + # Get the data + x = inference_outputs["x"] + + x_rna = x[:, : self.n_input_genes] + x_atac = x[:, self.n_input_genes:] + if self.n_input_proteins == 0: + y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) + else: + y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] + + mask_expr = x_rna.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 + mask_pro = y.sum(dim=1) > 0 + + # Compute Accessibility loss + px_atac = generative_outputs["px_atac"] + if self.atac_likelihood == "bernoulli": + rl_accessibility = self._get_reconstruction_loss_bernoulli( + x_atac, px_atac["px_rate"]) + else: + rl_accessibility = - px_atac.log_prob(x_atac).sum(-1) + + # Compute Expression loss + rl_expression = - generative_outputs["px"].log_prob(x_rna).sum(-1) + + # Compute Protein loss - No ability to mask minibatch (Param:None) + if mask_pro.sum().gt(0): + py_ = generative_outputs["py_"] + rl_protein = get_reconstruction_loss_protein(y, py_, None) + else: + rl_protein = torch.zeros(x.shape[0], device=x.device, requires_grad=False) + + # calling without weights makes this act like a masked sum + recon_loss_expression = rl_expression * mask_expr + recon_loss_accessibility = rl_accessibility * mask_acc + recon_loss_protein = rl_protein * mask_pro + recon_loss = recon_loss_expression + recon_loss_accessibility + recon_loss_protein + + # Compute KLD between Z and N(0,I) + qz_m = inference_outputs["qz_m"] + qz_v = inference_outputs["qz_v"] + kl_div_z = kld( + Normal(qz_m, torch.sqrt(qz_v)), + Normal(0, 1), + ).sum(dim=1) + + # Compute KLD between distributions for paired data + kl_div_paired = self._compute_mod_penalty( + (inference_outputs["qzm_expr"], inference_outputs["qzv_expr"]), + (inference_outputs["qzm_acc"], inference_outputs["qzv_acc"]), + (inference_outputs["qzm_pro"], inference_outputs["qzv_pro"]), + mask_expr, + mask_acc, + mask_pro, + ) + + # KL WARMUP + kl_local_for_warmup = kl_div_z + weighted_kl_local = kl_weight * kl_local_for_warmup + kl_div_paired + + # TOTAL LOSS + loss = torch.mean(recon_loss + weighted_kl_local) + + recon_losses = { + "reconstruction_loss_expression": recon_loss_expression, + "reconstruction_loss_accessibility": recon_loss_accessibility, + "reconstruction_loss_protein": recon_loss_protein, + } + kl_local = { + "kl_divergence_z": kl_div_z, + "kl_divergence_paired": kl_div_paired, + } + return LossOutput(loss=loss, reconstruction_loss=recon_losses, kl_local=kl_local) + + def _get_reconstruction_loss_bernoulli(self, x, p): + """Computes the reconstruction loss for the accessibility data.""" + # Scaling improves convergence speed. Otherwise region_factors takes long to train. + return torch.nn.BCELoss(reduction="none")(p, (x > 0).float()).sum(dim=-1) + + def _compute_mod_penalty(self, mod_params1, mod_params2, mod_params3, mask1, mask2, mask3): + """Computes Similarity Penalty across modalities given selection (None, Jeffreys, MMD). + + Parameters + ---------- + mod_params1/2/3 + Posterior parameters for for modality 1/2/3 + mask1/2/3 + mask for modality 1/2/3 + """ + mask12 = torch.logical_and(mask1, mask2) + mask13 = torch.logical_and(mask1, mask3) + mask23 = torch.logical_and(mask3, mask2) + + if self.modality_penalty == "None": + return 0 + elif self.modality_penalty == "Jeffreys": + pair_penalty = torch.zeros(mask1.shape[0], device=mask1.device, requires_grad=True) + if mask12.sum().gt(0): + penalty12 = sym_kld( + mod_params1[0], + mod_params1[1].sqrt(), + mod_params2[0], + mod_params2[1].sqrt(), + ) + penalty12 = torch.where(mask12, penalty12.T, torch.zeros_like(penalty12).T).sum( + dim=0 + ) + pair_penalty = pair_penalty + penalty12 + if mask13.sum().gt(0): + penalty13 = sym_kld( + mod_params1[0], + mod_params1[1].sqrt(), + mod_params3[0], + mod_params3[1].sqrt(), + ) + penalty13 = torch.where(mask13, penalty13.T, torch.zeros_like(penalty13).T).sum( + dim=0 + ) + pair_penalty = pair_penalty + penalty13 + if mask23.sum().gt(0): + penalty23 = sym_kld( + mod_params2[0], + mod_params2[1].sqrt(), + mod_params3[0], + mod_params3[1].sqrt(), + ) + penalty23 = torch.where(mask23, penalty23.T, torch.zeros_like(penalty23).T).sum( + dim=0 + ) + pair_penalty = pair_penalty + penalty23 + + elif self.modality_penalty == "MMD": + pair_penalty = torch.zeros(mask1.shape[0], device=mask1.device, requires_grad=True) + if mask12.sum().gt(0): + penalty12 = torch.linalg.norm(mod_params1[0] - mod_params2[0], dim=1) + penalty12 = torch.where(mask12, penalty12.T, torch.zeros_like(penalty12).T).sum( + dim=0 + ) + pair_penalty = pair_penalty + penalty12 + if mask13.sum().gt(0): + penalty13 = torch.linalg.norm(mod_params1[0] - mod_params3[0], dim=1) + penalty13 = torch.where(mask13, penalty13.T, torch.zeros_like(penalty13).T).sum( + dim=0 + ) + pair_penalty = pair_penalty + penalty13 + if mask23.sum().gt(0): + penalty23 = torch.linalg.norm(mod_params2[0] - mod_params3[0], dim=1) + penalty23 = torch.where(mask23, penalty23.T, torch.zeros_like(penalty23).T).sum( + dim=0 + ) + pair_penalty = pair_penalty + penalty23 + else: + raise ValueError("modality penalty not supported") + + return pair_penalty + + +@auto_move_data +def mix_modalities(Xs, masks, weights, weight_transform: callable = None): + """Compute the weighted mean of the Xs while masking unmeasured modality values. + + Parameters + ---------- + Xs + Sequence of Xs to mix, each should be (N x D) + masks + Sequence of masks corresponding to the Xs, indicating whether the values + should be included in the mix or not (N) + weights + Weights for each modality (either K or N x K) + weight_transform + Transformation to apply to the weights before using them + """ + # (batch_size x latent) -> (batch_size x modalities x latent) + Xs = torch.stack(Xs, dim=1) + # (batch_size) -> (batch_size x modalities) + masks = torch.stack(masks, dim=1).float() + weights = masked_softmax(weights, masks, dim=-1) + + # (batch_size x modalities) -> (batch_size x modalities x latent) + weights = weights.unsqueeze(-1) + if weight_transform is not None: + weights = weight_transform(weights) + + # sum over modalities, so output is (batch_size x latent) + return (weights * Xs).sum(1) + + +@auto_move_data +def sym_kld(qzm1, qzv1, qzm2, qzv2): + """Symmetric KL divergence between two Gaussians.""" + rv1 = Normal(qzm1, qzv1.sqrt()) + rv2 = Normal(qzm2, qzv2.sqrt()) + + return kld(rv1, rv2) + kld(rv2, rv1) + + +@auto_move_data +def get_reconstruction_loss_protein(y, py_, pro_batch_mask_minibatch=None): + """Get the reconstruction loss for protein data.""" + py_conditional = NegativeBinomialMixture( + mu1=py_["rate_back"], + mu2=py_["rate_fore"], + theta1=py_["r"], + mixture_logits=py_["mixing"], + ) + + reconst_loss_protein_full = -py_conditional.log_prob(y) + + if pro_batch_mask_minibatch is not None: + temp_pro_loss_full = pro_batch_mask_minibatch.bool() * reconst_loss_protein_full + rl_protein = temp_pro_loss_full.sum(dim=-1) + else: + rl_protein = reconst_loss_protein_full.sum(dim=-1) + + return rl_protein diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 5a34d54cbe..73dbf9ee2e 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -4,7 +4,7 @@ import numpy as np import torch from torch import nn -from torch.distributions import Normal, Poisson +from torch.distributions import Normal from torch.distributions import kl_divergence as kld from torch.nn import functional as F @@ -12,6 +12,7 @@ from scvi.distributions import ( NegativeBinomial, NegativeBinomialMixture, + Poisson, ZeroInflatedNegativeBinomial, ) from scvi.module._peakvae import Decoder as DecoderPeakVI @@ -32,11 +33,12 @@ def __init__( n_hidden: int = 128, use_batch_norm: bool = False, use_layer_norm: bool = True, - deep_inject_covariates: bool = False, + inject_covariates: bool = False, + output_fn: str | None = "LeakyReLU", **kwargs, ): super().__init__() - self.px_decoder = FCLayers( + self.encoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, @@ -46,14 +48,20 @@ def __init__( activation_fn=torch.nn.LeakyReLU, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, - inject_covariates=deep_inject_covariates, + inject_covariates=inject_covariates, **kwargs, ) - self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU()) + if output_fn=="LeakyReLU": + output_fn = nn.LeakyReLU() + elif output_fn=="sigmoid": + output_fn = nn.Sigmoid() + else: + output_fn = nn.Identity() + self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), output_fn) def forward(self, x: torch.Tensor, *cat_list: int): """Forward pass.""" - return self.output(self.px_decoder(x, *cat_list)) + return self.output(self.encoder(x, *cat_list)) class DecoderADT(torch.nn.Module): @@ -69,7 +77,7 @@ def __init__( dropout_rate: float = 0.1, use_batch_norm: bool = False, use_layer_norm: bool = True, - deep_inject_covariates: bool = False, + inject_covariates: bool = False, ): super().__init__() self.n_output_proteins = n_output_proteins @@ -158,7 +166,10 @@ def forward(self, z: torch.Tensor, *cat_list: int): py_back_cat_z = torch.cat([py_back, z], dim=-1) py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list) - py_["back_beta"] = torch.exp(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) + py_["back_beta"] = ( + torch.nn.functional.softplus(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) + + 1e-8 + ) log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample() py_["rate_back"] = torch.exp(log_pro_back_mean) @@ -200,6 +211,10 @@ class MULTIVAE(BaseModuleClass): * ``"Jeffreys"``: Jeffreys penalty to align modalities * ``"MMD"``: MMD penalty to align modalities * ``"None"``: No penalty + mix_modality + How are modalities latent parameters combined. One of the following: + * ``"moe"``: Mixture of experts + * ``"poe"``: Product of experts n_batch Number of batches, if 0, no batch correction is performed. gene_likelihood @@ -213,6 +228,18 @@ class MULTIVAE(BaseModuleClass): * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell + atac_likelihood + The distribution to use for ATAC-seq data. One of the following + * ``'zinb'`` - Zero-Inflated Negative Binomial + * ``'nb'`` - Negative Binomial + * ``'poisson'`` - Poisson + * ``'bernoulli'`` - Bernoulli + atac_dispersion + One of the following: + * ``'peak'`` - dispersion parameter of NB is constant per peak across cells + * ``'peak-batch'`` - dispersion can differ between different batches + * ``'peak-label'`` - dispersion can differ between different labels + * ``'peak-cell'`` - dispersion can differ for every peak in every cell protein_dispersion One of the following: @@ -233,6 +260,8 @@ class MULTIVAE(BaseModuleClass): Dropout rate for neural networks region_factors Include region-specific factors in the model + scale_region_factors + Scale region factors by a fixed number to speed up convergence use_batch_norm One of the following * ``'encoder'`` - use batch normalization in the encoder only @@ -266,11 +295,14 @@ def __init__( n_input_proteins: int = 0, modality_weights: Literal["equal", "cell", "universal"] = "equal", modality_penalty: Literal["Jeffreys", "MMD", "None"] = "Jeffreys", + mix_modality: Literal["moe", "poe"] = "moe", n_batch: int = 0, n_obs: int = 0, n_labels: int = 0, gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", + atac_likelihood: Literal["zinb", "nb", "poisson", "bernoulli"] = "bernoulli", + atac_dispersion: Literal["peak", "peak-batch", "peak-label", "peak-cell"] = "peak", n_hidden: int = None, n_latent: int = None, n_layers_encoder: int = 2, @@ -279,6 +311,7 @@ def __init__( n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, region_factors: bool = True, + scale_region_factors: float = 100., use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: Literal["normal", "ln"] = "normal", @@ -374,7 +407,8 @@ def __init__( n_hidden=self.n_hidden, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, - deep_inject_covariates=self.deeply_inject_covariates, + inject_covariates=self.deeply_inject_covariates, + output_fn="none", ) # expression decoder @@ -392,6 +426,24 @@ def __init__( ) # accessibility + # atac dispersion parameters + self.atac_likelihood = atac_likelihood + self.atac_dispersion = atac_dispersion + if self.atac_dispersion == "peak": + self.px_r_atac = torch.nn.Parameter(2 * torch.rand(self.n_input_regions)) + elif self.atac_dispersion == "peak-batch": + self.px_r_atac = torch.nn.Parameter(2 * torch.rand(self.n_input_regions, n_batch)) + elif self.atac_dispersion == "peak-label": + self.px_r_atac = torch.nn.Parameter(2 * torch.rand(self.n_input_regions, n_labels)) + elif self.atac_dispersion == "peak-cell": + pass + else: + raise ValueError( + "dispersion must be one of ['gene', 'gene-batch'," + " 'gene-label', 'gene-cell'], but input was " + "{}.format(self.dispersion)" + ) + # accessibility encoder if self.n_input_regions == 0: input_acc = 1 @@ -415,35 +467,47 @@ def __init__( # accessibility region-specific factors self.region_factors = None + self.scale_region_factors = scale_region_factors if region_factors: self.region_factors = torch.nn.Parameter(torch.zeros(self.n_input_regions)) # accessibility decoder - self.z_decoder_accessibility = DecoderPeakVI( - n_input=self.n_latent + self.n_continuous_cov, - n_output=n_input_regions, - n_hidden=self.n_hidden, + if self.atac_likelihood == 'bernoulli': + atac_decoder_fn = DecoderPeakVI + decoder_atac_kwargs = {} + atac_library_kwargs = {"output_fn": "sigmoid"} + else: + atac_decoder_fn = DecoderSCVI + decoder_atac_kwargs = { + "scale_activation": "softplus" if use_size_factor_key else "softmax" + } + atac_library_kwargs = {"output_fn": "none"} + + self.z_decoder_accessibility = atac_decoder_fn( + n_input_decoder, + n_input_regions, n_cat_list=cat_list, - n_layers=self.n_layers_decoder, + n_layers=n_layers_decoder, + n_hidden=self.n_hidden, + inject_covariates=self.deeply_inject_covariates, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, - deep_inject_covariates=self.deeply_inject_covariates, + **decoder_atac_kwargs ) # accessibility library size encoder - self.l_encoder_accessibility = DecoderPeakVI( - n_input=n_input_encoder_acc, - n_output=1, - n_hidden=self.n_hidden, + self.l_encoder_accessibility = LibrarySizeEncoder( + n_input_encoder_acc, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, + n_hidden=self.n_hidden, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, - deep_inject_covariates=self.deeply_inject_covariates, + inject_covariates=self.deeply_inject_covariates, + **atac_library_kwargs ) # protein - # protein encoder self.protein_dispersion = protein_dispersion if protein_background_prior_mean is None: if n_batch > 0: @@ -502,7 +566,7 @@ def __init__( n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, - deep_inject_covariates=self.deeply_inject_covariates, + inject_covariates=self.deeply_inject_covariates, ) # protein dispersion parameters @@ -528,6 +592,8 @@ def __init__( self.mod_weights = torch.nn.Parameter(torch.ones(max_n_modalities)) else: # cell-specific weights self.mod_weights = torch.nn.Parameter(torch.ones(n_obs, max_n_modalities)) + self.mix_modality = mix_modality + assert mix_modality in ["moe", "poe"], "mix_modality must be one of ['moe', 'poe']" def _get_inference_input(self, tensors): """Get input tensors for the inference model.""" @@ -621,13 +687,15 @@ def inference( if self.use_size_factor_key: libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) libsize_acc = size_factor[:, [1]] + if self.atac_likelihood != "bernoulli": + libsize_acc = torch.log(libsize_acc + 1e-6) else: libsize_acc = self.l_encoder_accessibility( encoder_input_accessibility, batch_index, *categorical_input ) - libsize_expr = self.l_encoder_expression( - encoder_input_expression, batch_index, *categorical_input - ) + libsize_expr = self.l_encoder_expression( + encoder_input_expression, batch_index, *categorical_input + ) # mix representations if self.modality_weights == "cell": @@ -635,15 +703,25 @@ def inference( else: weights = self.mod_weights.unsqueeze(0).expand(x.shape[0], -1) - qz_m = mix_modalities( - (qzm_expr, qzm_acc, qzm_pro), (mask_expr, mask_acc, mask_pro), weights - ) - qz_v = mix_modalities( - (qzv_expr, qzv_acc, qzv_pro), - (mask_expr, mask_acc, mask_pro), - weights, - torch.sqrt, - ) + if self.mix_modality == "moe": + qz_m = mixture_of_expert( + (qzm_expr, qzm_acc, qzm_pro), + (mask_expr, mask_acc, mask_pro), + weights + ) + qz_v = mixture_of_expert( + (qzv_expr, qzv_acc, qzv_pro), + (mask_expr, mask_acc, mask_pro), + weights, + ) + print('FFFFFF', qz_v.mean(), (qzv_expr.mean(), qzv_acc.mean())) + else: + qz_m, qz_v = product_of_expert( + (qzm_expr, qzm_acc, qzm_pro), + (qzv_expr, qzv_acc, qzv_pro), + (mask_expr, mask_acc, mask_pro), + weights, + ) # sample if n_samples > 1: @@ -689,6 +767,7 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None z = inference_outputs["z"] qz_m = inference_outputs["qz_m"] libsize_expr = inference_outputs["libsize_expr"] + libsize_acc = inference_outputs["libsize_acc"] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY @@ -709,6 +788,7 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None "cont_covs": cont_covs, "cat_covs": cat_covs, "libsize_expr": libsize_expr, + "libsize_acc": libsize_acc, "label": label, } return input_dict @@ -722,6 +802,7 @@ def generative( cont_covs=None, cat_covs=None, libsize_expr=None, + libsize_acc=None, use_z_mean=False, label: torch.Tensor = None, ): @@ -742,10 +823,53 @@ def generative( decoder_input = torch.cat([latent, cont_covs], dim=-1) # Accessibility Decoder - p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) + region_factor = ( + torch.sigmoid(self.scale_region_factors * self.region_factors) + if self.region_factors is not None else 1. + ) + if self.atac_likelihood == "bernoulli": + p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) + px_atac = {'px_rate': libsize_acc * region_factor * p, 'px_scale': p} + else: + # ATAC Decoder + px_scale_atac, px_r_atac, px_rate_atac, px_dropout_atac = self.z_decoder_accessibility( + self.atac_dispersion, + decoder_input, + libsize_acc, + batch_index, + *categorical_input, + label, + ) + # scale by 2 to match the initial scale of the region factor (0.5). + px_rate_atac = px_rate_atac * region_factor + # ATAC Dispersion + if self.atac_dispersion == "peak-label": + px_r_atac = F.linear( + F.one_hot(label.squeeze(-1), self.n_labels).float(), self.px_r_atac + ) # px_r gets transposed - last dimension is nb genes + elif self.atac_dispersion == "peak-batch": + px_r_atac = F.linear( + F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), + self.px_r_atac) + elif self.atac_dispersion == "peak": + px_r_atac = self.px_r_atac + px_r_atac = torch.exp(px_r_atac) + if self.atac_likelihood == "zinb": + px_atac = ZeroInflatedNegativeBinomial( + mu=px_rate_atac, + theta=px_r_atac, + zi_logits=px_dropout_atac, + scale=px_scale_atac, + ) + elif self.atac_likelihood == "nb": + px_atac = NegativeBinomial(mu=px_rate_atac, theta=px_r_atac, scale=px_scale_atac) + elif self.atac_likelihood == "poisson": + px_atac = Poisson(rate=px_rate_atac, scale=px_scale_atac) + elif self.atac_likelihood == "normal": + px_atac = Normal(px_rate_atac, px_r_atac, normal_mu=px_scale_atac) # Expression Decoder - px_scale, _, px_rate, px_dropout = self.z_decoder_expression( + px_scale, px_scale, px_rate, px_dropout = self.z_decoder_expression( self.gene_dispersion, decoder_input, libsize_expr, @@ -764,6 +888,20 @@ def generative( px_r = self.px_r px_r = torch.exp(px_r) + if self.gene_likelihood == "zinb": + px = ZeroInflatedNegativeBinomial( + mu=px_rate, + theta=px_r, + zi_logits=px_dropout, + scale=px_scale, + ) + elif self.gene_likelihood == "nb": + px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) + elif self.gene_likelihood == "poisson": + px = Poisson(rate=px_rate, scale=px_scale) + elif self.gene_likelihood == "normal": + px = Normal(px_rate, px_r, normal_mu=px_scale) + # Protein Decoder py_, log_pro_back_mean = self.z_decoder_pro(decoder_input, batch_index, *categorical_input) # Protein Dispersion @@ -778,11 +916,8 @@ def generative( py_["r"] = py_r return { - "p": p, - "px_scale": px_scale, - "px_r": torch.exp(self.px_r), - "px_rate": px_rate, - "px_dropout": px_dropout, + "px_atac": px_atac, + "px": px, "py_": py_, "log_pro_back_mean": log_pro_back_mean, } @@ -793,7 +928,7 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float x = inference_outputs["x"] x_rna = x[:, : self.n_input_genes] - x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + x_atac = x[:, self.n_input_genes:] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: @@ -804,18 +939,15 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float mask_pro = y.sum(dim=1) > 0 # Compute Accessibility loss - p = generative_outputs["p"] - libsize_acc = inference_outputs["libsize_acc"] - rl_accessibility = self.get_reconstruction_loss_accessibility(x_atac, p, libsize_acc) + px_atac = generative_outputs["px_atac"] + if self.atac_likelihood == "bernoulli": + rl_accessibility = self._get_reconstruction_loss_bernoulli( + x_atac, px_atac["px_rate"]) + else: + rl_accessibility = - px_atac.log_prob(x_atac).sum(-1) # Compute Expression loss - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] - px_dropout = generative_outputs["px_dropout"] - x_expression = x[:, : self.n_input_genes] - rl_expression = self.get_reconstruction_loss_expression( - x_expression, px_rate, px_r, px_dropout - ) + rl_expression = - generative_outputs["px"].log_prob(x_rna).sum(-1) # Compute Protein loss - No ability to mask minibatch (Param:None) if mask_pro.sum().gt(0): @@ -866,25 +998,10 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float } return LossOutput(loss=loss, reconstruction_loss=recon_losses, kl_local=kl_local) - def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout): - """Computes the reconstruction loss for the expression data.""" - rl = 0.0 - if self.gene_likelihood == "zinb": - rl = ( - -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) - .log_prob(x) - .sum(dim=-1) - ) - elif self.gene_likelihood == "nb": - rl = -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) - elif self.gene_likelihood == "poisson": - rl = -Poisson(px_rate).log_prob(x).sum(dim=-1) - return rl - - def get_reconstruction_loss_accessibility(self, x, p, d): + def _get_reconstruction_loss_bernoulli(self, x, p): """Computes the reconstruction loss for the accessibility data.""" - reg_factor = torch.sigmoid(self.region_factors) if self.region_factors is not None else 1 - return torch.nn.BCELoss(reduction="none")(p * d * reg_factor, (x > 0).float()).sum(dim=-1) + # Scaling improves convergence speed. Otherwise region_factors takes long to train. + return torch.nn.BCELoss(reduction="none")(p, (x > 0).float()).sum(dim=-1) def _compute_mod_penalty(self, mod_params1, mod_params2, mod_params3, mask1, mask2, mask3): """Computes Similarity Penalty across modalities given selection (None, Jeffreys, MMD). @@ -965,7 +1082,7 @@ def _compute_mod_penalty(self, mod_params1, mod_params2, mod_params3, mask1, mas @auto_move_data -def mix_modalities(Xs, masks, weights, weight_transform: callable = None): +def mixture_of_expert(Xs, masks, weights): """Compute the weighted mean of the Xs while masking unmeasured modality values. Parameters @@ -977,8 +1094,6 @@ def mix_modalities(Xs, masks, weights, weight_transform: callable = None): should be included in the mix or not (N) weights Weights for each modality (either K or N x K) - weight_transform - Transformation to apply to the weights before using them """ # (batch_size x latent) -> (batch_size x modalities x latent) Xs = torch.stack(Xs, dim=1) @@ -988,13 +1103,44 @@ def mix_modalities(Xs, masks, weights, weight_transform: callable = None): # (batch_size x modalities) -> (batch_size x modalities x latent) weights = weights.unsqueeze(-1) - if weight_transform is not None: - weights = weight_transform(weights) - # sum over modalities, so output is (batch_size x latent) return (weights * Xs).sum(1) +@auto_move_data +def product_of_expert(mus, vars, masks, weights): + """Compute the weighted mean of the Xs while masking unmeasured modality values. + + Parameters + ---------- + mus + Sequence of mus to mix, each should be (N x D) + vars + Sequence of vars to mix, each should be (N x D) + masks + Sequence of masks corresponding to the Xs, indicating whether the values + should be included in the mix or not (N) + weights + Weights for each modality (either K or N x K) + """ + mus = torch.stack(mus, dim=1) + vars = torch.stack(vars, dim=1) + masks = torch.stack(masks, dim=1).float() + weights = masked_softmax(weights, masks, dim=-1).unsqueeze(-1) + + # Compute precision (inverse variance) for each expert + precisions = masks.unsqueeze(-1) / vars # (N, K, D) + weighted_precisions = precisions * weights # (N, K, D) + joint_precision = weighted_precisions.sum(dim=1) # Sum across modalities (K -> joint expert) + joint_variance = 1.0 / joint_precision # Joint variance (N, D) + + # Joint mean + weighted_mean = (mus * weighted_precisions).sum(dim=1) # Sum weighted means + joint_mean = weighted_mean * joint_variance # Scale by joint variance + + return joint_mean, joint_variance + + @auto_move_data def sym_kld(qzm1, qzv1, qzm2, qzv2): """Symmetric KL divergence between two Gaussians.""" diff --git a/src/scvi/module/_peakvae.py b/src/scvi/module/_peakvae.py index 10f4898354..3049dfaace 100644 --- a/src/scvi/module/_peakvae.py +++ b/src/scvi/module/_peakvae.py @@ -52,7 +52,7 @@ def __init__( n_hidden: int = 128, use_batch_norm: bool = False, use_layer_norm: bool = True, - deep_inject_covariates: bool = False, + inject_covariates: bool = False, **kwargs, ): super().__init__() @@ -66,7 +66,7 @@ def __init__( activation_fn=torch.nn.LeakyReLU, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, - inject_covariates=deep_inject_covariates, + inject_covariates=inject_covariates, **kwargs, ) self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, n_output), torch.nn.Sigmoid()) @@ -204,7 +204,7 @@ def __init__( n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, - deep_inject_covariates=self.deeply_inject_covariates, + inject_covariates=self.deeply_inject_covariates, **_extra_decoder_kwargs, ) From fa7b4579660ff951da24f3a840084cecd6d8978d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 08:12:19 +0000 Subject: [PATCH 50/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/model/_multivi.py | 4 +++- src/scvi/module/_multivae copy.py | 28 +++++++++++----------- src/scvi/module/_multivae.py | 40 +++++++++++++++---------------- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index d93f553f62..09cf1f190c 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -441,7 +441,7 @@ def get_library_size_factors( @torch.inference_mode() def get_region_factors(self, return_numpy=True) -> np.ndarray: """Return region-specific factors.""" - if self.n_regions == 0 : + if self.n_regions == 0: return np.zeros(1) else: if self.module.region_factors is None: @@ -871,9 +871,11 @@ def change_fn(a, b): return a - b if two_sided: + def m1_domain_fn(samples): return np.abs(samples) >= delta else: + def m1_domain_fn(samples): return samples >= delta diff --git a/src/scvi/module/_multivae copy.py b/src/scvi/module/_multivae copy.py index 1bedd5d928..29275a19e6 100644 --- a/src/scvi/module/_multivae copy.py +++ b/src/scvi/module/_multivae copy.py @@ -160,8 +160,8 @@ def forward(self, z: torch.Tensor, *cat_list: int): py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list) py_["back_beta"] = ( - torch.nn.functional.softplus(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) + - 1e-8 + torch.nn.functional.softplus(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) + + 1e-8 ) log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample() py_["rate_back"] = torch.exp(log_pro_back_mean) @@ -287,7 +287,7 @@ def __init__( n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, region_factors: bool = True, - scale_region_factors: float = 1., + scale_region_factors: float = 1.0, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: Literal["normal", "ln"] = "normal", @@ -447,7 +447,7 @@ def __init__( self.region_factors = torch.nn.Parameter(torch.zeros(self.n_input_regions)) # accessibility decoder - if self.atac_likelihood == 'bernoulli': + if self.atac_likelihood == "bernoulli": atac_decoder_fn = DecoderPeakVI decoder_atac_kwargs = {} else: @@ -465,7 +465,7 @@ def __init__( inject_covariates=self.deeply_inject_covariates, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, - **decoder_atac_kwargs + **decoder_atac_kwargs, ) # accessibility library size encoder @@ -787,11 +787,12 @@ def generative( # Accessibility Decoder region_factor = ( torch.sigmoid(self.scale_region_factors * self.region_factors) - if self.region_factors is not None else 1. + if self.region_factors is not None + else 1.0 ) if self.atac_likelihood == "bernoulli": p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) - px_atac = {'px_rate': libsize_acc * region_factor * p, 'px_scale': p} + px_atac = {"px_rate": libsize_acc * region_factor * p, "px_scale": p} else: # ATAC Decoder px_scale_atac, px_r_atac, px_rate_atac, px_dropout_atac = self.z_decoder_accessibility( @@ -811,8 +812,8 @@ def generative( ) # px_r gets transposed - last dimension is nb genes elif self.atac_dispersion == "peak-batch": px_r_atac = F.linear( - F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), - self.px_r_atac) + F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.px_r_atac + ) elif self.atac_dispersion == "peak": px_r_atac = self.px_r_atac px_r_atac = torch.exp(px_r_atac) @@ -890,7 +891,7 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float x = inference_outputs["x"] x_rna = x[:, : self.n_input_genes] - x_atac = x[:, self.n_input_genes:] + x_atac = x[:, self.n_input_genes :] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: @@ -903,13 +904,12 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float # Compute Accessibility loss px_atac = generative_outputs["px_atac"] if self.atac_likelihood == "bernoulli": - rl_accessibility = self._get_reconstruction_loss_bernoulli( - x_atac, px_atac["px_rate"]) + rl_accessibility = self._get_reconstruction_loss_bernoulli(x_atac, px_atac["px_rate"]) else: - rl_accessibility = - px_atac.log_prob(x_atac).sum(-1) + rl_accessibility = -px_atac.log_prob(x_atac).sum(-1) # Compute Expression loss - rl_expression = - generative_outputs["px"].log_prob(x_rna).sum(-1) + rl_expression = -generative_outputs["px"].log_prob(x_rna).sum(-1) # Compute Protein loss - No ability to mask minibatch (Param:None) if mask_pro.sum().gt(0): diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 73dbf9ee2e..686e1fafce 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -51,9 +51,9 @@ def __init__( inject_covariates=inject_covariates, **kwargs, ) - if output_fn=="LeakyReLU": + if output_fn == "LeakyReLU": output_fn = nn.LeakyReLU() - elif output_fn=="sigmoid": + elif output_fn == "sigmoid": output_fn = nn.Sigmoid() else: output_fn = nn.Identity() @@ -167,8 +167,8 @@ def forward(self, z: torch.Tensor, *cat_list: int): py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list) py_["back_beta"] = ( - torch.nn.functional.softplus(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) + - 1e-8 + torch.nn.functional.softplus(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) + + 1e-8 ) log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample() py_["rate_back"] = torch.exp(log_pro_back_mean) @@ -311,7 +311,7 @@ def __init__( n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, region_factors: bool = True, - scale_region_factors: float = 100., + scale_region_factors: float = 100.0, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: Literal["normal", "ln"] = "normal", @@ -472,7 +472,7 @@ def __init__( self.region_factors = torch.nn.Parameter(torch.zeros(self.n_input_regions)) # accessibility decoder - if self.atac_likelihood == 'bernoulli': + if self.atac_likelihood == "bernoulli": atac_decoder_fn = DecoderPeakVI decoder_atac_kwargs = {} atac_library_kwargs = {"output_fn": "sigmoid"} @@ -492,7 +492,7 @@ def __init__( inject_covariates=self.deeply_inject_covariates, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, - **decoder_atac_kwargs + **decoder_atac_kwargs, ) # accessibility library size encoder @@ -504,7 +504,7 @@ def __init__( use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, inject_covariates=self.deeply_inject_covariates, - **atac_library_kwargs + **atac_library_kwargs, ) # protein @@ -705,16 +705,14 @@ def inference( if self.mix_modality == "moe": qz_m = mixture_of_expert( - (qzm_expr, qzm_acc, qzm_pro), - (mask_expr, mask_acc, mask_pro), - weights + (qzm_expr, qzm_acc, qzm_pro), (mask_expr, mask_acc, mask_pro), weights ) qz_v = mixture_of_expert( (qzv_expr, qzv_acc, qzv_pro), (mask_expr, mask_acc, mask_pro), weights, ) - print('FFFFFF', qz_v.mean(), (qzv_expr.mean(), qzv_acc.mean())) + print("FFFFFF", qz_v.mean(), (qzv_expr.mean(), qzv_acc.mean())) else: qz_m, qz_v = product_of_expert( (qzm_expr, qzm_acc, qzm_pro), @@ -825,11 +823,12 @@ def generative( # Accessibility Decoder region_factor = ( torch.sigmoid(self.scale_region_factors * self.region_factors) - if self.region_factors is not None else 1. + if self.region_factors is not None + else 1.0 ) if self.atac_likelihood == "bernoulli": p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) - px_atac = {'px_rate': libsize_acc * region_factor * p, 'px_scale': p} + px_atac = {"px_rate": libsize_acc * region_factor * p, "px_scale": p} else: # ATAC Decoder px_scale_atac, px_r_atac, px_rate_atac, px_dropout_atac = self.z_decoder_accessibility( @@ -849,8 +848,8 @@ def generative( ) # px_r gets transposed - last dimension is nb genes elif self.atac_dispersion == "peak-batch": px_r_atac = F.linear( - F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), - self.px_r_atac) + F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.px_r_atac + ) elif self.atac_dispersion == "peak": px_r_atac = self.px_r_atac px_r_atac = torch.exp(px_r_atac) @@ -928,7 +927,7 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float x = inference_outputs["x"] x_rna = x[:, : self.n_input_genes] - x_atac = x[:, self.n_input_genes:] + x_atac = x[:, self.n_input_genes :] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: @@ -941,13 +940,12 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float # Compute Accessibility loss px_atac = generative_outputs["px_atac"] if self.atac_likelihood == "bernoulli": - rl_accessibility = self._get_reconstruction_loss_bernoulli( - x_atac, px_atac["px_rate"]) + rl_accessibility = self._get_reconstruction_loss_bernoulli(x_atac, px_atac["px_rate"]) else: - rl_accessibility = - px_atac.log_prob(x_atac).sum(-1) + rl_accessibility = -px_atac.log_prob(x_atac).sum(-1) # Compute Expression loss - rl_expression = - generative_outputs["px"].log_prob(x_rna).sum(-1) + rl_expression = -generative_outputs["px"].log_prob(x_rna).sum(-1) # Compute Protein loss - No ability to mask minibatch (Param:None) if mask_pro.sum().gt(0): From a13f29d9575e5fbdec9f71e189974ee4c58e4a3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:45:31 +0000 Subject: [PATCH 51/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55d41a1103..88f7f25d2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,7 +37,7 @@ to [Semantic Versioning]. Full commit history is available in the {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`. - Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell bisulfite sequencing (scBS-seq) experiments {pr}`2834`. - + #### Fixed - Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`