Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding model with cycle consistency and VampPrior #2421

Draft
wants to merge 76 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
a582251
add model and tests
Hrovatin Jan 14, 2024
bb359ac
update documentation
Hrovatin Jan 14, 2024
14e41f1
move embedding to device
Hrovatin Jan 14, 2024
5863b1f
Merge branch 'main' into main
martinkim0 Jan 19, 2024
3f49266
pr comments
Hrovatin Jan 21, 2024
6605682
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Jan 21, 2024
3319fc8
Merge branch 'main' into main
martinkim0 Jan 22, 2024
3a67d1c
Merge branch 'main' into main
martinkim0 Feb 5, 2024
5edad83
updates
Hrovatin Feb 7, 2024
a4b080e
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Feb 7, 2024
a24ef28
Merge branch 'main' into main
martinkim0 Feb 20, 2024
9b05bca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
8c25dba
Merge branch 'main' into main
martinkim0 Mar 11, 2024
c5f5c37
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
5b4838c
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
c885e20
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
661bbc6
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
9a49d24
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
3622eee
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
e4c1ef9
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
0f7bd06
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
9e0cba9
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
54f5734
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
df959ed
merge
Hrovatin Sep 14, 2024
f65c403
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
5fcb6f6
merge
Hrovatin Sep 14, 2024
3f1cffe
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Sep 14, 2024
3c93e7f
merge
Hrovatin Sep 14, 2024
3e8ffd0
extend var documentation and remove unused "linear" mode
Hrovatin Sep 14, 2024
4ce4614
update var documentation
Hrovatin Sep 14, 2024
f81abda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
bdd2c7a
var activation to softplus
Hrovatin Sep 14, 2024
4507602
clarify pseudoinputs_data_indices description and assert it matches s…
Hrovatin Sep 14, 2024
cc19d91
also check ndim for pseudoinputs_data_indices
Hrovatin Sep 14, 2024
c15b8a2
remove obtainng cycle latent representations in user interface
Hrovatin Sep 14, 2024
aacdafe
latent always returns np and add option to return_dist
Hrovatin Sep 14, 2024
de6db13
bugfix
Hrovatin Sep 14, 2024
235767c
rm unused cycle latent retrieval
Hrovatin Sep 14, 2024
5f04d65
bugfix
Hrovatin Sep 14, 2024
69759ec
rm adata validation parts repeated in super
Hrovatin Sep 14, 2024
2e447c8
bugfix
Hrovatin Sep 14, 2024
9e8cc35
put back original _validate_anndata
Hrovatin Sep 14, 2024
b0829c0
remove too many custom checks from adata validation
Hrovatin Sep 14, 2024
bf2e850
covariate type explanation
Hrovatin Sep 14, 2024
bd5a882
revert var exp
Hrovatin Sep 14, 2024
7f50353
bugfix
Hrovatin Sep 14, 2024
01db60a
rm reconstr loss f that not reused
Hrovatin Sep 14, 2024
6cdd3d6
explain why cycle loss computed on standardized values
Hrovatin Sep 14, 2024
62eb924
add return statement
Hrovatin Sep 14, 2024
b9a9047
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Sep 14, 2024
b95cb03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
9a2a33e
Merge branch 'main' of https://github.com/scverse/scvi-tools
Hrovatin Oct 19, 2024
c049ea1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2024
041c6a4
unify SysVI with scvi-tools code
Hrovatin Oct 20, 2024
7743c19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2024
1641654
use adatamanager to access filed statistics
Hrovatin Oct 26, 2024
7c3bddc
documentation
Hrovatin Oct 26, 2024
eb872bc
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 26, 2024
5015e82
optionally change var activation function
Hrovatin Oct 27, 2024
b64b2ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
9b78ba3
documentation improvements
Hrovatin Oct 27, 2024
2aab4b4
Use real instead of mock covariates in cycle
Hrovatin Oct 27, 2024
1b50be1
improve documentation
Hrovatin Oct 27, 2024
1abf15d
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 27, 2024
6d32d89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
2e59be7
improve documentation
Hrovatin Oct 27, 2024
4e0988e
fix bug introduced when renaming parameters in generative function
Hrovatin Oct 27, 2024
f0eb2b0
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 27, 2024
3f8be47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
776a8e2
rename tests to prevent automatic test failure
Hrovatin Oct 27, 2024
ebbe1ca
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 27, 2024
35e3385
fix typo in docstring and formatting
Hrovatin Oct 27, 2024
4601808
ruff fixes
Hrovatin Oct 27, 2024
87a367e
bugfix in embedding of covariates
Hrovatin Nov 1, 2024
df4b335
bugfix in test for checking cov embeding
Hrovatin Nov 1, 2024
6a2e3eb
Change var activation to softplus
Hrovatin Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scvi/external/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .scbasset import SCBASSET
from .solo import SOLO
from .stereoscope import RNAStereoscope, SpatialStereoscope
from .sysvi import SysVI
from .tangram import Tangram

__all__ = [
Expand All @@ -19,4 +20,5 @@
"SCBASSET",
"POISSONVI",
"ContrastiveVI",
"SysVI",
]
3 changes: 3 additions & 0 deletions scvi/external/sysvi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._model import SysVI

__all__ = ["SysVI"]
347 changes: 347 additions & 0 deletions scvi/external/sysvi/_base_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
from __future__ import annotations

from collections import OrderedDict
from typing import Literal

import torch
from torch.distributions import Normal
from torch.nn import (
BatchNorm1d,
Dropout,
LayerNorm,
Linear,
Module,
Parameter,
ReLU,
Sequential,
)


class Embedding(Module):
Hrovatin marked this conversation as resolved.
Show resolved Hide resolved
"""Module for obtaining embedding of categorical covariates

Parameters
----------
size
N categories
cov_embed_dims
Dimensions of embedding
normalize
Apply layer normalization
"""

def __init__(self, size, cov_embed_dims: int = 10, normalize: bool = True):
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()

self.normalize = normalize

self.embedding = torch.nn.Embedding(size, cov_embed_dims)

if self.normalize:
# TODO this could probably be implemented more efficiently as embed gives same result for every sample in
# a give class. However, if we have many balanced classes there wont be many repetitions within minibatch
self.layer_norm = torch.nn.LayerNorm(
cov_embed_dims, elementwise_affine=False
)

def forward(self, x):
x = self.embedding(x)
if self.normalize:
x = self.layer_norm(x)

return x


class EncoderDecoder(Module):
"""Module that can be used as probabilistic encoder or decoder

Based on inputs and optional covariates predicts output mean and var

Parameters
----------
n_input
n_output
n_cov
n_hidden
n_layers
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
var_eps
See :class:`~scvi.external.sysvi.nn.VarEncoder`
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
var_mode
See :class:`~scvi.external.sysvi.nn.VarEncoder`
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
sample
Return samples from predicted distribution
kwargs
Passed to :class:`~scvi.external.sysvi.nn.Layers`
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
n_input: int,
n_output: int,
n_cov: int,
n_hidden: int = 256,
n_layers: int = 3,
var_eps: float = 1e-4,
var_mode: str = "feature",
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
sample: bool = False,
**kwargs,
):
super().__init__()
self.sample = sample

self.var_eps = var_eps

self.decoder_y = Layers(
n_in=n_input,
n_cov=n_cov,
n_out=n_hidden,
n_hidden=n_hidden,
n_layers=n_layers,
**kwargs,
)

self.mean_encoder = Linear(n_hidden, n_output)
self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps)

def forward(self, x, cov: torch.Tensor | None = None):
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
y = self.decoder_y(x=x, cov=cov)
# TODO better handling of inappropriate edge-case values than nan_to_num or at least warn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have there been edge case values other than NaNs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike nan_to_num. It's numerically unstable but gives the user no good insight of this issue. Can you describe the type of None errors that you are getting.

y_m = torch.nan_to_num(self.mean_encoder(y))
y_v = self.var_encoder(y, x_m=y_m)

outputs = {"y_m": y_m, "y_v": y_v}

# Sample from latent distribution
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
if self.sample:
y = Normal(y_m, y_v.sqrt()).rsample()
outputs["y"] = y

return outputs


class Layers(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this class is similar to the existing implementation of FCLayers, could you refactor Layers so that it subclasses FCLayers and overrides methods as needed? This will make it clearer what's different about this implementation as well as remove duplicate code in inject_into_layer and set_online_update_hooks, as it looks like these methods are identical to the original

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inheriting this will make it also easier for us to add new functionality to sysVI. Are there issues with the FClayers class?

"""A helper class to build fully-connected layers for a neural network.

Adapted from scVI FCLayers to use covariates more flexibly

Parameters
----------
n_in
The dimensionality of the main input
n_out
The dimensionality of the output
n_cov
Dimensionality of covariates.
If there are no cov this should be set to None -
in this case cov will not be used.
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
use_batch_norm
Whether to have `BatchNorm` layers or not
use_layer_norm
Whether to have `LayerNorm` layers or not
use_activation
Whether to have layer activation or not
bias
Whether to learn bias in linear layers or not
inject_covariates
Whether to inject covariates in each layer, or just the first.
activation_fn
Which activation function to use
"""

def __init__(
self,
n_in: int,
n_out: int,
n_cov: int | None = None,
n_layers: int = 1,
n_hidden: int = 128,
dropout_rate: float = 0.1,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
use_activation: bool = True,
bias: bool = True,
inject_covariates: bool = True,
activation_fn: Module = ReLU,
):
super().__init__()

self.inject_covariates = inject_covariates
self.n_cov = n_cov if n_cov is not None else 0

layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]

self.fc_layers = Sequential(
OrderedDict(
[
(
f"Layer {i}",
Sequential(
Linear(
n_in + self.n_cov * self.inject_into_layer(i),
n_out,
bias=bias,
),
# non-default params come from defaults in original Tensorflow implementation
BatchNorm1d(n_out, momentum=0.01, eps=0.001)
if use_batch_norm
else None,
LayerNorm(n_out, elementwise_affine=False)
if use_layer_norm
else None,
activation_fn() if use_activation else None,
Dropout(p=dropout_rate) if dropout_rate > 0 else None,
),
)
for i, (n_in, n_out) in enumerate(
zip(layers_dim[:-1], layers_dim[1:])
)
]
)
)

def inject_into_layer(self, layer_num) -> bool:
"""Helper to determine if covariates should be injected."""
user_cond = layer_num == 0 or (layer_num > 0 and self.inject_covariates)
return user_cond

def set_online_update_hooks(self, hook_first_layer=True):
self.hooks = []

def _hook_fn_weight(grad):
new_grad = torch.zeros_like(grad)
if self.n_cov > 0:
new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :]
return new_grad

def _hook_fn_zero_out(grad):
return grad * 0

for i, layers in enumerate(self.fc_layers):
for layer in layers:
if i == 0 and not hook_first_layer:
continue
if isinstance(layer, Linear):
if self.inject_into_layer(i):
w = layer.weight.register_hook(_hook_fn_weight)
else:
w = layer.weight.register_hook(_hook_fn_zero_out)
self.hooks.append(w)
b = layer.bias.register_hook(_hook_fn_zero_out)
self.hooks.append(b)

def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None):
"""
Forward computation on ``x``.

Parameters
----------
x
tensor of values with shape ``(n_in,)``
cov
tensor of covariate values with shape ``(n_cov,)`` or None

Returns
-------
py:class:`torch.Tensor`
tensor of shape ``(n_out,)``

"""
for i, layers in enumerate(self.fc_layers):
for layer in layers:
if layer is not None:
if isinstance(layer, BatchNorm1d):
if x.dim() == 3:
x = torch.cat(
[(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0
)
else:
x = layer(x)
else:
# Injection of covariates
if (
self.n_cov > 0
and isinstance(layer, Linear)
and self.inject_into_layer(i)
):
x = torch.cat((x, cov), dim=-1)
x = layer(x)
return x


class VarEncoder(Module):
"""Encode variance (strictly positive).

Parameters
----------
n_input
Number of input dimensions, used if mode is sample_feature
n_output
Number of variances to predict
mode
How to compute var
'sample_feature' - learn per sample and feature
'feature' - learn per feature, constant across samples
'linear' - linear with respect to input mean, var = a1 * mean + a0;
not suggested to be used due to bad implementation for positive constraining
Hrovatin marked this conversation as resolved.
Show resolved Hide resolved
eps
Hrovatin marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
n_input: int,
n_output: int,
mode: Literal["sample_feature", "feature", "linear"],
eps: float = 1e-4,
):
super().__init__()

self.eps = eps
self.mode = mode
if self.mode == "sample_feature":
self.encoder = Linear(n_input, n_output)
elif self.mode == "feature":
self.var_param = Parameter(torch.zeros(1, n_output))
elif self.mode == "linear":
self.var_param_a1 = Parameter(torch.tensor([1.0]))
self.var_param_a0 = Parameter(torch.tensor([self.eps]))
else:
raise ValueError("Mode not recognised.")
self.activation = torch.exp
Hrovatin marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, x: torch.Tensor, x_m: torch.Tensor):
"""Forward pass through model

Parameters
----------
x
Used to encode var if mode is sample_feature; dim = n_samples x n_input
x_m
Used to predict var instead of x if mode is linear; dim = n_samples x 1

Returns
-------
Predicted var
"""
# Force to be non nan - TODO come up with better way to do so
if self.mode == "sample_feature":
v = self.encoder(x)
v = (
torch.nan_to_num(self.activation(v)) + self.eps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to avoid nan_to_num

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would you suggest?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using softplus activation will be much more stable than exp. Do you need exp here for specific reasons? Clamping v would otherwise be safe (something like 20).

) # Ensure that var is strictly positive
elif self.mode == "feature":
v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size
v = (
torch.nan_to_num(self.activation(v)) + self.eps
) # Ensure that var is strictly positive
elif self.mode == "linear":
v = self.var_param_a1 * x_m.detach().clone() + self.var_param_a0
# TODO come up with a better way to constrain this to positive while having lin relationship
# Could activation be used for log-lin relationship?
v = torch.clamp(torch.nan_to_num(v), min=self.eps)
Hrovatin marked this conversation as resolved.
Show resolved Hide resolved
return v
Loading
Loading