-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(external): implement METHLYVI for scBS-seq data (#2834)
This PR implements methylation variational inference (methylVI) for modeling methylation data from single-cell bisulfite sequencing (scBS-seq) experiments, described originally in [Weinberger and Lee (2024).](https://openreview.net/forum?id=Mg2DM0F3AY). In short, methylVI assumes that counts of methylated cytosines across pre-defined genomic region features follow a [Beta-Binomial distribution](https://docs.scvi-tools.org/en/latest/api/reference/scvi.distributions.BetaBinomial.html). As a placeholder citation I put in a reference to the workshop paper with the intention of swapping in a citation to the full paper once we upload it to bioRxiv. All implementation code can be found in `scvi/external/methylvi`. Of note, as part of this PR I've added a new `DecoderMETHYLVI` class in `external/methylvi/_base_components.py` for obtaining decoded Beta-Binomial parameters for BS-seq data. This class may be useful for the development of future models, and might be worth moving out of external. Otherwise, the bulk of the implementation code can be found in `scvi/external/methylvi/_model.py` and `scvi/external/methylvi/_module.py`. I structured my implementation of methylVI so that the model could be used with a single methylation modality an AnnData object (e.g. CpG gene body features) or multiple modalities in a MuData object (e.g. both CpG and non-CpG gene body features as separate modalities). To avoid repeating too much code, under the hood I tried to treat models trained with AnnData's as if they were trained using a MuData with a single modality (see my comments for more details). @martinkim0 @canergen for testing I added some basic sanity check tests (loosely based on the tests in `tests/model/test_peakvi.py`) confirming that (1) the model trains without errors and (2) downstream analysis functions (e.g. `get_latent_representation`) run without errors. Let me know if there are additional tests you'd like to see. --------- Co-authored-by: Ethan Weinberger <[email protected]> Co-authored-by: Martin Kim <[email protected]> Co-authored-by: Ori Kronfeld <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
b5d81d4
commit fc79a3a
Showing
14 changed files
with
1,317 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,7 @@ import scvi | |
external.POISSONVI | ||
external.VELOVI | ||
external.MRVI | ||
external.METHYLVI | ||
``` | ||
|
||
## Data loading | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# MethylVI | ||
|
||
**methylVI** [^ref1] (Python class {class}`~scvi.external.METHYLVI`) is a generative model of scBS-seq data that can subsequently | ||
be used for many common downstream tasks. | ||
|
||
The advantages of methylVI are: | ||
|
||
- Comprehensive in capabilities. | ||
- Scalable to very large datasets (>1 million cells). | ||
|
||
The limitations of methylVI include: | ||
|
||
- Effectively requires a GPU for fast inference. | ||
- Latent space is not interpretable, unlike that of a linear method. | ||
|
||
```{topic} Tutorials: | ||
- Work in progress. | ||
``` | ||
|
||
## Preliminaries | ||
|
||
MethylVI takes as input scBS-seq count matrices representing methylation measurements aggregated over pre-defined | ||
regions of interest (e.g. gene bodies, known regulatory regions, etc.). Depending on the system being investigated, | ||
such measurements may be separated based on methylation context (e.g. CpG methylation versus non-CpG methylation). | ||
|
||
For each context, methylVI accepts two count matrices as input $Y^{C}_{mc}$ and $Y^{C}_{cov}$. Here $C$ refers to | ||
an arbitrary methylation context, and each of these matrices has data from $N$ cells and $M$ genomic regions. | ||
Each entry in $Y_{cov}$ represents the _total_ number of cytosines profiled at a given region in a cell, while the | ||
entries in $Y_{mc}$ denote the number of _methylated_ cytosines in a region for a cell. Additionally, a vector of | ||
categorical covariates $S$, representing batch, donor, etc, is an optional input to the model. | ||
|
||
## Generative process | ||
|
||
MethylVI posits that the observed number of methylated cytosines in context $C$ for cell $i$ in region $j$, | ||
$y^{C}_{ij}$, is generated by the following process: | ||
|
||
```{math} | ||
:nowrap: true | ||
\begin{align} | ||
z_{i} &\sim \mathcal{N}(0, I_d) \\ | ||
\mu^{C}_{ij} &= f_{\theta^{C}}(z_{i}, s_i)_j \\ | ||
p^{C}_{ijk} &\sim \text{Beta}(\mu^{C}_{ij}, \gamma^{C}_j) \\ | ||
y^{C}_{ijk} &\sim \text{Ber}(p^{C}_{ijk}) \\ | ||
y^{C}_{ij} &= \sum_{k}y^{C}_{ijk} | ||
\end{align} | ||
``` | ||
|
||
In brief, we assume that detection of an individual cytosine $k$ within region $j$ for cell $i$ as methylated | ||
can be modeled as a Bernoulli random variable. The parameters of these Bernoulli distributions are | ||
assumed to be similar for all cytosines $k$ within region $j$, which we model as draws from a Beta distribution with | ||
parameters that depend on a cell-specific latent variable $z_i$ that captures underlying methylation state as well | ||
as a batch covariate $s_i$. The outcomes of these Bernoulli draws are then summed to obtain our number of methylated | ||
cytosines within the given region. | ||
|
||
The above hierarchical process can be expressed more compactly as: | ||
|
||
```{math} | ||
:nowrap: true | ||
\begin{align} | ||
z_{i} &\sim \mathcal{N}(0, I_d) \\ | ||
\mu^{C}_{ij} &= f_{\theta^{C}}(z_{i}, s_i)_j \\ | ||
y^{C}_{ij} &\sim \text{BetaBinomial}(n_{ij}, \mu_{ij}, \gamma_{j}) | ||
\end{align} | ||
``` | ||
|
||
For each methylation context $C$, the MethylVI generative process uses a single neural network: | ||
|
||
```{math} | ||
:nowrap: true | ||
\begin{align} | ||
f_{\theta^{C}}(z_{i}, s_i) &: \mathbb{R}^{d} \times \{0, 1\}^K \to \left(0,1\right)^M | ||
\end{align} | ||
``` | ||
|
||
which estimates regions' the methylation levels. | ||
|
||
The latent variables, along with their description are summarized in the following table: | ||
|
||
```{eval-rst} | ||
.. list-table:: | ||
:widths: 20 90 15 | ||
:header-rows: 1 | ||
* - Latent variable | ||
- Description | ||
- Code variable (if different) | ||
* - :math:`z_i \in \mathbb{R}^d` | ||
- Low-dimensional representation capturing the state of a cell | ||
- ``z`` | ||
* - :math:`\mu_i \in \left(0,1\right)^{M}` | ||
- Per-region methylation level estimates | ||
- ``mu`` | ||
* - :math:`\gamma_i \in \left(0,1\right)` | ||
- Region-wise dispersion factor | ||
- ``d`` | ||
``` | ||
|
||
## Inference | ||
|
||
MethylVI uses variational inference, specifically auto-encoding variational Bayes | ||
(see {doc}`/user_guide/background/variational_inference`) to learn both the model parameters (the neural network params, | ||
dispersion parameters, etc.) and an approximate posterior distribution. In particular, we approximate the true posterior | ||
distribution with a mean-field variational distribution $q_{\phi}(z_i \mid y_i, n_i, s_i)$ chosen to be Gaussian | ||
with a diagonal covariance matrix. Here $y_i$ ($n_i$) is used as a shorthand to denote the concatenation of the numbers | ||
of methylated (total) cytosines for each region in all contexts, and $\phi$ denotes a set of learned weights used to | ||
infer the parameters of our approximate posterior. | ||
|
||
## Tasks | ||
|
||
Here we provide an overview of some of the tasks that MethylVI can perform. Please see {class}`scvi.external.METHYLVI` | ||
for the full API reference. | ||
|
||
### Dimensionality reduction | ||
|
||
For dimensionality reduction, the mean of the approximate posterior $q_\phi(z_i \mid y_i, n_i)$ is returned by default. | ||
This is achieved using the method: | ||
|
||
``` | ||
>>> adata.obsm["X_methylvi"] = model.get_latent_representation() | ||
``` | ||
|
||
Users may also return samples from this distribution, as opposed to the mean, by passing the argument `give_mean=False`. | ||
The latent representation can be used to create a nearest neighbor graph with scanpy with: | ||
|
||
``` | ||
>>> import scanpy as sc | ||
>>> sc.pp.neighbors(adata, use_rep="X_methylvi") | ||
>>> adata.obsp["distances"] | ||
``` | ||
|
||
### Transfer learning | ||
|
||
A MethylVI model can be pre-trained on reference data and updated with query data using {meth}`~scvi.external.METHYLVI.load_query_data`, which then facilitates transfer of metadata like cell type annotations. See the {doc}`/user_guide/background/transfer_learning` guide for more information. | ||
|
||
### Estimation of methylation levels | ||
|
||
In {meth}`~scvi.external.METHYLVI.get_normalized_methylation` MethylVI returns the expected value of $\mu_i$ under the approximate posterior. For one cell $i$, this can be written as: | ||
|
||
```{math} | ||
:nowrap: true | ||
\begin{align} | ||
\mathbb{E}_{q_\phi(z_i \mid y_i, n_i)}\left[f_{\theta}\left(z_{i}, s_i)\right) \right], | ||
\end{align} | ||
``` | ||
|
||
As the expectation can be expensive to compute, by default, MethylVI uses the mean of $z_i$ as a point estimate, but this behaviour can be changed by setting `use_z_mean=False` argument. | ||
|
||
### Differential methylation | ||
|
||
Differential methylation analysis is achieved with {meth}`~scvi.external.METHYLVI.differential_methylation`. | ||
MethylVI tests differences in methylation levels $\mu^{C}_{i} = f_{\theta^{C}}\left(z_{i}, s_i)\right)$. | ||
|
||
[^ref1]: | ||
Ethan Weinberger and Su-In Lee (2021), | ||
_A deep generative model of single-cell methylomic data_, | ||
[OpenReview](https://openreview.net/forum?id=Mg2DM0F3AY). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from ._base_components import DecoderMETHYLVI | ||
from ._constants import METHYLVI_REGISTRY_KEYS | ||
from ._model import METHYLVI as METHYLVI | ||
from ._module import METHYLVAE | ||
|
||
__all__ = ["METHYLVI_REGISTRY_KEYS", "DecoderMETHYLVI", "METHYLVAE", "METHYLVI"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from collections.abc import Iterable | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from scvi.nn import FCLayers | ||
|
||
|
||
class DecoderMETHYLVI(nn.Module): | ||
"""Decodes data from latent space of ``n_input`` dimensions into ``n_output`` dimensions. | ||
Uses a fully-connected neural network of ``n_hidden`` layers. | ||
Parameters | ||
---------- | ||
n_input | ||
The dimensionality of the input (latent space) | ||
n_output | ||
The dimensionality of the output (data space) | ||
n_cat_list | ||
A list containing the number of categories | ||
for each category of interest. Each category will be | ||
included using a one-hot encoding | ||
n_layers | ||
The number of fully-connected hidden layers | ||
n_hidden | ||
The number of nodes per hidden layer | ||
dropout_rate | ||
Dropout rate to apply to each of the hidden layers | ||
inject_covariates | ||
Whether to inject covariates in each layer, or just the first (default). | ||
use_batch_norm | ||
Whether to use batch norm in layers | ||
use_layer_norm | ||
Whether to use layer norm in layers | ||
scale_activation | ||
Activation layer to use for px_scale_decoder | ||
**kwargs | ||
Keyword args for :class:`~scvi.nn.FCLayers`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_input: int, | ||
n_output: int, | ||
n_cat_list: Iterable[int] = None, | ||
n_layers: int = 1, | ||
n_hidden: int = 128, | ||
inject_covariates: bool = True, | ||
use_batch_norm: bool = False, | ||
use_layer_norm: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.px_decoder = FCLayers( | ||
n_in=n_input, | ||
n_out=n_hidden, | ||
n_cat_list=n_cat_list, | ||
n_layers=n_layers, | ||
n_hidden=n_hidden, | ||
dropout_rate=0, | ||
inject_covariates=inject_covariates, | ||
use_batch_norm=use_batch_norm, | ||
use_layer_norm=use_layer_norm, | ||
**kwargs, | ||
) | ||
|
||
self.px_mu_decoder = nn.Sequential( | ||
nn.Linear(n_hidden, n_output), | ||
nn.Sigmoid(), | ||
) | ||
self.px_gamma_decoder = nn.Sequential( | ||
nn.Linear(n_hidden, n_output), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward( | ||
self, | ||
dispersion: str, | ||
z: torch.Tensor, | ||
*cat_list: int, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""The forward computation for a single sample. | ||
#. Decodes the data from the latent space using the decoder network | ||
#. Returns parameters for the beta-binomial distribution of methylation | ||
#. If ``dispersion != 'region-cell'`` then value for that param will be ``None`` | ||
Parameters | ||
---------- | ||
dispersion | ||
One of the following | ||
* ``'region'`` - dispersion parameter of NB is constant per region across cells | ||
* ``'region-cell'`` - dispersion can differ for every region in every cell | ||
z : | ||
tensor with shape ``(n_input,)`` | ||
library_size | ||
library size | ||
cat_list | ||
list of category membership(s) for this sample | ||
Returns | ||
------- | ||
2-tuple of :py:class:`torch.Tensor` | ||
parameters for the Beta distribution of mean methylation values | ||
""" | ||
px = self.px_decoder(z, *cat_list) | ||
px_mu = self.px_mu_decoder(px) | ||
px_gamma = self.px_gamma_decoder(px) if dispersion == "region-cell" else None | ||
|
||
return px_mu, px_gamma |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from typing import NamedTuple | ||
|
||
|
||
class _METHYLVI_REGISTRY_KEYS_NT(NamedTuple): | ||
MC_KEY: str = "mc" | ||
COV_KEY: str = "cov" | ||
|
||
|
||
METHYLVI_REGISTRY_KEYS = _METHYLVI_REGISTRY_KEYS_NT() |
Oops, something went wrong.