Skip to content

Commit

Permalink
feat(external): implement METHLYVI for scBS-seq data (#2834)
Browse files Browse the repository at this point in the history
This PR implements methylation variational inference (methylVI) for
modeling methylation data from single-cell bisulfite sequencing
(scBS-seq) experiments, described originally in [Weinberger and Lee
(2024).](https://openreview.net/forum?id=Mg2DM0F3AY). In short, methylVI
assumes that counts of methylated cytosines across pre-defined genomic
region features follow a [Beta-Binomial
distribution](https://docs.scvi-tools.org/en/latest/api/reference/scvi.distributions.BetaBinomial.html).
As a placeholder citation I put in a reference to the workshop paper
with the intention of swapping in a citation to the full paper once we
upload it to bioRxiv.

All implementation code can be found in `scvi/external/methylvi`. Of
note, as part of this PR I've added a new `DecoderMETHYLVI` class in
`external/methylvi/_base_components.py` for obtaining decoded
Beta-Binomial parameters for BS-seq data. This class may be useful for
the development of future models, and might be worth moving out of
external. Otherwise, the bulk of the implementation code can be found in
`scvi/external/methylvi/_model.py` and
`scvi/external/methylvi/_module.py`.

I structured my implementation of methylVI so that the model could be
used with a single methylation modality an AnnData object (e.g. CpG gene
body features) or multiple modalities in a MuData object (e.g. both CpG
and non-CpG gene body features as separate modalities). To avoid
repeating too much code, under the hood I tried to treat models trained
with AnnData's as if they were trained using a MuData with a single
modality (see my comments for more details).

@martinkim0 @canergen for testing I added some basic sanity check tests
(loosely based on the tests in `tests/model/test_peakvi.py`) confirming
that (1) the model trains without errors and (2) downstream analysis
functions (e.g. `get_latent_representation`) run without errors. Let me
know if there are additional tests you'd like to see.

---------

Co-authored-by: Ethan Weinberger <[email protected]>
Co-authored-by: Martin Kim <[email protected]>
Co-authored-by: Ori Kronfeld <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Oct 6, 2024
1 parent b5d81d4 commit fc79a3a
Show file tree
Hide file tree
Showing 14 changed files with 1,317 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ to [Semantic Versioning]. Full commit history is available in the
data {pr}`2756`.
- Add support for reference mapping with {class}`mudata.MuData` models to
{class}`scvi.model.base.ArchesMixin` {pr}`2578`.
- Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell
bisulfite sequencing (scBS-seq) experiments {pr}`2834`.
- Add argument `return_mean` to {meth}`scvi.model.base.VAEMixin.get_reconstruction_error`
and {meth}`scvi.model.base.VAEMixin.get_elbo` to allow computation
without averaging across cells {pr}`2362`.
Expand Down
1 change: 1 addition & 0 deletions docs/api/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ Module classes in the external API with respective generative and inference proc
external.contrastivevi.ContrastiveVAE
external.velovi.VELOVAE
external.mrvi.MRVAE
external.methylvi.METHYLVAE
```

Expand Down
2 changes: 1 addition & 1 deletion docs/api/user.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ import scvi
external.POISSONVI
external.VELOVI
external.MRVI
external.METHYLVI
```

## Data loading
Expand Down
8 changes: 8 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ @article{Weinberger23
publisher={Nature Publishing Group}
}

@inproceedings{Weinberger2023a,
title={A deep generative model of single-cell methylomic data},
author={Ethan Weinberger and Su-In Lee},
booktitle={NeurIPS 2023 Generative AI and Biology (GenBio) Workshop},
year={2023},
url={https://openreview.net/forum?id=Mg2DM0F3AY}
}

@article{Xu21,
title = {Probabilistic harmonization and annotation of single-cell transcriptomics data with deep generative models},
author = {Chenling Xu and Romain Lopez and Edouard Mehlman and Jeffrey Regier and Michael I. Jordan and Nir Yosef},
Expand Down
15 changes: 15 additions & 0 deletions docs/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ scvi-tools is composed of models that can perform one or many analysis tasks. In
- :cite:p:`Martens2023`
```

## BS-seq analysis

```{eval-rst}
.. list-table::
:widths: 15 100 25
:header-rows: 1
* - Model
- Tasks
- Reference
* - :doc:`/user_guide/models/methylvi`
- Dimensionality reduction, removal of unwanted variation, integration across replicates, donors, and technologies, differential methylation, imputation, normalization of other cell- and sample-level confounding factors
- :cite:p:`Weinberger2023a`
```

## Multimodal analysis

### CITE-seq
Expand Down
161 changes: 161 additions & 0 deletions docs/user_guide/models/methylvi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# MethylVI

**methylVI** [^ref1] (Python class {class}`~scvi.external.METHYLVI`) is a generative model of scBS-seq data that can subsequently
be used for many common downstream tasks.

The advantages of methylVI are:

- Comprehensive in capabilities.
- Scalable to very large datasets (>1 million cells).

The limitations of methylVI include:

- Effectively requires a GPU for fast inference.
- Latent space is not interpretable, unlike that of a linear method.

```{topic} Tutorials:
- Work in progress.
```

## Preliminaries

MethylVI takes as input scBS-seq count matrices representing methylation measurements aggregated over pre-defined
regions of interest (e.g. gene bodies, known regulatory regions, etc.). Depending on the system being investigated,
such measurements may be separated based on methylation context (e.g. CpG methylation versus non-CpG methylation).

For each context, methylVI accepts two count matrices as input $Y^{C}_{mc}$ and $Y^{C}_{cov}$. Here $C$ refers to
an arbitrary methylation context, and each of these matrices has data from $N$ cells and $M$ genomic regions.
Each entry in $Y_{cov}$ represents the _total_ number of cytosines profiled at a given region in a cell, while the
entries in $Y_{mc}$ denote the number of _methylated_ cytosines in a region for a cell. Additionally, a vector of
categorical covariates $S$, representing batch, donor, etc, is an optional input to the model.

## Generative process

MethylVI posits that the observed number of methylated cytosines in context $C$ for cell $i$ in region $j$,
$y^{C}_{ij}$, is generated by the following process:

```{math}
:nowrap: true
\begin{align}
z_{i} &\sim \mathcal{N}(0, I_d) \\
\mu^{C}_{ij} &= f_{\theta^{C}}(z_{i}, s_i)_j \\
p^{C}_{ijk} &\sim \text{Beta}(\mu^{C}_{ij}, \gamma^{C}_j) \\
y^{C}_{ijk} &\sim \text{Ber}(p^{C}_{ijk}) \\
y^{C}_{ij} &= \sum_{k}y^{C}_{ijk}
\end{align}
```

In brief, we assume that detection of an individual cytosine $k$ within region $j$ for cell $i$ as methylated
can be modeled as a Bernoulli random variable. The parameters of these Bernoulli distributions are
assumed to be similar for all cytosines $k$ within region $j$, which we model as draws from a Beta distribution with
parameters that depend on a cell-specific latent variable $z_i$ that captures underlying methylation state as well
as a batch covariate $s_i$. The outcomes of these Bernoulli draws are then summed to obtain our number of methylated
cytosines within the given region.

The above hierarchical process can be expressed more compactly as:

```{math}
:nowrap: true
\begin{align}
z_{i} &\sim \mathcal{N}(0, I_d) \\
\mu^{C}_{ij} &= f_{\theta^{C}}(z_{i}, s_i)_j \\
y^{C}_{ij} &\sim \text{BetaBinomial}(n_{ij}, \mu_{ij}, \gamma_{j})
\end{align}
```

For each methylation context $C$, the MethylVI generative process uses a single neural network:

```{math}
:nowrap: true
\begin{align}
f_{\theta^{C}}(z_{i}, s_i) &: \mathbb{R}^{d} \times \{0, 1\}^K \to \left(0,1\right)^M
\end{align}
```

which estimates regions' the methylation levels.

The latent variables, along with their description are summarized in the following table:

```{eval-rst}
.. list-table::
:widths: 20 90 15
:header-rows: 1
* - Latent variable
- Description
- Code variable (if different)
* - :math:`z_i \in \mathbb{R}^d`
- Low-dimensional representation capturing the state of a cell
- ``z``
* - :math:`\mu_i \in \left(0,1\right)^{M}`
- Per-region methylation level estimates
- ``mu``
* - :math:`\gamma_i \in \left(0,1\right)`
- Region-wise dispersion factor
- ``d``
```

## Inference

MethylVI uses variational inference, specifically auto-encoding variational Bayes
(see {doc}`/user_guide/background/variational_inference`) to learn both the model parameters (the neural network params,
dispersion parameters, etc.) and an approximate posterior distribution. In particular, we approximate the true posterior
distribution with a mean-field variational distribution $q_{\phi}(z_i \mid y_i, n_i, s_i)$ chosen to be Gaussian
with a diagonal covariance matrix. Here $y_i$ ($n_i$) is used as a shorthand to denote the concatenation of the numbers
of methylated (total) cytosines for each region in all contexts, and $\phi$ denotes a set of learned weights used to
infer the parameters of our approximate posterior.

## Tasks

Here we provide an overview of some of the tasks that MethylVI can perform. Please see {class}`scvi.external.METHYLVI`
for the full API reference.

### Dimensionality reduction

For dimensionality reduction, the mean of the approximate posterior $q_\phi(z_i \mid y_i, n_i)$ is returned by default.
This is achieved using the method:

```
>>> adata.obsm["X_methylvi"] = model.get_latent_representation()
```

Users may also return samples from this distribution, as opposed to the mean, by passing the argument `give_mean=False`.
The latent representation can be used to create a nearest neighbor graph with scanpy with:

```
>>> import scanpy as sc
>>> sc.pp.neighbors(adata, use_rep="X_methylvi")
>>> adata.obsp["distances"]
```

### Transfer learning

A MethylVI model can be pre-trained on reference data and updated with query data using {meth}`~scvi.external.METHYLVI.load_query_data`, which then facilitates transfer of metadata like cell type annotations. See the {doc}`/user_guide/background/transfer_learning` guide for more information.

### Estimation of methylation levels

In {meth}`~scvi.external.METHYLVI.get_normalized_methylation` MethylVI returns the expected value of $\mu_i$ under the approximate posterior. For one cell $i$, this can be written as:

```{math}
:nowrap: true
\begin{align}
\mathbb{E}_{q_\phi(z_i \mid y_i, n_i)}\left[f_{\theta}\left(z_{i}, s_i)\right) \right],
\end{align}
```

As the expectation can be expensive to compute, by default, MethylVI uses the mean of $z_i$ as a point estimate, but this behaviour can be changed by setting `use_z_mean=False` argument.

### Differential methylation

Differential methylation analysis is achieved with {meth}`~scvi.external.METHYLVI.differential_methylation`.
MethylVI tests differences in methylation levels $\mu^{C}_{i} = f_{\theta^{C}}\left(z_{i}, s_i)\right)$.

[^ref1]:
Ethan Weinberger and Su-In Lee (2021),
_A deep generative model of single-cell methylomic data_,
[OpenReview](https://openreview.net/forum?id=Mg2DM0F3AY).
2 changes: 2 additions & 0 deletions src/scvi/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .cellassign import CellAssign
from .contrastivevi import ContrastiveVI
from .gimvi import GIMVI
from .methylvi import METHYLVI
from .mrvi import MRVI
from .poissonvi import POISSONVI
from .scar import SCAR
Expand All @@ -23,4 +24,5 @@
"ContrastiveVI",
"VELOVI",
"MRVI",
"METHYLVI",
]
6 changes: 6 additions & 0 deletions src/scvi/external/methylvi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._base_components import DecoderMETHYLVI
from ._constants import METHYLVI_REGISTRY_KEYS
from ._model import METHYLVI as METHYLVI
from ._module import METHYLVAE

__all__ = ["METHYLVI_REGISTRY_KEYS", "DecoderMETHYLVI", "METHYLVAE", "METHYLVI"]
113 changes: 113 additions & 0 deletions src/scvi/external/methylvi/_base_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from collections.abc import Iterable

import torch
from torch import nn

from scvi.nn import FCLayers


class DecoderMETHYLVI(nn.Module):
"""Decodes data from latent space of ``n_input`` dimensions into ``n_output`` dimensions.
Uses a fully-connected neural network of ``n_hidden`` layers.
Parameters
----------
n_input
The dimensionality of the input (latent space)
n_output
The dimensionality of the output (data space)
n_cat_list
A list containing the number of categories
for each category of interest. Each category will be
included using a one-hot encoding
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
inject_covariates
Whether to inject covariates in each layer, or just the first (default).
use_batch_norm
Whether to use batch norm in layers
use_layer_norm
Whether to use layer norm in layers
scale_activation
Activation layer to use for px_scale_decoder
**kwargs
Keyword args for :class:`~scvi.nn.FCLayers`.
"""

def __init__(
self,
n_input: int,
n_output: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
inject_covariates: bool = True,
use_batch_norm: bool = False,
use_layer_norm: bool = False,
**kwargs,
):
super().__init__()
self.px_decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0,
inject_covariates=inject_covariates,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
**kwargs,
)

self.px_mu_decoder = nn.Sequential(
nn.Linear(n_hidden, n_output),
nn.Sigmoid(),
)
self.px_gamma_decoder = nn.Sequential(
nn.Linear(n_hidden, n_output),
nn.Sigmoid(),
)

def forward(
self,
dispersion: str,
z: torch.Tensor,
*cat_list: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""The forward computation for a single sample.
#. Decodes the data from the latent space using the decoder network
#. Returns parameters for the beta-binomial distribution of methylation
#. If ``dispersion != 'region-cell'`` then value for that param will be ``None``
Parameters
----------
dispersion
One of the following
* ``'region'`` - dispersion parameter of NB is constant per region across cells
* ``'region-cell'`` - dispersion can differ for every region in every cell
z :
tensor with shape ``(n_input,)``
library_size
library size
cat_list
list of category membership(s) for this sample
Returns
-------
2-tuple of :py:class:`torch.Tensor`
parameters for the Beta distribution of mean methylation values
"""
px = self.px_decoder(z, *cat_list)
px_mu = self.px_mu_decoder(px)
px_gamma = self.px_gamma_decoder(px) if dispersion == "region-cell" else None

return px_mu, px_gamma
9 changes: 9 additions & 0 deletions src/scvi/external/methylvi/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import NamedTuple


class _METHYLVI_REGISTRY_KEYS_NT(NamedTuple):
MC_KEY: str = "mc"
COV_KEY: str = "cov"


METHYLVI_REGISTRY_KEYS = _METHYLVI_REGISTRY_KEYS_NT()
Loading

0 comments on commit fc79a3a

Please sign in to comment.