diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 6ff469f5a9..3d913fce03 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 6ff469f5a9ec3e26324fcb27ac487d8486c6942f +Subproject commit 3d913fce03a15ac42f46844840cd831e9b29d8ab diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/scvi/external/__init__.py b/src/scvi/external/__init__.py index 4e46ca1846..c1457c2485 100644 --- a/src/scvi/external/__init__.py +++ b/src/scvi/external/__init__.py @@ -8,6 +8,7 @@ from .scbasset import SCBASSET from .solo import SOLO from .stereoscope import RNAStereoscope, SpatialStereoscope +from .sysvi import SysVI from .tangram import Tangram from .velovi import VELOVI @@ -22,6 +23,7 @@ "SCBASSET", "POISSONVI", "ContrastiveVI", + "SysVI", "VELOVI", "MRVI", "METHYLVI", diff --git a/src/scvi/external/sysvi/__init__.py b/src/scvi/external/sysvi/__init__.py new file mode 100644 index 0000000000..5bd0403a95 --- /dev/null +++ b/src/scvi/external/sysvi/__init__.py @@ -0,0 +1,4 @@ +from ._model import SysVI +from ._module import SysVAE + +__all__ = ["SysVI", "SysVAE"] diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py new file mode 100644 index 0000000000..7553cb1bed --- /dev/null +++ b/src/scvi/external/sysvi/_base_components.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +import collections +import warnings +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from typing import Literal + +import numpy as np +import torch +from torch import nn +from torch.distributions import Normal +from torch.nn import ( + Linear, + Module, + Parameter, +) + +from scvi import settings + + +class EncoderDecoder(Module): + """Module that can be used as probabilistic encoder or decoder. + + Based on inputs and optional covariates predicts output mean and variance. + + Parameters + ---------- + n_input + The dimensionality of the main input. + n_output + The dimensionality of the output. + n_cat_list + A list containing the number of categories for each covariate. + n_cont + The dimensionality of the continuous covariates. + n_hidden + The number of nodes per hidden layer. + n_layers + The number of hidden layers. + var_mode + How to compute variance from model outputs, + see :class:`~scvi.external.sysvi.VarEncoder`. + One of the following: + * ```'sample_feature'``` - learn variance per sample and feature. + * ```'feature'``` - learn variance per feature, constant across samples. + var_activation + Function used to ensure positivity of the variance. + Defaults to :meth:`torch.exp`. + sample + Return samples from predicted distribution. + kwargs + Passed to :class:`~scvi.external.sysvi.Layers`. + """ + + def __init__( + self, + n_input: int, + n_output: int, + n_cat_list: list[int], + n_cont: int, + n_hidden: int = 256, + n_layers: int = 3, + var_mode: Literal["sample_feature", "feature"] = "feature", + var_activation: Callable | None = None, + sample: bool = False, + **kwargs, + ): + super().__init__() + self.sample = sample + + self.decoder_y = FCLayers( + n_in=n_input, + n_cat_list=n_cat_list, + n_cont=n_cont, + n_out=n_hidden, + n_hidden=n_hidden, + n_layers=n_layers, + **kwargs, + ) + + self.mean_encoder = Linear(n_hidden, n_output) + self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, activation=var_activation) + + def forward( + self, + x: torch.Tensor, + cont: torch.Tensor | None = None, + cat_list: list[torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass. + + Parameters + ---------- + x + Main input (i.e. expression for encoder or + latent embedding for decoder.). + dim = n_samples * n_input + cont + Optional continuous covariates. + dim = n_samples * n_cont + cat_list + List of optional categorical covariates. + Will be one hot encoded in `~scvi.nn.FCLayers`. + Each list entry is of dim = n_samples * 1 + + Returns + ------- + Predicted mean (``'y_m'``) and variance (``'y_v'``) and + optionally samples (``'y'``) from normal distribution + parametrized with the predicted parameters. + """ + y = self.decoder_y(x=x, cont=cont, cat_list=cat_list) + y_m = self.mean_encoder(y) + if y_m.isnan().any() or y_m.isinf().any(): + warnings.warn( + "Predicted mean contains nan or inf values. " + "Setting to numerical.", + stacklevel=settings.warnings_stacklevel, + ) + y_m = torch.nan_to_num(y_m) + y_v = self.var_encoder(y) + + outputs = {"y_m": y_m, "y_v": y_v} + + if self.sample: + y = Normal(y_m, y_v.sqrt()).rsample() + outputs["y"] = y + + return outputs + + +class FCLayers(nn.Module): + """A helper class to build fully-connected layers for a neural network. + + FCLayers class of scvi-tools adapted to also inject continous covariates. + + The only adaptation is addition of `n_cont` parameter in init + and `cont` in forward, with the associated handling of the two. + The forward method signature is changed to account for optional `cont`. + + Parameters + ---------- + n_in + The dimensionality of the input + n_out + The dimensionality of the output + n_cat_list + The number of categorical covariates and + the number of category levels. + A list containing, for each covariate of interest, + the number of categories. Each covariate will be + included using a one-hot encoding. + n_cont + The number of continuous covariates. + n_layers + The number of fully-connected hidden layers + n_hidden + The number of nodes per hidden layer + dropout_rate + Dropout rate to apply to each of the hidden layers + use_batch_norm + Whether to have `BatchNorm` layers or not + use_layer_norm + Whether to have `LayerNorm` layers or not + use_activation + Whether to have layer activation or not + bias + Whether to learn bias in linear layers or not + inject_covariates + Whether to inject covariates in each layer, or just the first (default). + activation_fn + Which activation function to use + """ + + def __init__( + self, + n_in: int, + n_out: int, + n_cat_list: Iterable[int] = None, + n_cont: int = 0, + n_layers: int = 1, + n_hidden: int = 128, + dropout_rate: float = 0.1, + use_batch_norm: bool = True, + use_layer_norm: bool = False, + use_activation: bool = True, + bias: bool = True, + inject_covariates: bool = True, + activation_fn: nn.Module = nn.ReLU, + ): + super().__init__() + self.inject_covariates = inject_covariates + layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out] + + if n_cat_list is not None: + # n_cat = 1 will be ignored + self.n_cat_list = [n_cat if n_cat > 1 else 0 for n_cat in n_cat_list] + else: + self.n_cat_list = [] + + self.n_cov = sum(self.n_cat_list) + n_cont + self.fc_layers = nn.Sequential( + collections.OrderedDict( + [ + ( + f"Layer {i}", + nn.Sequential( + nn.Linear( + n_in + self.n_cov * self.inject_into_layer(i), + n_out, + bias=bias, + ), + # non-default params come from defaults + # in original Tensorflow + # implementation + nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001) + if use_batch_norm + else None, + nn.LayerNorm(n_out, elementwise_affine=False) + if use_layer_norm + else None, + activation_fn() if use_activation else None, + nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None, + ), + ) + for i, (n_in, n_out) in enumerate( + zip(layers_dim[:-1], layers_dim[1:], strict=True) + ) + ] + ) + ) + + def inject_into_layer(self, layer_num) -> bool: + """Helper to determine if covariates should be injected.""" + user_cond = layer_num == 0 or (layer_num > 0 and self.inject_covariates) + return user_cond + + def set_online_update_hooks(self, hook_first_layer=True): + """Set online update hooks.""" + self.hooks = [] + + def _hook_fn_weight(grad): + new_grad = torch.zeros_like(grad) + if self.n_cov > 0: + new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] + return new_grad + + def _hook_fn_zero_out(grad): + return grad * 0 + + for i, layers in enumerate(self.fc_layers): + for layer in layers: + if i == 0 and not hook_first_layer: + continue + if isinstance(layer, nn.Linear): + if self.inject_into_layer(i): + w = layer.weight.register_hook(_hook_fn_weight) + else: + w = layer.weight.register_hook(_hook_fn_zero_out) + self.hooks.append(w) + b = layer.bias.register_hook(_hook_fn_zero_out) + self.hooks.append(b) + + def forward( + self, x: torch.Tensor, cont: torch.Tensor | None = None, cat_list: list | None = None + ) -> torch.Tensor: + """Forward computation on ``x``. + + Parameters + ---------- + x + tensor of values with shape ``(n_in,)`` + cont + continuous covariates for this sample, + tensor of values with shape ``(n_cont,)`` + cat_list + list of category membership(s) for this sample + + Returns + ------- + :class:`torch.Tensor` + tensor of shape ``(n_out,)`` + """ + one_hot_cat_list = [] # for generality in this list many idxs useless. + cont_list = [cont] if cont is not None else [] + cat_list = cat_list or [] + + if len(self.n_cat_list) > len(cat_list): + raise ValueError("nb. categorical args provided doesn't match init. params.") + for n_cat, cat in zip(self.n_cat_list, cat_list, strict=False): + if n_cat and cat is None: + raise ValueError("cat not provided while n_cat != 0 in init. params.") + if n_cat > 1: # n_cat = 1 will be ignored - no additional info + if cat.size(1) != n_cat: + one_hot_cat = nn.functional.one_hot(cat.squeeze(-1), n_cat) + else: + one_hot_cat = cat # cat has already been one_hot encoded + one_hot_cat_list += [one_hot_cat] + for i, layers in enumerate(self.fc_layers): + for layer in layers: + if layer is not None: + if isinstance(layer, nn.BatchNorm1d): + if x.dim() == 3: + x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) + else: + x = layer(x) + else: + if isinstance(layer, nn.Linear) and self.inject_into_layer(i): + if x.dim() == 3: + cov_list_layer = [ + o.unsqueeze(0).expand((x.size(0), o.size(0), o.size(1))) + for o in one_hot_cat_list + ] + else: + cov_list_layer = one_hot_cat_list + x = torch.cat((x, *cov_list_layer, *cont_list), dim=-1) + x = layer(x) + return x + + +class VarEncoder(Module): + """Encode variance (strictly positive). + + Parameters + ---------- + n_input + Number of input dimensions. + Used if mode is ``'sample_feature'`` to construct a network predicting + variance from input features. + n_output + Number of variances to predict, matching the desired number of features + (e.g. latent dimensions for variational encoding or output features + for variational decoding). + mode + How to compute variance. + One of the following: + * ``'sample_feature'`` - learn variance per sample and feature. + * ``'feature'`` - learn variance per feature, constant across samples. + activation + Activation function. If empty it is set to softplus. + exp_clip + Perform clipping before activation to prevent inf values in e**x. + This should be useful for any activation function using e**x, + such as exp, softplus, etc. + """ + + def __init__( + self, + n_input: int, + n_output: int, + mode: Literal["sample_feature", "feature", "linear"], + activation: Callable | None = None, + exp_clip: bool = True, + ): + super().__init__() + + self.clip_exp_thr = np.log(torch.finfo(torch.get_default_dtype()).max) - 1e-4 + self.exp_clip = exp_clip + self.mode = mode + if self.mode == "sample_feature": + self.encoder = Linear(n_input, n_output) + elif self.mode == "feature": + self.var_param = Parameter(torch.zeros(1, n_output)) + else: + raise ValueError("Mode not recognised.") + self.activation = torch.nn.Softplus() if activation is None else activation + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Forward pass through model. + + Parameters + ---------- + x + Used to encode variance if mode is ``'sample_feature'``. + dim = n_samples x n_input + + Returns + ------- + Predicted variance + dim = n_samples * 1 + """ + if self.mode == "sample_feature": + v = self.encoder(x) + elif self.mode == "feature": + v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size + + # Ensure that var is strictly positive via exp - + # Bring back to non-log scale + # Clip to range that will not be inf after exp + # This should be useful for any activation that uses e**x + # such as exp, softplus, etc. + if self.exp_clip: + v = torch.clip(v, min=-self.clip_exp_thr, max=self.clip_exp_thr) + v = self.activation(v) + if v.isnan().any(): + warnings.warn( + "Predicted variance contains nan values. Setting to 0.", + stacklevel=settings.warnings_stacklevel, + ) + v = torch.nan_to_num(v) + + return v diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py new file mode 100644 index 0000000000..d75109b4d6 --- /dev/null +++ b/src/scvi/external/sysvi/_model.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import logging +import warnings +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + from anndata import AnnData + +import numpy as np +import torch + +from scvi import REGISTRY_KEYS, settings +from scvi.data import AnnDataManager +from scvi.data.fields import ( + CategoricalJointObsField, + CategoricalObsField, + LayerField, + NumericalJointObsField, + NumericalObsField, +) +from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin +from scvi.utils import setup_anndata_dsp + +from ._module import SysVAE + +logger = logging.getLogger(__name__) + + +class SysVI(UnsupervisedTrainingMixin, BaseModelClass): + """Integration with cVAE & optional VampPrior and latent cycle-consistency. + + Described in + `Hrovatin et al. (2023) `_. + + Parameters + ---------- + adata + AnnData object that has been registered via + :meth:`~scvi.external.SysVI.setup_anndata`. + prior + The prior distribution to be used. + You can choose between ``"standard_normal"`` and ``"vamp"``. + n_prior_components + Number of prior components (i.e. modes) to use in VampPrior. + pseudoinputs_data_indices + By default, VampPrior pseudoinputs are randomly selected from data. + Alternatively, one can specify pseudoinput indices using this parameter. + The number of specified indices in the input 1D array should match + ``n_prior_components``. + **model_kwargs + Keyword args for :class:`~scvi.external.sysvi.SysVAE` + """ + + def __init__( + self, + adata: AnnData, + prior: Literal["standard_normal", "vamp"] = "vamp", + n_prior_components: int = 5, + pseudoinputs_data_indices: np.array | None = None, + **model_kwargs, + ): + super().__init__(adata) + + if prior == "vamp": + if pseudoinputs_data_indices is None: + pseudoinputs_data_indices = np.random.randint( + 0, self.summary_stats.n_vars, n_prior_components + ) + assert pseudoinputs_data_indices.shape[0] == n_prior_components + assert pseudoinputs_data_indices.ndim == 1 + pseudoinput_data = next( + iter( + self._make_data_loader( + adata=adata, + indices=pseudoinputs_data_indices, + batch_size=n_prior_components, + shuffle=False, + ) + ) + ) + else: + pseudoinput_data = None + + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + + self.module = SysVAE( + n_input=self.summary_stats.n_vars, + n_batch=self.summary_stats.n_batch, + n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), + n_cats_per_cov=n_cats_per_cov, + prior=prior, + n_prior_components=n_prior_components, + pseudoinput_data=pseudoinput_data, + **model_kwargs, + ) + + self._model_summary_string = ( + "SysVI - cVAE model with optional VampPrior " + "and latent cycle-consistency loss." + ) + self.init_params_ = self._get_init_params(locals()) + + logger.info("The model has been initialized") + + def train( + self, + *args, + plan_kwargs: dict | None = None, + **kwargs, + ): + """Train the models. + + Overwrites the ``train`` method of + class:`~scvi.model.base.UnsupervisedTrainingMixin` + to prevent the use of KL loss warmup (specified in ``plan_kwargs``). + This is disabled as our experiments showed poor integration in the + cycle model when using KL loss warmup. + + Parameters + ---------- + args + Training args. + plan_kwargs + Training plan kwargs. + kwargs + Training kwargs. + """ + plan_kwargs = plan_kwargs or {} + kl_weight_defaults = {"n_epochs_kl_warmup": 0, "n_steps_kl_warmup": 0} + if any(v != plan_kwargs.get(k, v) for k, v in kl_weight_defaults.items()): + warnings.warn( + "The use of KL weight warmup is not recommended in SysVI. " + + "The n_epochs_kl_warmup and n_steps_kl_warmup " + + "will be reset to 0.", + stacklevel=settings.warnings_stacklevel, + ) + # Overwrite plan kwargs with kl weight defaults + plan_kwargs = {**plan_kwargs, **kl_weight_defaults} + + # Pass to parent + kwargs = kwargs or {} + kwargs["plan_kwargs"] = plan_kwargs + super().train(*args, **kwargs) + + @torch.inference_mode() + def get_latent_representation( + self, + adata: AnnData, + indices: Sequence[int] | None = None, + give_mean: bool = True, + batch_size: int | None = None, + return_dist: bool = False, + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + """Return the latent representation for each cell. + + Parameters + ---------- + adata + Input adata for which latent representation should be obtained. + indices + Data indices to embed. If None embedd all samples. + give_mean + Return the posterior latent distribution mean + instead of a sample from it. + Ignored if `return_dist` is ``True``. + batch_size + Minibatch size for data loading into model. + Defaults to `scvi.settings.batch_size`. + return_dist + If ``True``, returns the mean and variance of the posterior + latent distribution. + Otherwise, returns its mean or a sample from it. + + Returns + ------- + Latent representation of a cell. + If ``return_dist`` is ``True``, returns the mean and variance + of the posterior latent distribution. + Else, returns the mean or a sample, depending on ``give_mean``. + """ + self._check_if_trained(warn=False) + adata = self._validate_anndata(adata) + if indices is None: + indices = np.arange(adata.n_obs) + # Do not shuffle to retain order + tensors_fwd = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size, shuffle=False + ) + predicted_m = [] + predicted_v = [] + for tensors in tensors_fwd: + inference_inputs = self.module._get_inference_input(tensors) + inference_outputs = self.module.inference(**inference_inputs) + if give_mean or return_dist: + predicted_m += [inference_outputs["z_m"]] + else: + predicted_m += [inference_outputs["z"]] + if return_dist: + predicted_v += [inference_outputs["z_v"]] + + predicted_m = torch.cat(predicted_m).cpu().numpy() + if return_dist: + predicted_v = torch.cat(predicted_v).cpu().numpy() + + if return_dist: + return predicted_m, predicted_v + else: + return predicted_m + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + batch_key: str, + layer: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + weight_batches: bool = False, + **kwargs, + ): + """Prepare adata for input to SysVI model. + + Setup distinguishes between two main types of covariates that can be + corrected for: + + - batch - referred to as "system" in the original publication + Hrovatin, et al., 2023): + Single categorical covariate that will be corrected via cycle + consistency loss. + It will be also used as a condition in cVAE. + This covariate is expected to correspond to stronger batch effects, + such as between datasets from different sequencing technology or + model systems (animal species, in-vitro models and tissue, etc.). + - covariate (includes both continous and categorical covariates): + Additional covariates to be used only + as a condition in cVAE, but not corrected via cycle loss. + These covariates are expected to correspond to weaker batch effects, + such as between datasets from the same sequencing technology and + system (animal, in-vitro, etc.) or between samples within a dataset. + + Parameters + ---------- + adata + Adata object - will be modified in place. + batch_key + Name of the obs column with the substantial batch effect covariate, + referred to as batch in the original publication + (Hrovatin, et al., 2023). + Should be categorical. + layer + AnnData layer to use, default is X. + Should contain normalized and log+1 transformed expression. + categorical_covariate_keys + Name of obs columns with additional categorical + covariate information. + Will be one hot encoded or embedded, as later defined in the + ``SysVI`` model. + continuous_covariate_keys + Name of obs columns with additional continuous + covariate information. + """ + setup_method_args = cls._get_setup_method_args(**locals()) + + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=False), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), + ] + if weight_batches: + warnings.warn( + "The use of inverse batch proportion weights " + "is experimental.", + stacklevel=settings.warnings_stacklevel, + ) + batch_weights_key = "batch_weights" + adata.obs[batch_weights_key] = adata.obs[batch_key].map( + {cat: 1 / n for cat, n in adata.obs[batch_key].value_counts().items()} + ) + anndata_fields.append(NumericalObsField(batch_weights_key, batch_weights_key)) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py new file mode 100644 index 0000000000..fd773a7b01 --- /dev/null +++ b/src/scvi/external/sysvi/_module.py @@ -0,0 +1,747 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Literal + +import torch + +from scvi import REGISTRY_KEYS +from scvi.module.base import BaseModuleClass, EmbeddingModuleMixin, LossOutput, auto_move_data + +from ._base_components import EncoderDecoder +from ._priors import StandardPrior, VampPrior + +torch.backends.cudnn.benchmark = True + + +class SysVAE(BaseModuleClass, EmbeddingModuleMixin): + """CVAE with optional VampPrior and latent cycle consistency loss. + + Described in + `Hrovatin et al. (2023) `_. + + Parameters + ---------- + n_input + Number of input features. + n_batch + Number of batches. + n_continuous_cov + Number of continuous covariates. + n_cats_per_cov + A list of integers containing the number of categories + for each categorical covariate. + embed_cat + If ``True`` embeds categorical covariates and batches + into continuously-valued vectors instead of using one-hot encoding. + prior + Which prior distribution to use. + * ``'standard_normal'``: Standard normal distribution. + * ``'vamp'``: VampPrior. + n_prior_components + Number of prior components for VampPrior. + trainable_priors + Should prior components of VampPrior be trainable. + pseudoinput_data + Initialisation data for VampPrior. + Should match input tensors structure. + n_latent + Numer of latent space dimensions. + n_hidden + Numer of hidden nodes per layer for encoder and decoder. + n_layers + Number of hidden layers for encoder and decoder. + dropout_rate + Dropout rate for encoder and decoder. + out_var_mode + How variance is predicted in decoder, + see :class:`~scvi.external.sysvi.nn.VarEncoder`. + One of the following: + * ``'sample_feature'`` - learn variance per sample and feature. + * ``'feature'`` - learn variance per feature, constant across samples. + enc_dec_kwargs + Additional kwargs passed to encoder and decoder. + For both encoder and decoder + :class:`~scvi.external.sysvi.nn.EncoderDecoder` is used. + embedding_kwargs + Keyword arguments passed into :class:`~scvi.nn.Embedding` + if ``embed_cat`` is set to ``True``. + """ + + # TODO could disable computation of cycle if predefined + # that cycle loss will not be used. + # Cycle loss is not expected to be disabled in practice + # for typical use cases. + # As the use of cycle is currently only based on loss kwargs, + # which are specified only later, it can not be inferred here. + + def __init__( + self, + n_input: int, + n_batch: int, + n_continuous_cov: int = 0, + n_cats_per_cov: list[int] | None = None, + embed_cat: bool = False, + prior: Literal["standard_normal", "vamp"] = "vamp", + n_prior_components: int = 5, + trainable_priors: bool = True, + pseudoinput_data: dict[str, torch.Tensor] | None = None, + n_latent: int = 15, + n_hidden: int = 256, + n_layers: int = 2, + dropout_rate: float = 0.1, + out_var_mode: Literal["sample_feature", "feature"] = "feature", + enc_dec_kwargs: dict | None = None, + embedding_kwargs: dict | None = None, + ): + super().__init__() + + self.embed_cat = embed_cat + + enc_dec_kwargs = enc_dec_kwargs or {} + embedding_kwargs = embedding_kwargs or {} + + self.n_batch = n_batch + n_cat_list = [n_batch] + n_cont = n_continuous_cov + if n_cats_per_cov is not None: + if self.embed_cat: + for cov, n in enumerate(n_cats_per_cov): + cov = self._cov_idx_name(cov=cov) + self.init_embedding(cov, n, **embedding_kwargs) + n_cont += self.get_embedding(cov).embedding_dim + else: + n_cat_list.extend(n_cats_per_cov) + + self.encoder = EncoderDecoder( + n_input=n_input, + n_output=n_latent, + n_cat_list=n_cat_list, + n_cont=n_cont, + n_hidden=n_hidden, + n_layers=n_layers, + dropout_rate=dropout_rate, + sample=True, + var_mode="sample_feature", + **enc_dec_kwargs, + ) + + self.decoder = EncoderDecoder( + n_input=n_latent, + n_output=n_input, + n_cat_list=n_cat_list, + n_cont=n_cont, + n_hidden=n_hidden, + n_layers=n_layers, + dropout_rate=dropout_rate, + sample=True, + var_mode=out_var_mode, + **enc_dec_kwargs, + ) + + if prior == "standard_normal": + self.prior = StandardPrior() + elif prior == "vamp": + assert ( + pseudoinput_data is not None + ), "Pseudoinput data must be specified if using VampPrior" + pseudoinput_data = self._get_inference_input(pseudoinput_data) + self.prior = VampPrior( + n_components=n_prior_components, + encoder=self.encoder, + data_x=pseudoinput_data["expr"], + n_cat_list=n_cat_list, + data_cat=self._merge_batch_cov( + cat=pseudoinput_data["cat"], batch=pseudoinput_data["batch"] + ), + data_cont=pseudoinput_data["cont"], + trainable_priors=trainable_priors, + ) + else: + raise ValueError("Prior not recognised") + + @staticmethod + def _cov_idx_name(cov: int) -> str: + """Convert covariate index into a name used for embedding. + + Parameters + ---------- + cov + Covariate index. + + Returns + ------- + Covariate name. + + """ + return "cov" + str(cov) + + def _get_inference_input( + self, tensors: dict[str, torch.Tensor], **kwargs + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """Parse the input tensors to get inference inputs. + + Parameters + ---------- + tensors + Input tensors. + kwargs + Not used. Added for inheritance compatibility. + + Returns + ------- + Tensors that can be used for inference. + Keys: + * ``'expr'``: Expression. + * ``'batch'``: Batch covariate. + * ``'cat'``: All covariates that require one-hot encoding. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded + categorical covariates. + Single tensor of + dim = n_samples * n_concatenated_cov_features. + If absent returns None. + """ + cov = self._get_cov(tensors=tensors) + input_dict = { + "expr": tensors[REGISTRY_KEYS.X_KEY], + "batch": tensors[REGISTRY_KEYS.BATCH_KEY], + "cat": cov["cat"], + "cont": cov["cont"], + } + return input_dict + + def _get_inference_cycle_input( + self, + tensors: dict[str, torch.Tensor], + generative_outputs: dict[str, torch.Tensor], + selected_batch: torch.Tensor, + **kwargs, + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """In. tensors, gen. outputs, and cycle batch -> cycle inference inputs. + + Parameters + ---------- + tensors + Input tensors. + generative_outputs + Outputs of the generative pass. + selected_batch + Batch covariate to be used for the cycle inference. + dim = n_samples * 1 + kwargs + Not used. Added for inheritance compatibility. + + Returns + ------- + Tensors that can be used for cycle inference. + Keys: + * ``'expr'``: Expression. + * ``'batch'``: Batch covariate. + * ``'cat'``: All covariates that require one-hot encoding. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded categorical covariates. + Single tensor of + dim = n_samples * n_concatenated_cov_features. + If absent returns None. + + Note: cycle covariates differ from the original publication. + Instead of mock covariates the real input covaiates are used in cycle. + + """ + cov = self._get_cov(tensors=tensors) + input_dict = { + "expr": generative_outputs["y_m"], + "batch": selected_batch, + "cat": cov["cat"], + "cont": cov["cont"], + } + return input_dict + + def _get_generative_input( + self, + tensors: dict[str, torch.Tensor], + inference_outputs: dict[str, torch.Tensor], + selected_batch: torch.Tensor, + **kwargs, + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor | list[torch.Tensor] | None]]: + """In. tensors, inf. outputs, and cycle batch info -> generative inputs. + + Parameters + ---------- + tensors + Input tensors. + inference_outputs + Outputs of the inference pass. + selected_batch + Batch covariate to be used for the cycle expression generation. + dim = n_samples * 1 + kwargs + Not used. Added for inheritance compatibility. + + Returns + ------- + Tensors that can be used for normal and cycle generation. + Keys: + * ``'z'``: Latent representation. + * ``'batch'``: Batch covariates. + Dict with keys ``'x'`` for normal and + ``'y'`` for cycle pass. + * ``'cat'``: All covariates that require one-hot encoding. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded categorical covariates. + Single tensor of + dim = n_samples * n_concatenated_cov_features. + If absent returns None. + + Note: cycle covariates differ from the original publication. + Instead of mock covariates the real input covaiates are used in cycle. + + """ + z = inference_outputs["z"] + cov = self._get_cov(tensors=tensors) + batch = {"x": tensors["batch"], "y": selected_batch} + + input_dict = {"z": z, "batch": batch, "cat": cov["cat"], "cont": cov["cont"]} + return input_dict + + @auto_move_data # TODO remove? + def _get_cov( + self, + tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """Process all covs into continuous and categorical components for cVAE. + + Parameters + ---------- + tensors + Input tensors. + + Returns + ------- + Covariates that can be used for decoder and encoder. + Keys: + * ``'cat'``: All covariates that require one-hot encoding. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded categorical covariates. + Single tensor of + dim = n_samples * n_concatenated_cov_features. + If absent returns None. + + """ + cat_parts = [] + cont_parts = [] + if REGISTRY_KEYS.CONT_COVS_KEY in tensors: + cont_parts.append(tensors[REGISTRY_KEYS.CONT_COVS_KEY]) + if REGISTRY_KEYS.CAT_COVS_KEY in tensors: + cat = torch.split(tensors[REGISTRY_KEYS.CAT_COVS_KEY], 1, dim=1) + if self.embed_cat: + for idx, tensor in enumerate(cat): + cont_parts.append(self.compute_embedding(self._cov_idx_name(idx), tensor)) + else: + cat_parts.extend(cat) + cov = { + "cat": cat_parts, + "cont": torch.concat(cont_parts, dim=1) if len(cont_parts) > 0 else None, + } + return cov + + @staticmethod + def _merge_batch_cov( + cat: list[torch.Tensor], + batch: torch.Tensor, + ) -> list[torch.Tensor]: + """Merge batch and continuous covs for input into encoder and decoder. + + Parameters + ---------- + cat + Categorical covariates. + List of tensors with dim = n_samples * 1. + batch + Batch covariate. + dim = n_samples * 1 + + Returns + ------- + Single list with batch and categorical covariates. + + """ + return [batch] + cat + + @auto_move_data + def inference( + self, + expr: torch.Tensor, + batch: torch.Tensor, + cat: list[torch.Tensor], + cont: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + """Inference: expression & cov -> latent representation. + + Parameters + ---------- + expr + Expression data. + batch + Batch covariate. + cat + All covariates that require one-hot encoding. + cont + All covariates that are already continous. + Includes continous and embedded categorical covariates. + + Returns + ------- + Predicted mean (``'z_m'``) and variance (``'z_v'``) + of the latent distribution as wll as a sample (``'z'``) from it. + + """ + z = self.encoder(x=expr, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont) + return {"z": z["y"], "z_m": z["y_m"], "z_v": z["y_v"]} + + @auto_move_data + def generative( + self, + z: torch.Tensor, + batch: dict[str, torch.Tensor], + cat: list[torch.Tensor], + cont: torch.Tensor | None, + x_x: bool = True, + x_y: bool = True, + ) -> dict[str, torch.Tensor]: + """Generation: latent representation & cov -> expression. + + Parameters + ---------- + z + Latent representation. + batch + Batch covariate for normal (``'x'``) and cycle (``'y'``) generation. + cat + All covariates that require one-hot encoding. + cont + All covariates that are already continous. + Includes continous and embedded categorical covariates. + x_x + Decode to original batch. + x_y + Decode to cycle batch. + + Returns + ------- + Predicted mean (``'x_m'``) and variance (``'x_v'``) + of the expression distribution as wll as a sample (``'x'``) from it. + Same outputs are returned for the cycle generation with ``'x'`` + in keys being replaced by ``'y'``. + """ + + def outputs( + name: str, + res: dict, + x: torch.Tensor, + batch: torch.Tensor, + cat: list[torch.Tensor], + cont: torch.Tensor | None, + ): + """Helper to compute generative outputs for normal and cycle pass. + + Adds generative outputs directly to the ``res`` dict. + + Parameters + ---------- + name + Name prepended to the keys added to the ``res`` dict. + res + Dict to store generative outputs in. + Mean is stored in ``'name_m'``, variance to ``'name_v'`` + and sample to ``'name'``. + x + Latent representation. + batch + Batch covariate. + cat + All covariates that require one-hot encoding. + cont + All covariates that are already continous. + Includes continous and embedded categorical covariates. + """ + res_sub = self.decoder( + x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont + ) + res[name] = res_sub["y"] + res[name + "_m"] = res_sub["y_m"] + res[name + "_v"] = res_sub["y_v"] + + res = {} + if x_x: + outputs(name="x", res=res, x=z, batch=batch["x"], cat=cat, cont=cont) + if x_y: + outputs(name="y", res=res, x=z, batch=batch["y"], cat=cat, cont=cont) + return res + + @auto_move_data + def forward( + self, + tensors: dict[str, torch.Tensor], + get_inference_input_kwargs: dict | None = None, + get_generative_input_kwargs: dict | None = None, + inference_kwargs: dict | None = None, + generative_kwargs: dict | None = None, + loss_kwargs: dict | None = None, + compute_loss: bool = True, + ) -> ( + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] + | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], LossOutput] + ): + """Forward pass through the network. + + Parameters + ---------- + tensors + Input tensors. + get_inference_input_kwargs + Keyword args for ``_get_inference_input()`` + get_generative_input_kwargs + Keyword args for ``_get_generative_input()`` + inference_kwargs + Keyword args for ``inference()`` + generative_kwargs + Keyword args for ``generative()`` + loss_kwargs + Keyword args for ``loss()`` + compute_loss + Whether to compute loss on forward pass. + + Returns + ------- + Inference outputs, generative outputs of the normal pass, + and optionally loss components. + Inference normal and cycle outputs are combined into a single dict. + Thus, the keys of cycle inference outputs are modified by replacing + ``'z'`` with ``'z_cyc'``. + """ + # TODO could disable computation of cycle if cycle loss + # will not be used (weight = 0). + # Cycle loss is not expected to be disabled in practice + # for typical use cases. + + # Parse kwargs + inference_kwargs = inference_kwargs or {} + generative_kwargs = generative_kwargs or {} + loss_kwargs = loss_kwargs or {} + get_inference_input_kwargs = get_inference_input_kwargs or {} + get_generative_input_kwargs = get_generative_input_kwargs or {} + + # Inference + inference_inputs = self._get_inference_input(tensors, **get_inference_input_kwargs) + inference_outputs = self.inference(**inference_inputs, **inference_kwargs) + # Generative + selected_batch = self.random_select_batch(tensors[REGISTRY_KEYS.BATCH_KEY]) + generative_inputs = self._get_generative_input( + tensors, + inference_outputs, + selected_batch=selected_batch, + **get_generative_input_kwargs, + ) + generative_outputs = self.generative( + **generative_inputs, x_x=True, x_y=True, **generative_kwargs + ) + # Inference cycle + inference_cycle_inputs = self._get_inference_cycle_input( + tensors=tensors, + generative_outputs=generative_outputs, + selected_batch=selected_batch, + **get_inference_input_kwargs, + ) + inference_cycle_outputs = self.inference(**inference_cycle_inputs, **inference_kwargs) + + # Combine outputs of all forward pass components + # (first and cycle pass) into a single dict, + # separately for inference and generative outputs + # Rename keys in outputs of cycle pass + # to be distinguishable from the first pass + # for the merging into a single dict + inference_outputs_merged = dict(**inference_outputs) + inference_outputs_merged.update( + **{k.replace("z", "z_cyc"): v for k, v in inference_cycle_outputs.items()} + ) + generative_outputs_merged = dict(**generative_outputs) + + if compute_loss: + losses = self.loss( + tensors=tensors, + inference_outputs=inference_outputs_merged, + generative_outputs=generative_outputs_merged, + **loss_kwargs, + ) + return inference_outputs_merged, generative_outputs_merged, losses + else: + return inference_outputs_merged, generative_outputs_merged + + def loss( + self, + tensors: dict[str, torch.Tensor], + inference_outputs: dict[str, torch.Tensor], + generative_outputs: dict[str, torch.Tensor], + kl_weight: float = 1.0, + reconstruction_weight: float = 1.0, + z_distance_cycle_weight: float = 2.0, + ) -> LossOutput: + """Compute loss of forward pass. + + Parameters + ---------- + tensors + Input tensors. + inference_outputs + Outputs of normal and cycle inference pass. + generative_outputs + Outputs of the normal generative pass. + kl_weight + Weight for KL loss. + reconstruction_weight + Weight for reconstruction loss. + z_distance_cycle_weight + Weight for cycle loss. + + Returns + ------- + Loss components: + Cycle loss is added to extra metrics as ``'cycle_loss'``. + """ + # Reconstruction loss + x_true = tensors[REGISTRY_KEYS.X_KEY] + reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")( + generative_outputs["x_m"], x_true, generative_outputs["x_v"] + ).sum(dim=1) + reconst_loss = reconst_loss_x + + # Kl divergence on latent space + kl_divergence_z = self.prior.kl( + m_q=inference_outputs["z_m"], + v_q=inference_outputs["z_v"], + z=inference_outputs["z"], + ) + + def z_dist( + z_x_m: torch.Tensor, + z_y_m: torch.Tensor, + ) -> torch.Tensor: + """MSE loss between standardised inputs. + + MSE loss should be computed on standardized latent representations + as else model can learn to cheat the MSE loss + by setting the latent representations to smaller numbers. + Standardizer is fitted on concatenation of both inputs. + + Parameters + ---------- + z_x_m + First input. + z_y_m + Second input. + + Returns + ------- + The loss. + dim = n_samples * 1 + """ + # Standardise data (jointly both z-s) before MSE calculation + z = torch.concat([z_x_m, z_y_m]) + means = z.mean(dim=0, keepdim=True) + stds = z.std(dim=0, keepdim=True) + + def standardize(x: torch.Tensor) -> torch.Tensor: + """Helper function to standardize a tensor. + + Mean and variance from the outer scope are used for standardization. + + Parameters + ---------- + x + Input tensor. + + Returns + ------- + Standardized tensor. + """ + return (x - means) / stds + + return torch.nn.MSELoss(reduction="none")(standardize(z_x_m), standardize(z_y_m)).sum( + dim=1 + ) + + z_distance_cyc = z_dist(z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"]) + if "batch_weights" in tensors.keys(): + z_distance_cyc *= tensors["batch_weights"].flatten() + + loss = ( + reconst_loss * reconstruction_weight + + kl_divergence_z * kl_weight + + z_distance_cyc * z_distance_cycle_weight + ) + + return LossOutput( + loss=loss.mean(), + reconstruction_loss=reconst_loss, + kl_local=kl_divergence_z, + extra_metrics={"cycle_loss": z_distance_cyc.mean()}, + ) + + def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: + """For each cell randomly selects new batch different from the real one. + + Parameters + ---------- + batch + Real batch information for each cell. + + Returns + ------- + Newly selected batch for each cell. + """ + # Get available batches - + # those that are zero will become nonzero and vice versa + batch = torch.nn.functional.one_hot(batch.squeeze(-1), self.n_batch) + available_batches = 1 - batch + # Get nonzero indices for each cell - + # batches that differ from the real batch and are thus available + row_indices, col_indices = torch.nonzero(available_batches, as_tuple=True) + col_pairs = col_indices.view(-1, batch.shape[1] - 1) + # Select batch for every cell from available batches + randomly_selected_indices = col_pairs.gather( + 1, + torch.randint( + 0, + batch.shape[1] - 1, + size=(col_pairs.size(0), 1), + device=col_pairs.device, + dtype=col_pairs.dtype, + ), + ) + + return randomly_selected_indices + + @torch.inference_mode() + def sample(self, *args, **kwargs): + """Generate expression samples from posterior generative distribution. + + Not implemented as the use of decoded expression + is not recommended for SysVI. + + Raises + ------ + NotImplementedError + """ + raise NotImplementedError("The use of decoded expression is not recommended for SysVI.") diff --git a/src/scvi/external/sysvi/_priors.py b/src/scvi/external/sysvi/_priors.py new file mode 100644 index 0000000000..caece4a64d --- /dev/null +++ b/src/scvi/external/sysvi/_priors.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import abc +from abc import abstractmethod + +import torch +from torch.distributions import Normal, kl_divergence + + +class Prior(torch.nn.Module, abc.ABC): + """Abstract class for prior distributions.""" + + @abstractmethod + def kl( + self, + m_q: torch.Tensor, + v_q: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + """Compute KL divergence between prior and posterior distribution. + + Parameters + ---------- + m_q + Posterior distribution mean. + v_q + Posterior distribution variance. + z + Sample from the posterior distribution. + + Returns + ------- + KL divergence. + """ + pass + + +class StandardPrior(Prior): + """Standard prior distribution.""" + + def kl(self, m_q: torch.Tensor, v_q: torch.Tensor, z: None = None) -> torch.Tensor: + """Compute KL div between std. normal prior and the posterior distn. + + Parameters + ---------- + m_q + Posterior distribution mean. + v_q + Posterior distribution variance. + z + Ignored. + + Returns + ------- + KL divergence. + """ + # 1 x N + return kl_divergence( + Normal(m_q, v_q.sqrt()), Normal(torch.zeros_like(m_q), torch.ones_like(v_q)) + ).sum(dim=1) + + +class VampPrior(Prior): + """VampPrior. + + Adapted from a + `blog post + `_ + of the original VampPrior author. + + Parameters + ---------- + n_components + Number of prior components. + encoder + The encoder. + data_x + Expression data for pseudoinputs initialisation. + n_cat_list + The number of categorical covariates and + the number of category levels. + A list containing, for each covariate of interest, + the number of categories. + data_cat + List of categorical covariates for pseudoinputs initialisation. + Includes all covariates that will be one-hot encoded by the ``encoder``, + including the batch. + data_cont + Continuous covariates for pseudoinputs initialisation. + trainable_priors + Are pseudoinput parameters trainable or fixed. + """ + + # K - components, I - inputs, L - latent, N - samples + + def __init__( + self, + n_components: int, + encoder: torch.nn.Module, + data_x: torch.Tensor, + n_cat_list: list[int], + data_cat: list[torch.Tensor], + data_cont: torch.Tensor | None = None, + trainable_priors: bool = True, + ): + super().__init__() + + self.encoder = encoder + + # Make pseudoinputs into parameters + # X + assert n_components == data_x.shape[0] + self.u = torch.nn.Parameter(data_x, requires_grad=trainable_priors) # K x I + # Cat + assert all(cat.shape[0] == n_components for cat in data_cat) + # For categorical covariates, since scvi-tools one-hot encodes + # them in the layers, we need to create a multinomial distn + # from which we can sample categories for layers input + # Initialise the multinomial distn weights based on + # one-hot encoding of pseudoinput categories + self.u_cat = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), + # K x C_cat_onehot + requires_grad=trainable_priors, + ) + for cat, n in zip(data_cat, n_cat_list, strict=False) + # K x C_cat + ] + ) + # Cont + if data_cont is None: + self.u_cont = None + else: + assert n_components == data_cont.shape[0] + self.u_cont = torch.nn.Parameter( + data_cont, requires_grad=trainable_priors + ) # K x C_cont + + # mixing weights + self.w = torch.nn.Parameter(torch.zeros(n_components, 1, 1)) # K x 1 x 1 + + def get_params(self) -> tuple[torch.Tensor, torch.Tensor]: + """Get posterior of pseudoinputs. + + Returns + ------- + Posterior representation mean and variance for each pseudoinput. + """ + # u, u_cov -> encoder -> mean, var + original_mode = self.encoder.training + self.encoder.train(False) + # Convert category weights to categories + cat_list = [torch.multinomial(cat, num_samples=1) for cat in self.u_cat] + z = self.encoder(x=self.u, cat_list=cat_list, cont=self.u_cont) + self.encoder.train(original_mode) + return z["y_m"], z["y_v"] # (K x L), (K x L) + + def log_prob(self, z: torch.Tensor) -> torch.Tensor: + """Log probability of posterior sample under the prior. + + Parameters + ---------- + z + Latent embedding of samples. + + Returns + ------- + Log probability of every sample. + dim = n_samples * n_latent_dimensions + """ + # Mixture of gaussian computed on K x N x L + z = z.unsqueeze(0) # 1 x N x L + + # Get pseudoinputs posteriors which are prior params + m_p, v_p = self.get_params() # (K x L), (K x L) + m_p = m_p.unsqueeze(1) # K x 1 x L + v_p = v_p.unsqueeze(1) # K x 1 x L + + # mixing probabilities + w = torch.nn.functional.softmax(self.w, dim=0) # K x 1 x 1 + + # sum of log_p across components weighted by w + log_prob = Normal(m_p, v_p.sqrt()).log_prob(z) + torch.log(w) # K x N x L + log_prob = torch.logsumexp(log_prob, dim=0, keepdim=False) # N x L + + return log_prob # N x L + + def kl(self, m_q: torch.Tensor, v_q: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + """Compute KL div. between VampPrior and the posterior distribution. + + Parameters + ---------- + m_q + Posterior distribution mean. + v_q + Posterior distribution variance. + z + Sample from the posterior distribution. + + Returns + ------- + KL divergence. + """ + return (Normal(m_q, v_q.sqrt()).log_prob(z) - self.log_prob(z)).sum(1) diff --git a/tests/external/sysvi/test_sysvi.py b/tests/external/sysvi/test_sysvi.py new file mode 100644 index 0000000000..c47b22a9d1 --- /dev/null +++ b/tests/external/sysvi/test_sysvi.py @@ -0,0 +1,208 @@ +import math + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from numpy.testing import assert_raises +from scipy import sparse + +from scvi.external import SysVI + + +def mock_adata(): + """Mock adata for testing.""" + adata = AnnData( + sparse.csr_matrix( + np.exp( + np.concatenate( + [ + np.random.normal(1, 0.5, (200, 5)), + np.random.normal(1.1, 0.00237, (200, 5)), + np.random.normal(1.3, 0.35, (200, 5)), + np.random.normal(2, 0.111, (200, 5)), + np.random.normal(2.2, 0.3, (200, 5)), + np.random.normal(2.7, 0.01, (200, 5)), + np.random.normal(1, 0.001, (200, 5)), + np.random.normal(0.00001, 0.4, (200, 5)), + np.random.normal(0.2, 0.91, (200, 5)), + np.random.normal(0.1, 0.0234, (200, 5)), + np.random.normal(0.00005, 0.1, (200, 5)), + np.random.normal(0.05, 0.001, (200, 5)), + np.random.normal(0.023, 0.3, (200, 5)), + np.random.normal(0.6, 0.13, (200, 5)), + np.random.normal(0.9, 0.5, (200, 5)), + np.random.normal(1, 0.0001, (200, 5)), + np.random.normal(1.5, 0.05, (200, 5)), + np.random.normal(2, 0.009, (200, 5)), + np.random.normal(1, 0.0001, (200, 5)), + ], + axis=1, + ) + ) + ), + var=pd.DataFrame(index=[str(i) for i in range(95)]), + ) + adata.obs["covariate_cont"] = list(range(200)) + adata.obs["covariate_cat"] = ["a"] * 50 + ["b"] * 50 + ["c"] * 50 + ["d"] * 50 + adata.obs["batch"] = ["a"] * 100 + ["b"] * 50 + ["c"] * 50 + + return adata + + +@pytest.mark.parametrize( + ( + "categorical_covariate_keys", + "continuous_covariate_keys", + "pseudoinputs_data_indices", + "embed_cat", + "weight_batches", + ), + [ + # Check different covariate combinations + (["covariate_cat"], ["covariate_cont"], None, False, False), + (["covariate_cat"], ["covariate_cont"], None, True, False), + (["covariate_cat"], None, None, False, False), + (["covariate_cat"], None, None, True, False), + (None, ["covariate_cont"], None, False, False), + # Check pre-specifying pseudoinputs + (None, None, np.array(list(range(5))), False, False), + ], +) +def test_sysvi_model( + categorical_covariate_keys, + continuous_covariate_keys, + pseudoinputs_data_indices, + embed_cat, + weight_batches, +): + """Test model with different input and parameters settings.""" + adata = mock_adata() + + # Run adata setup + SysVI.setup_anndata( + adata, + batch_key="batch", + categorical_covariate_keys=categorical_covariate_keys, + continuous_covariate_keys=continuous_covariate_keys, + weight_batches=weight_batches, + ) + + # Model + + # Check that model runs through with standard normal prior + model = SysVI(adata=adata, prior="standard_normal", embed_cat=embed_cat) + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) + + # Check that model runs through with vamp prior + model = SysVI( + adata=adata, + prior="vamp", + pseudoinputs_data_indices=pseudoinputs_data_indices, + n_prior_components=5, + embed_cat=embed_cat, + ) + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) + + # Embedding + + # Check that embedding default works + assert ( + model.get_latent_representation( + adata=adata, + ).shape[0] + == adata.shape[0] + ) + + +def test_sysvi_latent_representation(): + """Test different parameters for computing later representation.""" + # Train model + adata = mock_adata() + SysVI.setup_anndata( + adata, + batch_key="batch", + categorical_covariate_keys=None, + continuous_covariate_keys=None, + weight_batches=False, + ) + model = SysVI(adata=adata, prior="standard_normal") + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) + + # Check that specifying indices in embedding works + idx = [1, 2, 3] + embed = model.get_latent_representation( + adata=adata, + indices=idx, + give_mean=True, + ) + assert embed.shape[0] == 3 + + # Check predicting mean vs sample + np.testing.assert_allclose( + embed, + model.get_latent_representation( + adata=adata, + indices=idx, + give_mean=True, + ), + ) + with assert_raises(AssertionError): + np.testing.assert_allclose( + embed, + model.get_latent_representation( + adata=adata, + indices=idx, + give_mean=False, + ), + ) + + # Test returning distn + mean, var = model.get_latent_representation( + adata=adata, + indices=idx, + return_dist=True, + ) + np.testing.assert_allclose(embed, mean) + + +def test_sysvi_warnings(): + """Test that the most important warnings and exceptions are raised.""" + # Train model + adata = mock_adata() + SysVI.setup_anndata( + adata, + batch_key="batch", + categorical_covariate_keys=None, + continuous_covariate_keys=None, + weight_batches=False, + ) + model = SysVI(adata=adata, prior="standard_normal") + + # Assert that warning is printed if kl warmup is used + # Step warmup + with pytest.warns(Warning) as record: + model.train( + max_epochs=2, + batch_size=math.ceil(adata.n_obs / 2.0), + plan_kwargs={"n_steps_kl_warmup": 1}, + ) + assert any( + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record + ) + # Epoch warmup + with pytest.warns(Warning) as record: + model.train( + max_epochs=2, + batch_size=math.ceil(adata.n_obs / 2.0), + plan_kwargs={"n_epochs_kl_warmup": 1}, + ) + assert any( + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record + ) + + # Asert that sampling is disabled + with pytest.raises(NotImplementedError): + model.module.sample()