-
Notifications
You must be signed in to change notification settings - Fork 372
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement Decipher model in external (#3015)
CC @ANazaret Implements Decipher model (https://github.com/azizilab/decipher, https://www.biorxiv.org/content/10.1101/2023.11.11.566719v1) into external/ For now, it only includes base implementation without many of the downstream workflows from the original implementation. Includes minor non-breaking changes to the `LowLevelPyroTrainingPlan`. Test: was able to approximately reproduce figures from the tutorial (https://github.com/azizilab/decipher/blob/main/examples/1-tutorial.ipynb), some of the v plots for several random seeds below: Original implementation: ![decipher_orig_0](https://github.com/user-attachments/assets/9e3e45ff-b5dd-48fc-bbbf-f98bd2751ea4) ![decipher_orig_1030](https://github.com/user-attachments/assets/2e035c43-8f1a-4d2c-bf33-9ae4d660e7b2) New implementation: ![decipher_tutorial_scvi_0](https://github.com/user-attachments/assets/c13a1ec9-133e-414f-8c79-cc7304ff99e6) ![decipher_tutorial_scvi_1](https://github.com/user-attachments/assets/30908122-928f-4275-a05e-20624782c006) ![decipher_tutorial_scvi_2](https://github.com/user-attachments/assets/ecaa7d3c-2ec2-4975-8da2-656eb2027342) ![decipher_tutorial_scvi_3](https://github.com/user-attachments/assets/e5ebb920-ff43-472f-b3b1-0ccbb843b24d) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: access <[email protected]> Co-authored-by: Ori Kronfeld <[email protected]>
- Loading branch information
1 parent
2046e7c
commit 54ba452
Showing
11 changed files
with
624 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ import scvi | |
external.VELOVI | ||
external.MRVI | ||
external.METHYLVI | ||
external.Decipher | ||
``` | ||
|
||
## Data loading | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from ._model import Decipher | ||
from ._module import DecipherPyroModule | ||
|
||
__all__ = ["Decipher", "DecipherPyroModule"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from collections.abc import Sequence | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class ConditionalDenseNN(nn.Module): | ||
"""Dense neural network with multiple outputs, optionally conditioned on a context variable. | ||
(Derived from pyro.nn.dense_nn.ConditionalDenseNN with some modifications [1]) | ||
Parameters | ||
---------- | ||
input_dim | ||
Dimension of the input | ||
hidden_dims | ||
Dimensions of the hidden layers (excluding the output layer) | ||
output_dims | ||
Dimensions of each output layer | ||
context_dim | ||
Dimension of the context input. | ||
deep_context_injection | ||
If True, inject the context into every hidden layer. | ||
If False, only inject the context into the first hidden layer | ||
(concatenated with the input). | ||
activation | ||
Activation function to use between hidden layers (not applied to the outputs). | ||
Default: torch.nn.ReLU() | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_dim: int, | ||
hidden_dims: Sequence[int], | ||
output_dims: Sequence = (1,), | ||
context_dim: int = 0, | ||
deep_context_injection: bool = False, | ||
activation=torch.nn.ReLU(), | ||
): | ||
super().__init__() | ||
|
||
self.input_dim = input_dim | ||
self.context_dim = context_dim | ||
self.hidden_dims = hidden_dims | ||
self.output_dims = output_dims | ||
self.deep_context_injection = deep_context_injection | ||
self.n_output_layers = len(self.output_dims) | ||
self.output_total_dim = sum(self.output_dims) | ||
|
||
# The multiple outputs are computed as a single output layer, and then split | ||
last_output_end_idx = 0 | ||
self.output_slices = [] | ||
for dim in self.output_dims: | ||
self.output_slices.append(slice(last_output_end_idx, last_output_end_idx + dim)) | ||
last_output_end_idx += dim | ||
|
||
# Create masked layers | ||
deep_context_dim = self.context_dim if self.deep_context_injection else 0 | ||
layers = [] | ||
batch_norms = [] | ||
if len(hidden_dims): | ||
layers.append(torch.nn.Linear(input_dim + context_dim, hidden_dims[0])) | ||
batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) | ||
for i in range(1, len(hidden_dims)): | ||
layers.append( | ||
torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) | ||
) | ||
batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) | ||
|
||
layers.append( | ||
torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) | ||
) | ||
else: | ||
layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) | ||
|
||
self.layers = torch.nn.ModuleList(layers) | ||
|
||
self.activation_fn = activation | ||
self.batch_norms = torch.nn.ModuleList(batch_norms) | ||
|
||
def forward(self, x, context=None): | ||
if context is not None: | ||
# We must be able to broadcast the size of the context over the input | ||
context = context.expand(x.size()[:-1] + (context.size(-1),)) | ||
|
||
h = x | ||
for i, layer in enumerate(self.layers): | ||
if self.context_dim > 0 and (self.deep_context_injection or i == 0): | ||
h = torch.cat([context, h], dim=-1) | ||
h = layer(h) | ||
if i < len(self.layers) - 1: | ||
h = self.batch_norms[i](h) | ||
h = self.activation_fn(h) | ||
|
||
if self.n_output_layers == 1: | ||
return h | ||
|
||
h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim]) | ||
return tuple([h[..., s] for s in self.output_slices]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import logging | ||
from collections.abc import Sequence | ||
|
||
import numpy as np | ||
import pyro | ||
import torch | ||
from anndata import AnnData | ||
|
||
from scvi._constants import REGISTRY_KEYS | ||
from scvi.data import AnnDataManager | ||
from scvi.data.fields import LayerField | ||
from scvi.model.base import BaseModelClass, PyroSviTrainMixin | ||
from scvi.train import PyroTrainingPlan | ||
from scvi.utils import setup_anndata_dsp | ||
|
||
from ._module import DecipherPyroModule | ||
from ._trainingplan import DecipherTrainingPlan | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Decipher(PyroSviTrainMixin, BaseModelClass): | ||
"""Decipher model for single-cell data analysis :cite:p:`Nazaret23`. | ||
Parameters | ||
---------- | ||
adata | ||
AnnData object that has been registered via | ||
:meth:`~scvi.model.Decipher.setup_anndata`. | ||
dim_v | ||
Dimension of the interpretable latent space v. | ||
dim_z | ||
Dimension of the intermediate latent space z. | ||
layers_v_to_z | ||
Hidden layer sizes for the v to z decoder network. | ||
layers_z_to_x | ||
Hidden layer sizes for the z to x decoder network. | ||
beta | ||
Regularization parameter for the KL divergence. | ||
""" | ||
|
||
_module_cls = DecipherPyroModule | ||
_training_plan_cls = DecipherTrainingPlan | ||
|
||
def __init__(self, adata: AnnData, **kwargs): | ||
pyro.clear_param_store() | ||
|
||
super().__init__(adata) | ||
|
||
dim_genes = self.summary_stats.n_vars | ||
|
||
self.module = self._module_cls( | ||
dim_genes, | ||
**kwargs, | ||
) | ||
|
||
self.init_params_ = self._get_init_params(locals()) | ||
|
||
@classmethod | ||
@setup_anndata_dsp.dedent | ||
def setup_anndata( | ||
cls, | ||
adata: AnnData, | ||
layer: str | None = None, | ||
**kwargs, | ||
) -> AnnData | None: | ||
"""%(summary)s. | ||
Parameters | ||
---------- | ||
%(param_adata)s | ||
%(param_layer)s | ||
""" | ||
setup_method_args = cls._get_setup_method_args(**locals()) | ||
anndata_fields = [ | ||
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), | ||
] | ||
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) | ||
adata_manager.register_fields(adata, **kwargs) | ||
cls.register_manager(adata_manager) | ||
|
||
def train( | ||
self, | ||
max_epochs: int | None = None, | ||
accelerator: str = "auto", | ||
device: int | str = "auto", | ||
train_size: float = 0.9, | ||
validation_size: float | None = None, | ||
shuffle_set_split: bool = True, | ||
batch_size: int = 128, | ||
early_stopping: bool = False, | ||
training_plan: PyroTrainingPlan | None = None, | ||
datasplitter_kwargs: dict | None = None, | ||
plan_kwargs: dict | None = None, | ||
**trainer_kwargs, | ||
): | ||
if "early_stopping_monitor" not in trainer_kwargs: | ||
trainer_kwargs["early_stopping_monitor"] = "nll_validation" | ||
datasplitter_kwargs = datasplitter_kwargs or {} | ||
if "drop_last" not in datasplitter_kwargs: | ||
datasplitter_kwargs["drop_last"] = True | ||
super().train( | ||
max_epochs=max_epochs, | ||
accelerator=accelerator, | ||
device=device, | ||
train_size=train_size, | ||
validation_size=validation_size, | ||
shuffle_set_split=shuffle_set_split, | ||
batch_size=batch_size, | ||
early_stopping=early_stopping, | ||
plan_kwargs=plan_kwargs, | ||
training_plan=training_plan, | ||
datasplitter_kwargs=datasplitter_kwargs, | ||
**trainer_kwargs, | ||
) | ||
|
||
def get_latent_representation( | ||
self, | ||
adata: AnnData | None = None, | ||
indices: Sequence[int] | None = None, | ||
batch_size: int | None = None, | ||
give_z: bool = False, | ||
) -> np.ndarray: | ||
"""Get the latent representation of the data. | ||
Parameters | ||
---------- | ||
adata | ||
AnnData object with the data to get the latent representation of. | ||
indices | ||
Indices of the data to get the latent representation of. | ||
batch_size | ||
Batch size to use for the data loader. | ||
give_z | ||
Whether to return the intermediate latent space z or the top-level | ||
latent space v. | ||
""" | ||
self._check_if_trained(warn=False) | ||
adata = self._validate_anndata(adata) | ||
|
||
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) | ||
latent_locs = [] | ||
for tensors in scdl: | ||
x = tensors[REGISTRY_KEYS.X_KEY] | ||
x = torch.log1p(x) | ||
x = x.to(self.module.device) | ||
z_loc, _ = self.module.encoder_x_to_z(x) | ||
if give_z: | ||
latent_locs.append(z_loc) | ||
else: | ||
v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1)) | ||
latent_locs.append(v_loc) | ||
return torch.cat(latent_locs).detach().cpu().numpy() |
Oops, something went wrong.