Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement Decipher model in external #3015

Merged
merged 51 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9fd03ce
first draft of moving decipher model into scvi-tools
justjhong Oct 10, 2024
3ad77a2
add early stopping based on predictive nll and freeze batch norm afte…
justjhong Oct 14, 2024
c8c987f
add get latent rep to model class
justjhong Oct 14, 2024
bfee40e
remove impute fn for now
justjhong Oct 14, 2024
8c8b52d
revert base module lints
justjhong Oct 14, 2024
0d9039d
converges but with blobby latent space
justjhong Oct 14, 2024
600eba1
fix batch norm freezing and move pll impl into module
justjhong Oct 14, 2024
7fe1beb
fix save/load
justjhong Oct 14, 2024
e48fe3b
drop last and fix loss scaling
justjhong Oct 15, 2024
2baf11d
fix tests, remove validation step from base training plan
justjhong Oct 15, 2024
b67f6c2
first draft of moving decipher model into scvi-tools
justjhong Oct 10, 2024
84ff3b0
add early stopping based on predictive nll and freeze batch norm afte…
justjhong Oct 14, 2024
f22741b
add get latent rep to model class
justjhong Oct 14, 2024
89c9319
remove impute fn for now
justjhong Oct 14, 2024
812d5db
revert base module lints
justjhong Oct 14, 2024
4ba55a2
converges but with blobby latent space
justjhong Oct 14, 2024
0be6d06
fix batch norm freezing and move pll impl into module
justjhong Oct 14, 2024
2cdb539
fix save/load
justjhong Oct 14, 2024
1caa156
drop last and fix loss scaling
justjhong Oct 15, 2024
9cc3aa6
fix tests, remove validation step from base training plan
justjhong Oct 15, 2024
1bc5805
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
6491fbc
fix merge conflicts
justjhong Oct 15, 2024
46a2320
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
13e130e
fix ruff
justjhong Oct 15, 2024
089037c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
477513e
Merge remote-tracking branch 'origin/main' into jhong/decipher
Oct 20, 2024
4a239c9
Merge branch 'main' into jhong/decipher
ori-kron-wis Oct 20, 2024
0d682a0
Merge remote-tracking branch 'origin/jhong/decipher' into jhong/decipher
Oct 20, 2024
e2267ec
fix tensor device bug for get latent
justjhong Oct 21, 2024
e3562e2
add docs and release note for decipher
justjhong Oct 21, 2024
3bd9fab
Merge branch 'jhong/decipher' of github.com:scverse/scvi-tools into j…
justjhong Oct 21, 2024
74f9323
check if this fixes the cuda test
ori-kron-wis Oct 21, 2024
ec85018
check if this fixes the cuda test
ori-kron-wis Oct 21, 2024
94c4da6
check if this fixes the cuda test
ori-kron-wis Oct 21, 2024
8dceb5f
add Decipher to docs
justjhong Oct 22, 2024
f9a30ab
Merge branch 'jhong/decipher' of github.com:scverse/scvi-tools into j…
justjhong Oct 22, 2024
b711502
Merge branch 'main' into jhong/decipher
justjhong Oct 22, 2024
a06d936
fix docs
justjhong Oct 22, 2024
17e7d4a
add doi
justjhong Oct 22, 2024
5db40fd
Merge branch 'main' into jhong/decipher
justjhong Oct 31, 2024
c49e14d
remove prior arg
justjhong Nov 3, 2024
342b48a
Merge branch 'jhong/decipher' of github.com:scverse/scvi-tools into j…
justjhong Nov 3, 2024
e25173b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2024
0c151b5
Merge branch 'main' into jhong/decipher
justjhong Nov 13, 2024
43ab670
address comments
justjhong Nov 13, 2024
9850619
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
61df480
Update test_linux_cuda.yml
justjhong Nov 13, 2024
38e1076
Merge branch 'main' into jhong/decipher
justjhong Nov 13, 2024
26131c5
simplify output slice code
justjhong Nov 13, 2024
c0eda75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
d8c5ac4
move change in changelog to 1.3.0
justjhong Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ to [Semantic Versioning]. Full commit history is available in the

## Version 1.2

### 1.3.0 (2024-XX-XX)

#### Added

#### Fixed

#### Changed

#### Removed

- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable
representation learning in single-cell RNA sequencing data {pr}`3015`.

### 1.2.1 (2024-XX-XX)

#### Added
Expand Down
1 change: 1 addition & 0 deletions docs/api/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Module classes in the external API with respective generative and inference proc
external.velovi.VELOVAE
external.mrvi.MRVAE
external.methylvi.METHYLVAE
external.decipher.DecipherPyroModule

```

Expand Down
1 change: 1 addition & 0 deletions docs/api/user.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import scvi
external.VELOVI
external.MRVI
external.METHYLVI
external.Decipher
```

## Data loading
Expand Down
10 changes: 10 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,16 @@ @article{Martens2023
publisher={Nature Publishing Group}
}

@article{Nazaret23,
title={Deep generative model deciphers derailed trajectories in acute myeloid leukemia},
author={Nazaret, Achille and Fan, Joy Linyue and Lavallee, Vincent-Philippe and Cornish, Andrew E and Kiseliovas, Vaidotas and Masilionis, Ignas and Chun, Jaeyoung and Bowman, Robert L and Eisman, Shira E and Wang, James and others},
journal={bioRxiv},
pages={2023--11},
year={2023},
publisher={Cold Spring Harbor Laboratory},
doi = {10.1101/2023.11.11.566719}
}

@article{Sheng22,
title = {Probabilistic machine learning ensures accurate ambient denoising in droplet-based single-cell omics},
author = {Caibin Sheng and Rui Lopes and Gang Li and Sven Schuierer and Annick Waldt and Rachel Cuttat and Slavica Dimitrieva and Audrey Kauffmann and Eric Durand and Giorgio G. Galli and Guglielmo Roma and Antoine de Weck},
Expand Down
2 changes: 2 additions & 0 deletions src/scvi/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cellassign import CellAssign
from .contrastivevi import ContrastiveVI
from .decipher import Decipher
from .gimvi import GIMVI
from .methylvi import METHYLVI
from .mrvi import MRVI
Expand All @@ -15,6 +16,7 @@
"SCAR",
"SOLO",
"GIMVI",
"Decipher",
"RNAStereoscope",
"SpatialStereoscope",
"CellAssign",
Expand Down
4 changes: 4 additions & 0 deletions src/scvi/external/decipher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._model import Decipher
from ._module import DecipherPyroModule

__all__ = ["Decipher", "DecipherPyroModule"]
99 changes: 99 additions & 0 deletions src/scvi/external/decipher/_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from collections.abc import Sequence

import torch
import torch.nn as nn


class ConditionalDenseNN(nn.Module):
"""Dense neural network with multiple outputs, optionally conditioned on a context variable.

(Derived from pyro.nn.dense_nn.ConditionalDenseNN with some modifications [1])

Parameters
----------
input_dim
Dimension of the input
hidden_dims
Dimensions of the hidden layers (excluding the output layer)
output_dims
Dimensions of each output layer
context_dim
Dimension of the context input.
deep_context_injection
If True, inject the context into every hidden layer.
If False, only inject the context into the first hidden layer
(concatenated with the input).
activation
Activation function to use between hidden layers (not applied to the outputs).
Default: torch.nn.ReLU()
"""

def __init__(
self,
input_dim: int,
hidden_dims: Sequence[int],
output_dims: Sequence = (1,),
context_dim: int = 0,
deep_context_injection: bool = False,
activation=torch.nn.ReLU(),
):
super().__init__()

self.input_dim = input_dim
self.context_dim = context_dim
self.hidden_dims = hidden_dims
self.output_dims = output_dims
self.deep_context_injection = deep_context_injection
self.n_output_layers = len(self.output_dims)
self.output_total_dim = sum(self.output_dims)

# The multiple outputs are computed as a single output layer, and then split
last_output_end_idx = 0
self.output_slices = []
for dim in self.output_dims:
self.output_slices.append(slice(last_output_end_idx, last_output_end_idx + dim))
last_output_end_idx += dim

# Create masked layers
deep_context_dim = self.context_dim if self.deep_context_injection else 0
layers = []
batch_norms = []
if len(hidden_dims):
layers.append(torch.nn.Linear(input_dim + context_dim, hidden_dims[0]))
batch_norms.append(nn.BatchNorm1d(hidden_dims[0]))
for i in range(1, len(hidden_dims)):
layers.append(
torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i])
)
batch_norms.append(nn.BatchNorm1d(hidden_dims[i]))

layers.append(
torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim)
)
else:
layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim))

self.layers = torch.nn.ModuleList(layers)
canergen marked this conversation as resolved.
Show resolved Hide resolved

self.activation_fn = activation
self.batch_norms = torch.nn.ModuleList(batch_norms)

def forward(self, x, context=None):
if context is not None:
# We must be able to broadcast the size of the context over the input
context = context.expand(x.size()[:-1] + (context.size(-1),))

h = x
for i, layer in enumerate(self.layers):
if self.context_dim > 0 and (self.deep_context_injection or i == 0):
h = torch.cat([context, h], dim=-1)
h = layer(h)
if i < len(self.layers) - 1:
h = self.batch_norms[i](h)
h = self.activation_fn(h)

if self.n_output_layers == 1:
return h

h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim])
return tuple([h[..., s] for s in self.output_slices])
153 changes: 153 additions & 0 deletions src/scvi/external/decipher/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import logging
from collections.abc import Sequence

import numpy as np
import pyro
import torch
from anndata import AnnData

from scvi._constants import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import LayerField
from scvi.model.base import BaseModelClass, PyroSviTrainMixin
from scvi.train import PyroTrainingPlan
from scvi.utils import setup_anndata_dsp

from ._module import DecipherPyroModule
from ._trainingplan import DecipherTrainingPlan

logger = logging.getLogger(__name__)


class Decipher(PyroSviTrainMixin, BaseModelClass):
"""Decipher model for single-cell data analysis :cite:p:`Nazaret23`.

Parameters
----------
adata
AnnData object that has been registered via
:meth:`~scvi.model.Decipher.setup_anndata`.
dim_v
Dimension of the interpretable latent space v.
canergen marked this conversation as resolved.
Show resolved Hide resolved
dim_z
Dimension of the intermediate latent space z.
layers_v_to_z
Hidden layer sizes for the v to z decoder network.
layers_z_to_x
Hidden layer sizes for the z to x decoder network.
beta
Regularization parameter for the KL divergence.
"""

_module_cls = DecipherPyroModule
_training_plan_cls = DecipherTrainingPlan

def __init__(self, adata: AnnData, **kwargs):
pyro.clear_param_store()

super().__init__(adata)

dim_genes = self.summary_stats.n_vars

self.module = self._module_cls(
dim_genes,
**kwargs,
)

self.init_params_ = self._get_init_params(locals())

@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
**kwargs,
) -> AnnData | None:
"""%(summary)s.

Parameters
----------
%(param_adata)s
%(param_layer)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

def train(
self,
max_epochs: int | None = None,
accelerator: str = "auto",
device: int | str = "auto",
train_size: float = 0.9,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
early_stopping: bool = False,
training_plan: PyroTrainingPlan | None = None,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
**trainer_kwargs,
):
if "early_stopping_monitor" not in trainer_kwargs:
trainer_kwargs["early_stopping_monitor"] = "nll_validation"
datasplitter_kwargs = datasplitter_kwargs or {}
if "drop_last" not in datasplitter_kwargs:
datasplitter_kwargs["drop_last"] = True
justjhong marked this conversation as resolved.
Show resolved Hide resolved
super().train(
max_epochs=max_epochs,
accelerator=accelerator,
device=device,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
batch_size=batch_size,
early_stopping=early_stopping,
plan_kwargs=plan_kwargs,
training_plan=training_plan,
datasplitter_kwargs=datasplitter_kwargs,
**trainer_kwargs,
)

def get_latent_representation(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
give_z: bool = False,
) -> np.ndarray:
"""Get the latent representation of the data.

Parameters
----------
adata
AnnData object with the data to get the latent representation of.
indices
Indices of the data to get the latent representation of.
batch_size
Batch size to use for the data loader.
give_z
Whether to return the intermediate latent space z or the top-level
latent space v.
"""
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent_locs = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
x = torch.log1p(x)
x = x.to(self.module.device)
z_loc, _ = self.module.encoder_x_to_z(x)
if give_z:
latent_locs.append(z_loc)
else:
v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1))
latent_locs.append(v_loc)
return torch.cat(latent_locs).detach().cpu().numpy()
Loading
Loading