diff --git a/CHANGELOG.md b/CHANGELOG.md index 046502dbd8..7dd7c35377 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ to [Semantic Versioning]. Full commit history is available in the ## Version 1.2 +### 1.3.0 (2024-XX-XX) + +#### Added + +#### Fixed + +#### Changed + +#### Removed + +- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable + representation learning in single-cell RNA sequencing data {pr}`3015`. + ### 1.2.1 (2024-XX-XX) #### Added diff --git a/docs/api/developer.md b/docs/api/developer.md index 3d3f6504c6..abbb1717b3 100644 --- a/docs/api/developer.md +++ b/docs/api/developer.md @@ -180,6 +180,7 @@ Module classes in the external API with respective generative and inference proc external.velovi.VELOVAE external.mrvi.MRVAE external.methylvi.METHYLVAE + external.decipher.DecipherPyroModule ``` diff --git a/docs/api/user.md b/docs/api/user.md index 10da50816b..825815d64a 100644 --- a/docs/api/user.md +++ b/docs/api/user.md @@ -62,6 +62,7 @@ import scvi external.VELOVI external.MRVI external.METHYLVI + external.Decipher ``` ## Data loading diff --git a/docs/references.bib b/docs/references.bib index 5a017924ff..3baa082b55 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -212,6 +212,16 @@ @article{Martens2023 publisher={Nature Publishing Group} } +@article{Nazaret23, + title={Deep generative model deciphers derailed trajectories in acute myeloid leukemia}, + author={Nazaret, Achille and Fan, Joy Linyue and Lavallee, Vincent-Philippe and Cornish, Andrew E and Kiseliovas, Vaidotas and Masilionis, Ignas and Chun, Jaeyoung and Bowman, Robert L and Eisman, Shira E and Wang, James and others}, + journal={bioRxiv}, + pages={2023--11}, + year={2023}, + publisher={Cold Spring Harbor Laboratory}, + doi = {10.1101/2023.11.11.566719} +} + @article{Sheng22, title = {Probabilistic machine learning ensures accurate ambient denoising in droplet-based single-cell omics}, author = {Caibin Sheng and Rui Lopes and Gang Li and Sven Schuierer and Annick Waldt and Rachel Cuttat and Slavica Dimitrieva and Audrey Kauffmann and Eric Durand and Giorgio G. Galli and Guglielmo Roma and Antoine de Weck}, diff --git a/src/scvi/external/__init__.py b/src/scvi/external/__init__.py index 4e46ca1846..c54253c2da 100644 --- a/src/scvi/external/__init__.py +++ b/src/scvi/external/__init__.py @@ -1,5 +1,6 @@ from .cellassign import CellAssign from .contrastivevi import ContrastiveVI +from .decipher import Decipher from .gimvi import GIMVI from .methylvi import METHYLVI from .mrvi import MRVI @@ -15,6 +16,7 @@ "SCAR", "SOLO", "GIMVI", + "Decipher", "RNAStereoscope", "SpatialStereoscope", "CellAssign", diff --git a/src/scvi/external/decipher/__init__.py b/src/scvi/external/decipher/__init__.py new file mode 100644 index 0000000000..d1a6049056 --- /dev/null +++ b/src/scvi/external/decipher/__init__.py @@ -0,0 +1,4 @@ +from ._model import Decipher +from ._module import DecipherPyroModule + +__all__ = ["Decipher", "DecipherPyroModule"] diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py new file mode 100644 index 0000000000..12900b29a6 --- /dev/null +++ b/src/scvi/external/decipher/_components.py @@ -0,0 +1,99 @@ +from collections.abc import Sequence + +import torch +import torch.nn as nn + + +class ConditionalDenseNN(nn.Module): + """Dense neural network with multiple outputs, optionally conditioned on a context variable. + + (Derived from pyro.nn.dense_nn.ConditionalDenseNN with some modifications [1]) + + Parameters + ---------- + input_dim + Dimension of the input + hidden_dims + Dimensions of the hidden layers (excluding the output layer) + output_dims + Dimensions of each output layer + context_dim + Dimension of the context input. + deep_context_injection + If True, inject the context into every hidden layer. + If False, only inject the context into the first hidden layer + (concatenated with the input). + activation + Activation function to use between hidden layers (not applied to the outputs). + Default: torch.nn.ReLU() + """ + + def __init__( + self, + input_dim: int, + hidden_dims: Sequence[int], + output_dims: Sequence = (1,), + context_dim: int = 0, + deep_context_injection: bool = False, + activation=torch.nn.ReLU(), + ): + super().__init__() + + self.input_dim = input_dim + self.context_dim = context_dim + self.hidden_dims = hidden_dims + self.output_dims = output_dims + self.deep_context_injection = deep_context_injection + self.n_output_layers = len(self.output_dims) + self.output_total_dim = sum(self.output_dims) + + # The multiple outputs are computed as a single output layer, and then split + last_output_end_idx = 0 + self.output_slices = [] + for dim in self.output_dims: + self.output_slices.append(slice(last_output_end_idx, last_output_end_idx + dim)) + last_output_end_idx += dim + + # Create masked layers + deep_context_dim = self.context_dim if self.deep_context_injection else 0 + layers = [] + batch_norms = [] + if len(hidden_dims): + layers.append(torch.nn.Linear(input_dim + context_dim, hidden_dims[0])) + batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) + for i in range(1, len(hidden_dims)): + layers.append( + torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) + ) + batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) + + layers.append( + torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) + ) + else: + layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) + + self.layers = torch.nn.ModuleList(layers) + + self.activation_fn = activation + self.batch_norms = torch.nn.ModuleList(batch_norms) + + def forward(self, x, context=None): + if context is not None: + # We must be able to broadcast the size of the context over the input + context = context.expand(x.size()[:-1] + (context.size(-1),)) + + h = x + for i, layer in enumerate(self.layers): + if self.context_dim > 0 and (self.deep_context_injection or i == 0): + h = torch.cat([context, h], dim=-1) + h = layer(h) + if i < len(self.layers) - 1: + h = self.batch_norms[i](h) + h = self.activation_fn(h) + + if self.n_output_layers == 1: + return h + + h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim]) + return tuple([h[..., s] for s in self.output_slices]) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py new file mode 100644 index 0000000000..50841b3ddc --- /dev/null +++ b/src/scvi/external/decipher/_model.py @@ -0,0 +1,153 @@ +import logging +from collections.abc import Sequence + +import numpy as np +import pyro +import torch +from anndata import AnnData + +from scvi._constants import REGISTRY_KEYS +from scvi.data import AnnDataManager +from scvi.data.fields import LayerField +from scvi.model.base import BaseModelClass, PyroSviTrainMixin +from scvi.train import PyroTrainingPlan +from scvi.utils import setup_anndata_dsp + +from ._module import DecipherPyroModule +from ._trainingplan import DecipherTrainingPlan + +logger = logging.getLogger(__name__) + + +class Decipher(PyroSviTrainMixin, BaseModelClass): + """Decipher model for single-cell data analysis :cite:p:`Nazaret23`. + + Parameters + ---------- + adata + AnnData object that has been registered via + :meth:`~scvi.model.Decipher.setup_anndata`. + dim_v + Dimension of the interpretable latent space v. + dim_z + Dimension of the intermediate latent space z. + layers_v_to_z + Hidden layer sizes for the v to z decoder network. + layers_z_to_x + Hidden layer sizes for the z to x decoder network. + beta + Regularization parameter for the KL divergence. + """ + + _module_cls = DecipherPyroModule + _training_plan_cls = DecipherTrainingPlan + + def __init__(self, adata: AnnData, **kwargs): + pyro.clear_param_store() + + super().__init__(adata) + + dim_genes = self.summary_stats.n_vars + + self.module = self._module_cls( + dim_genes, + **kwargs, + ) + + self.init_params_ = self._get_init_params(locals()) + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + layer: str | None = None, + **kwargs, + ) -> AnnData | None: + """%(summary)s. + + Parameters + ---------- + %(param_adata)s + %(param_layer)s + """ + setup_method_args = cls._get_setup_method_args(**locals()) + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + ] + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) + + def train( + self, + max_epochs: int | None = None, + accelerator: str = "auto", + device: int | str = "auto", + train_size: float = 0.9, + validation_size: float | None = None, + shuffle_set_split: bool = True, + batch_size: int = 128, + early_stopping: bool = False, + training_plan: PyroTrainingPlan | None = None, + datasplitter_kwargs: dict | None = None, + plan_kwargs: dict | None = None, + **trainer_kwargs, + ): + if "early_stopping_monitor" not in trainer_kwargs: + trainer_kwargs["early_stopping_monitor"] = "nll_validation" + datasplitter_kwargs = datasplitter_kwargs or {} + if "drop_last" not in datasplitter_kwargs: + datasplitter_kwargs["drop_last"] = True + super().train( + max_epochs=max_epochs, + accelerator=accelerator, + device=device, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + batch_size=batch_size, + early_stopping=early_stopping, + plan_kwargs=plan_kwargs, + training_plan=training_plan, + datasplitter_kwargs=datasplitter_kwargs, + **trainer_kwargs, + ) + + def get_latent_representation( + self, + adata: AnnData | None = None, + indices: Sequence[int] | None = None, + batch_size: int | None = None, + give_z: bool = False, + ) -> np.ndarray: + """Get the latent representation of the data. + + Parameters + ---------- + adata + AnnData object with the data to get the latent representation of. + indices + Indices of the data to get the latent representation of. + batch_size + Batch size to use for the data loader. + give_z + Whether to return the intermediate latent space z or the top-level + latent space v. + """ + self._check_if_trained(warn=False) + adata = self._validate_anndata(adata) + + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + latent_locs = [] + for tensors in scdl: + x = tensors[REGISTRY_KEYS.X_KEY] + x = torch.log1p(x) + x = x.to(self.module.device) + z_loc, _ = self.module.encoder_x_to_z(x) + if give_z: + latent_locs.append(z_loc) + else: + v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1)) + latent_locs.append(v_loc) + return torch.cat(latent_locs).detach().cpu().numpy() diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py new file mode 100644 index 0000000000..fb864ebe48 --- /dev/null +++ b/src/scvi/external/decipher/_module.py @@ -0,0 +1,175 @@ +from collections.abc import Iterable, Sequence + +import numpy as np +import pyro +import pyro.distributions as dist +import pyro.poutine as poutine +import torch +import torch.nn as nn +import torch.utils.data +from torch.distributions import constraints +from torch.nn.functional import softmax, softplus + +from scvi._constants import REGISTRY_KEYS +from scvi.module.base import PyroBaseModuleClass, auto_move_data + +from ._components import ConditionalDenseNN + +TEST_BETA = 1.0 + + +class DecipherPyroModule(PyroBaseModuleClass): + """Pyro Module for the Decipher model. + + This module implements the Decipher model for dimensionality reduction and + interpretable representation learning in single-cell RNA sequencing data. + + Parameters + ---------- + dim_genes + Number of genes (features) in the dataset. + dim_v + Dimension of the interpretable latent space v. + dim_z + Dimension of the intermediate latent space z. + layers_v_to_z + Hidden layer sizes for the v to z decoder network. + layers_z_to_x + Hidden layer sizes for the z to x decoder network. + beta + Regularization parameter for the KL divergence. + """ + + def __init__( + self, + dim_genes: int, + dim_v: int = 2, + dim_z: int = 10, + layers_v_to_z: Sequence[int] = (64,), + layers_z_to_x: Sequence[int] = (), + beta: float = 0.1, + ): + super().__init__() + self.dim_v = dim_v + self.dim_z = dim_z + self.dim_genes = dim_genes + self.layers_v_to_z = layers_v_to_z + self.layers_z_to_x = layers_z_to_x + self.beta = beta + + self.decoder_v_to_z = ConditionalDenseNN(dim_v, layers_v_to_z, [dim_z] * 2) + self.decoder_z_to_x = ConditionalDenseNN(dim_z, layers_z_to_x, [dim_genes]) + self.encoder_x_to_z = ConditionalDenseNN(dim_genes, [128], [dim_z] * 2) + self.encoder_zx_to_v = ConditionalDenseNN( + dim_genes + dim_z, + [128], + [dim_v] * 2, + ) + + self.theta = None + + self._epsilon = 1e-5 + + self.n_obs = None # Populated by PyroTrainingPlan + + # Hack: to allow auto_move_data to infer device. + self._dummy_param = nn.Parameter(torch.empty(0)) + + @property + def device(self): + return self._dummy_param.device + + @staticmethod + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + x = tensor_dict[REGISTRY_KEYS.X_KEY] + return (x,), {} + + @auto_move_data + def model(self, x: torch.Tensor, beta: float | None = None): + pyro.module("decipher", self) + + self.theta = pyro.param( + "theta", + x.new_ones(self.dim_genes), + constraint=constraints.positive, + ) + + with ( + pyro.plate("batch", x.shape[0]), + poutine.scale(scale=1.0), + ): + with poutine.scale(scale=beta or self.beta): + prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) + v = pyro.sample("v", prior) + + z_loc, z_scale = self.decoder_v_to_z(v) + z_scale = softplus(z_scale) + z = pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1)) + + mu = self.decoder_z_to_x(z) + mu = softmax(mu, dim=-1) + library_size = x.sum(axis=-1, keepdim=True) + # Parametrization of Negative Binomial by the mean and inverse dispersion + # See https://github.com/pytorch/pytorch/issues/42449 + # noinspection PyTypeChecker + logit = torch.log(library_size * mu + self._epsilon) - torch.log( + self.theta + self._epsilon + ) + # noinspection PyUnresolvedReferences + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) + pyro.sample("x", x_dist.to_event(1), obs=x) + + @auto_move_data + def guide(self, x: torch.Tensor, beta: float | None = None): + pyro.module("decipher", self) + with ( + pyro.plate("batch", x.shape[0]), + poutine.scale(scale=1.0), + ): + x = torch.log1p(x) + + z_loc, z_scale = self.encoder_x_to_z(x) + z_scale = softplus(z_scale) + posterior_z = dist.Normal(z_loc, z_scale).to_event(1) + z = pyro.sample("z", posterior_z) + + zx = torch.cat([z, x], dim=-1) + v_loc, v_scale = self.encoder_zx_to_v(zx) + v_scale = softplus(v_scale) + with poutine.scale(scale=beta or self.beta): + posterior_v = dist.Normal(v_loc, v_scale).to_event(1) + pyro.sample("v", posterior_v) + return z_loc, v_loc, z_scale, v_scale + + def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): + """ + Calculate the predictive log-likelihood for a Decipher module. + + This function performs multiple runs through the dataloader to obtain + an empirical estimate of the predictive log-likelihood. It calculates the + log-likelihood for each run and returns the average. The beta parameter + of the Decipher module is temporarily modified and restored even if an + exception occurs. Used by default as an early stopping criterion. + + Parameters + ---------- + x : torch.Tensor + Batch of data to compute the log-likelihood for. + n_samples : int, optional + Number of passes through the dataloader (default is 5). + + Returns + ------- + float + The average estimated predictive log-likelihood across multiple runs. + """ + log_weights = [] + for _ in range(n_samples): + guide_trace = poutine.trace(self.guide).get_trace(x, beta=TEST_BETA) + model_trace = poutine.trace(poutine.replay(self.model, trace=guide_trace)).get_trace( + x, beta=TEST_BETA + ) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) + + log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) + return log_z.item() diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py new file mode 100644 index 0000000000..adcb9dbbbe --- /dev/null +++ b/src/scvi/external/decipher/_trainingplan.py @@ -0,0 +1,143 @@ +import pyro +import torch + +from scvi.module.base import ( + PyroBaseModuleClass, +) +from scvi.train import LowLevelPyroTrainingPlan + + +class DecipherTrainingPlan(LowLevelPyroTrainingPlan): + """Lightning module task to train the Decipher Pyro module. + + Parameters + ---------- + pyro_module + An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object + should have callable `model` and `guide` attributes or methods. + loss_fn + A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. + If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`. + optim + A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`, + defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`. + optim_kwargs + Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`. + """ + + def __init__( + self, + pyro_module: PyroBaseModuleClass, + loss_fn: pyro.infer.ELBO | None = None, + optim: pyro.optim.PyroOptim | None = None, + optim_kwargs: dict | None = None, + ): + super().__init__( + pyro_module=pyro_module, + loss_fn=loss_fn, + ) + optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} + if "lr" not in optim_kwargs.keys(): + optim_kwargs.update({"lr": 5e-3}) + if "weight_decay" not in optim_kwargs.keys(): + optim_kwargs.update({"weight_decay": 1e-4}) + self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + # We let SVI take care of all optimization + self.automatic_optimization = False + + self.svi = pyro.infer.SVI( + model=self.module.model, + guide=self.module.guide, + optim=self.optim, + loss=self.loss_fn, + ) + self.validation_step_outputs = [] + + # See configure_optimizers for what this does + self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) + + def on_validation_model_train(self): + """Prepare the model for validation by switching to train mode.""" + super().on_validation_model_train() + if self.current_epoch > 0: + # freeze the batch norm layers after the first epoch + # 1) the batch norm layers helps with the initialization + # 2) but then, they seem to imply a strong normal prior on the latent space + for module in self.module.modules(): + if isinstance(module, torch.nn.BatchNorm1d): + module.eval() + + def training_step(self, batch, batch_idx): + """Training step for Pyro training.""" + args, kwargs = self.module._get_fn_args_from_batch(batch) + + # pytorch lightning requires a Tensor object for loss + loss = torch.Tensor([self.svi.step(*args, **kwargs)]) + n_obs = args[0].shape[0] + + _opt = self.optimizers() + _opt.step() + + out_dict = {"loss": loss, "n_obs": n_obs} + self.training_step_outputs.append(out_dict) + return out_dict + + def on_train_epoch_end(self): + """Training epoch end for Pyro training.""" + outputs = self.training_step_outputs + elbo = 0 + n_obs = 0 + for out in outputs: + elbo += out["loss"] + n_obs += out["n_obs"] + elbo /= n_obs + self.log("elbo_train", elbo, prog_bar=True) + self.training_step_outputs.clear() + + def validation_step(self, batch, batch_idx): + """Validation step for Pyro training.""" + args, kwargs = self.module._get_fn_args_from_batch(batch) + loss = self.differentiable_loss_fn( + self.scale_fn(self.module.model), + self.scale_fn(self.module.guide), + *args, + **kwargs, + ) + nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) + out_dict = {"loss": loss, "nll": nll, "n_obs": args[0].shape[0]} + self.validation_step_outputs.append(out_dict) + return out_dict + + def on_validation_epoch_end(self): + """Validation epoch end for Pyro training.""" + outputs = self.validation_step_outputs + elbo = 0 + nll = 0 + n_obs = 0 + for out in outputs: + elbo += out["loss"] + nll += out["nll"] + n_obs += out["n_obs"] + elbo /= n_obs + nll /= n_obs + self.log("elbo_validation", elbo, prog_bar=True) + self.log("nll_validation", nll, prog_bar=False) + self.validation_step_outputs.clear() + + def configure_optimizers(self): + """Shim optimizer for PyTorch Lightning. + + PyTorch Lightning wants to take steps on an optimizer + returned by this function in order to increment the global + step count. See PyTorch Lighinting optimizer manual loop. + + Here we provide a shim optimizer that we can take steps on + at minimal computational cost in order to keep Lightning happy :). + """ + return torch.optim.Adam([self._dummy_param]) + + def optimizer_step(self, *args, **kwargs): + pass + + def backward(self, *args, **kwargs): + pass diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py new file mode 100644 index 0000000000..0056096eb0 --- /dev/null +++ b/tests/external/decipher/test_decipher.py @@ -0,0 +1,23 @@ +import pytest + +from scvi.data import synthetic_iid +from scvi.external import Decipher + + +@pytest.fixture(scope="session") +def adata(): + adata = synthetic_iid() + return adata + + +def test_decipher_train(adata): + Decipher.setup_anndata(adata) + model = Decipher(adata) + model.train( + max_epochs=2, + check_val_every_n_epoch=1, + train_size=0.5, + early_stopping=True, + ) + model.get_latent_representation(give_z=False) + model.get_latent_representation(give_z=True)