Skip to content

Commit

Permalink
feat: implement Decipher model in external (#3015)
Browse files Browse the repository at this point in the history
CC @ANazaret

Implements Decipher model (https://github.com/azizilab/decipher,
https://www.biorxiv.org/content/10.1101/2023.11.11.566719v1) into
external/

For now, it only includes base implementation without many of the
downstream workflows from the original implementation.
Includes minor non-breaking changes to the `LowLevelPyroTrainingPlan`.

Test: was able to approximately reproduce figures from the tutorial
(https://github.com/azizilab/decipher/blob/main/examples/1-tutorial.ipynb),
some of the v plots for several random seeds below:

Original implementation:

![decipher_orig_0](https://github.com/user-attachments/assets/9e3e45ff-b5dd-48fc-bbbf-f98bd2751ea4)

![decipher_orig_1030](https://github.com/user-attachments/assets/2e035c43-8f1a-4d2c-bf33-9ae4d660e7b2)

New implementation:

![decipher_tutorial_scvi_0](https://github.com/user-attachments/assets/c13a1ec9-133e-414f-8c79-cc7304ff99e6)

![decipher_tutorial_scvi_1](https://github.com/user-attachments/assets/30908122-928f-4275-a05e-20624782c006)

![decipher_tutorial_scvi_2](https://github.com/user-attachments/assets/ecaa7d3c-2ec2-4975-8da2-656eb2027342)

![decipher_tutorial_scvi_3](https://github.com/user-attachments/assets/e5ebb920-ff43-472f-b3b1-0ccbb843b24d)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: access <[email protected]>
Co-authored-by: Ori Kronfeld <[email protected]>
  • Loading branch information
4 people authored Nov 13, 2024
1 parent 2046e7c commit 54ba452
Show file tree
Hide file tree
Showing 11 changed files with 624 additions and 0 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ to [Semantic Versioning]. Full commit history is available in the

## Version 1.2

### 1.3.0 (2024-XX-XX)

#### Added

#### Fixed

#### Changed

#### Removed

- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable
representation learning in single-cell RNA sequencing data {pr}`3015`.

### 1.2.1 (2024-XX-XX)

#### Added
Expand Down
1 change: 1 addition & 0 deletions docs/api/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Module classes in the external API with respective generative and inference proc
external.velovi.VELOVAE
external.mrvi.MRVAE
external.methylvi.METHYLVAE
external.decipher.DecipherPyroModule
```

Expand Down
1 change: 1 addition & 0 deletions docs/api/user.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import scvi
external.VELOVI
external.MRVI
external.METHYLVI
external.Decipher
```

## Data loading
Expand Down
10 changes: 10 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,16 @@ @article{Martens2023
publisher={Nature Publishing Group}
}

@article{Nazaret23,
title={Deep generative model deciphers derailed trajectories in acute myeloid leukemia},
author={Nazaret, Achille and Fan, Joy Linyue and Lavallee, Vincent-Philippe and Cornish, Andrew E and Kiseliovas, Vaidotas and Masilionis, Ignas and Chun, Jaeyoung and Bowman, Robert L and Eisman, Shira E and Wang, James and others},
journal={bioRxiv},
pages={2023--11},
year={2023},
publisher={Cold Spring Harbor Laboratory},
doi = {10.1101/2023.11.11.566719}
}

@article{Sheng22,
title = {Probabilistic machine learning ensures accurate ambient denoising in droplet-based single-cell omics},
author = {Caibin Sheng and Rui Lopes and Gang Li and Sven Schuierer and Annick Waldt and Rachel Cuttat and Slavica Dimitrieva and Audrey Kauffmann and Eric Durand and Giorgio G. Galli and Guglielmo Roma and Antoine de Weck},
Expand Down
2 changes: 2 additions & 0 deletions src/scvi/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cellassign import CellAssign
from .contrastivevi import ContrastiveVI
from .decipher import Decipher
from .gimvi import GIMVI
from .methylvi import METHYLVI
from .mrvi import MRVI
Expand All @@ -15,6 +16,7 @@
"SCAR",
"SOLO",
"GIMVI",
"Decipher",
"RNAStereoscope",
"SpatialStereoscope",
"CellAssign",
Expand Down
4 changes: 4 additions & 0 deletions src/scvi/external/decipher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._model import Decipher
from ._module import DecipherPyroModule

__all__ = ["Decipher", "DecipherPyroModule"]
99 changes: 99 additions & 0 deletions src/scvi/external/decipher/_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from collections.abc import Sequence

import torch
import torch.nn as nn


class ConditionalDenseNN(nn.Module):
"""Dense neural network with multiple outputs, optionally conditioned on a context variable.
(Derived from pyro.nn.dense_nn.ConditionalDenseNN with some modifications [1])
Parameters
----------
input_dim
Dimension of the input
hidden_dims
Dimensions of the hidden layers (excluding the output layer)
output_dims
Dimensions of each output layer
context_dim
Dimension of the context input.
deep_context_injection
If True, inject the context into every hidden layer.
If False, only inject the context into the first hidden layer
(concatenated with the input).
activation
Activation function to use between hidden layers (not applied to the outputs).
Default: torch.nn.ReLU()
"""

def __init__(
self,
input_dim: int,
hidden_dims: Sequence[int],
output_dims: Sequence = (1,),
context_dim: int = 0,
deep_context_injection: bool = False,
activation=torch.nn.ReLU(),
):
super().__init__()

self.input_dim = input_dim
self.context_dim = context_dim
self.hidden_dims = hidden_dims
self.output_dims = output_dims
self.deep_context_injection = deep_context_injection
self.n_output_layers = len(self.output_dims)
self.output_total_dim = sum(self.output_dims)

# The multiple outputs are computed as a single output layer, and then split
last_output_end_idx = 0
self.output_slices = []
for dim in self.output_dims:
self.output_slices.append(slice(last_output_end_idx, last_output_end_idx + dim))
last_output_end_idx += dim

# Create masked layers
deep_context_dim = self.context_dim if self.deep_context_injection else 0
layers = []
batch_norms = []
if len(hidden_dims):
layers.append(torch.nn.Linear(input_dim + context_dim, hidden_dims[0]))
batch_norms.append(nn.BatchNorm1d(hidden_dims[0]))
for i in range(1, len(hidden_dims)):
layers.append(
torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i])
)
batch_norms.append(nn.BatchNorm1d(hidden_dims[i]))

layers.append(
torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim)
)
else:
layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim))

self.layers = torch.nn.ModuleList(layers)

self.activation_fn = activation
self.batch_norms = torch.nn.ModuleList(batch_norms)

def forward(self, x, context=None):
if context is not None:
# We must be able to broadcast the size of the context over the input
context = context.expand(x.size()[:-1] + (context.size(-1),))

h = x
for i, layer in enumerate(self.layers):
if self.context_dim > 0 and (self.deep_context_injection or i == 0):
h = torch.cat([context, h], dim=-1)
h = layer(h)
if i < len(self.layers) - 1:
h = self.batch_norms[i](h)
h = self.activation_fn(h)

if self.n_output_layers == 1:
return h

h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim])
return tuple([h[..., s] for s in self.output_slices])
153 changes: 153 additions & 0 deletions src/scvi/external/decipher/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import logging
from collections.abc import Sequence

import numpy as np
import pyro
import torch
from anndata import AnnData

from scvi._constants import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import LayerField
from scvi.model.base import BaseModelClass, PyroSviTrainMixin
from scvi.train import PyroTrainingPlan
from scvi.utils import setup_anndata_dsp

from ._module import DecipherPyroModule
from ._trainingplan import DecipherTrainingPlan

logger = logging.getLogger(__name__)


class Decipher(PyroSviTrainMixin, BaseModelClass):
"""Decipher model for single-cell data analysis :cite:p:`Nazaret23`.
Parameters
----------
adata
AnnData object that has been registered via
:meth:`~scvi.model.Decipher.setup_anndata`.
dim_v
Dimension of the interpretable latent space v.
dim_z
Dimension of the intermediate latent space z.
layers_v_to_z
Hidden layer sizes for the v to z decoder network.
layers_z_to_x
Hidden layer sizes for the z to x decoder network.
beta
Regularization parameter for the KL divergence.
"""

_module_cls = DecipherPyroModule
_training_plan_cls = DecipherTrainingPlan

def __init__(self, adata: AnnData, **kwargs):
pyro.clear_param_store()

super().__init__(adata)

dim_genes = self.summary_stats.n_vars

self.module = self._module_cls(
dim_genes,
**kwargs,
)

self.init_params_ = self._get_init_params(locals())

@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
**kwargs,
) -> AnnData | None:
"""%(summary)s.
Parameters
----------
%(param_adata)s
%(param_layer)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

def train(
self,
max_epochs: int | None = None,
accelerator: str = "auto",
device: int | str = "auto",
train_size: float = 0.9,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
early_stopping: bool = False,
training_plan: PyroTrainingPlan | None = None,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
**trainer_kwargs,
):
if "early_stopping_monitor" not in trainer_kwargs:
trainer_kwargs["early_stopping_monitor"] = "nll_validation"
datasplitter_kwargs = datasplitter_kwargs or {}
if "drop_last" not in datasplitter_kwargs:
datasplitter_kwargs["drop_last"] = True
super().train(
max_epochs=max_epochs,
accelerator=accelerator,
device=device,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
batch_size=batch_size,
early_stopping=early_stopping,
plan_kwargs=plan_kwargs,
training_plan=training_plan,
datasplitter_kwargs=datasplitter_kwargs,
**trainer_kwargs,
)

def get_latent_representation(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
give_z: bool = False,
) -> np.ndarray:
"""Get the latent representation of the data.
Parameters
----------
adata
AnnData object with the data to get the latent representation of.
indices
Indices of the data to get the latent representation of.
batch_size
Batch size to use for the data loader.
give_z
Whether to return the intermediate latent space z or the top-level
latent space v.
"""
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent_locs = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
x = torch.log1p(x)
x = x.to(self.module.device)
z_loc, _ = self.module.encoder_x_to_z(x)
if give_z:
latent_locs.append(z_loc)
else:
v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1))
latent_locs.append(v_loc)
return torch.cat(latent_locs).detach().cpu().numpy()
Loading

0 comments on commit 54ba452

Please sign in to comment.