Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding model with cycle consistency and VampPrior #2421

Draft
wants to merge 76 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
a582251
add model and tests
Hrovatin Jan 14, 2024
bb359ac
update documentation
Hrovatin Jan 14, 2024
14e41f1
move embedding to device
Hrovatin Jan 14, 2024
5863b1f
Merge branch 'main' into main
martinkim0 Jan 19, 2024
3f49266
pr comments
Hrovatin Jan 21, 2024
6605682
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Jan 21, 2024
3319fc8
Merge branch 'main' into main
martinkim0 Jan 22, 2024
3a67d1c
Merge branch 'main' into main
martinkim0 Feb 5, 2024
5edad83
updates
Hrovatin Feb 7, 2024
a4b080e
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Feb 7, 2024
a24ef28
Merge branch 'main' into main
martinkim0 Feb 20, 2024
9b05bca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
8c25dba
Merge branch 'main' into main
martinkim0 Mar 11, 2024
c5f5c37
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
5b4838c
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
c885e20
Update scvi/external/sysvi/_base_components.py
martinkim0 Mar 15, 2024
661bbc6
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
9a49d24
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
3622eee
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
e4c1ef9
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
0f7bd06
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
9e0cba9
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
54f5734
Update scvi/external/sysvi/_model.py
martinkim0 Mar 15, 2024
df959ed
merge
Hrovatin Sep 14, 2024
f65c403
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
5fcb6f6
merge
Hrovatin Sep 14, 2024
3f1cffe
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Sep 14, 2024
3c93e7f
merge
Hrovatin Sep 14, 2024
3e8ffd0
extend var documentation and remove unused "linear" mode
Hrovatin Sep 14, 2024
4ce4614
update var documentation
Hrovatin Sep 14, 2024
f81abda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
bdd2c7a
var activation to softplus
Hrovatin Sep 14, 2024
4507602
clarify pseudoinputs_data_indices description and assert it matches s…
Hrovatin Sep 14, 2024
cc19d91
also check ndim for pseudoinputs_data_indices
Hrovatin Sep 14, 2024
c15b8a2
remove obtainng cycle latent representations in user interface
Hrovatin Sep 14, 2024
aacdafe
latent always returns np and add option to return_dist
Hrovatin Sep 14, 2024
de6db13
bugfix
Hrovatin Sep 14, 2024
235767c
rm unused cycle latent retrieval
Hrovatin Sep 14, 2024
5f04d65
bugfix
Hrovatin Sep 14, 2024
69759ec
rm adata validation parts repeated in super
Hrovatin Sep 14, 2024
2e447c8
bugfix
Hrovatin Sep 14, 2024
9e8cc35
put back original _validate_anndata
Hrovatin Sep 14, 2024
b0829c0
remove too many custom checks from adata validation
Hrovatin Sep 14, 2024
bf2e850
covariate type explanation
Hrovatin Sep 14, 2024
bd5a882
revert var exp
Hrovatin Sep 14, 2024
7f50353
bugfix
Hrovatin Sep 14, 2024
01db60a
rm reconstr loss f that not reused
Hrovatin Sep 14, 2024
6cdd3d6
explain why cycle loss computed on standardized values
Hrovatin Sep 14, 2024
62eb924
add return statement
Hrovatin Sep 14, 2024
b9a9047
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Sep 14, 2024
b95cb03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
9a2a33e
Merge branch 'main' of https://github.com/scverse/scvi-tools
Hrovatin Oct 19, 2024
c049ea1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2024
041c6a4
unify SysVI with scvi-tools code
Hrovatin Oct 20, 2024
7743c19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2024
1641654
use adatamanager to access filed statistics
Hrovatin Oct 26, 2024
7c3bddc
documentation
Hrovatin Oct 26, 2024
eb872bc
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 26, 2024
5015e82
optionally change var activation function
Hrovatin Oct 27, 2024
b64b2ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
9b78ba3
documentation improvements
Hrovatin Oct 27, 2024
2aab4b4
Use real instead of mock covariates in cycle
Hrovatin Oct 27, 2024
1b50be1
improve documentation
Hrovatin Oct 27, 2024
1abf15d
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 27, 2024
6d32d89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
2e59be7
improve documentation
Hrovatin Oct 27, 2024
4e0988e
fix bug introduced when renaming parameters in generative function
Hrovatin Oct 27, 2024
f0eb2b0
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 27, 2024
3f8be47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
776a8e2
rename tests to prevent automatic test failure
Hrovatin Oct 27, 2024
ebbe1ca
Merge branch 'main' of https://github.com/Hrovatin/scvi-tools
Hrovatin Oct 27, 2024
35e3385
fix typo in docstring and formatting
Hrovatin Oct 27, 2024
4601808
ruff fixes
Hrovatin Oct 27, 2024
87a367e
bugfix in embedding of covariates
Hrovatin Nov 1, 2024
df4b335
bugfix in test for checking cov embeding
Hrovatin Nov 1, 2024
6a2e3eb
Change var activation to softplus
Hrovatin Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scvi/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cellassign import CellAssign
from .contrastivevi import ContrastiveVI
from .csi.model import Model as SysVI
from .gimvi import GIMVI
from .poissonvi import POISSONVI
from .scar import SCAR
Expand All @@ -19,4 +20,5 @@
"SCBASSET",
"POISSONVI",
"ContrastiveVI",
"SysVI",
]
7 changes: 7 additions & 0 deletions scvi/external/csi/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._model import (
Model,
)

__all__ = [
"Model",
]
322 changes: 322 additions & 0 deletions scvi/external/csi/model/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
import logging
from collections.abc import Sequence
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch
from anndata import AnnData
from typing_extensions import Literal

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
LayerField,
ObsmField,
)
from scvi.external.csi.module import Module
from scvi.model.base import BaseModelClass
from scvi.utils import setup_anndata_dsp

from ._training import TrainingCustom
from ._utils import prepare_metadata

logger = logging.getLogger(__name__)


class Model(TrainingCustom, BaseModelClass):
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
adata: AnnData,
prior: Literal["standard_normal", "vamp"] = "vamp",
n_prior_components=5,
pseudoinputs_data_indices: Optional[np.array] = None,
**model_kwargs,
):
"""Integration model based on cVAE with optional VampPrior and latent cycle-consistency loss.

Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi-tools.SysVI.setup_anndata`.
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
prior
The prior distribution to be used. You can choose between "standard_normal" and "vamp".
n_prior_components
Number of prior components in VampPrior.
pseudoinputs_data_indices
By default VampPrior pseudoinputs are randomly selected from data.
Alternatively, one can specify pseudoinput indices using this parameter.
**model_kwargs
Keyword args for :class:`~scvi.external.csi.module.Module`
"""
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(adata)

if prior == "vamp":
if pseudoinputs_data_indices is None:
pseudoinputs_data_indices = np.random.randint(
0, adata.shape[0], n_prior_components
)
pseudoinput_data = next(
iter(
self._make_data_loader(
adata=adata,
indices=pseudoinputs_data_indices,
batch_size=n_prior_components,
shuffle=False,
)
)
)
else:
pseudoinput_data = None

n_cov_const = (
adata.obsm["covariates"].shape[1] if "covariates" in adata.obsm else 0
)
cov_embed_sizes = (
pd.DataFrame(adata.obsm["covariates_embed"]).nunique(axis=0).to_list()
if "covariates_embed" in adata.obsm
else []
)

# self.summary_stats provides information about anndata dimensions and other tensor info
self.module = Module(
n_input=adata.shape[1],
n_cov_const=n_cov_const,
cov_embed_sizes=cov_embed_sizes,
n_system=adata.obsm["system"].shape[1],
prior=prior,
n_prior_components=n_prior_components,
pseudoinput_data=pseudoinput_data,
**model_kwargs,
)

self._model_summary_string = (
"cVAE model with optional VampPrior and latent cycle-consistency loss"
)
# necessary line to get params that will be used for saving/loading
self.init_params_ = self._get_init_params(locals())

logger.info("The model has been initialized")

@torch.no_grad()
def embed(
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
self,
adata: AnnData,
indices: Optional[Sequence[int]] = None,
cycle: bool = False,
give_mean: bool = True,
batch_size: Optional[int] = None,
as_numpy: bool = True,
) -> Union[np.ndarray, torch.Tensor]:
"""Return the latent representation for each cell.

Parameters
----------
adata
Input adata for which latent representation should be obtained.
indices
Data indices to embed. If None embedd all cells.
cycle
Return latent embedding of the cycle pass.
give_mean
Return the posterior mean instead of a sample from the posterior.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
as_numpy
Return in numpy rather than torch format.

Returns
-------
Latent Embedding
"""
# Check model and adata
self._check_if_trained(warn=False)
# TODO extend to check if adata setup is correct wrt training data
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
# Prediction
# Do not shuffle to retain order
tensors_fwd = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size, shuffle=False
)
predicted = []
for tensors in tensors_fwd:
# Inference
inference_inputs = self.module._get_inference_input(tensors)
inference_outputs = self.module.inference(**inference_inputs)
if cycle:
selected_system = self.module.random_select_systems(tensors["system"])
generative_inputs = self.module._get_generative_input(
tensors,
inference_outputs,
selected_system=selected_system,
)
generative_outputs = self.module.generative(
**generative_inputs, x_x=False, x_y=True
)
inference_cycle_inputs = self.module._get_inference_cycle_input(
tensors=tensors,
generative_outputs=generative_outputs,
selected_system=selected_system,
)
inference_outputs = self.module.inference(**inference_cycle_inputs)
if give_mean:
predicted += [inference_outputs["z_m"]]
else:
predicted += [inference_outputs["z"]]

predicted = torch.cat(predicted)

if as_numpy:
predicted = predicted.cpu().numpy()
return predicted

@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
system_key: str,
layer: Optional[str] = None,
categorical_covariate_keys: Optional[list[str]] = None,
categorical_covariate_embed_keys: Optional[list[str]] = None,
continuous_covariate_keys: Optional[list[str]] = None,
covariate_categ_orders: Optional[dict] = None,
covariate_key_orders: Optional[dict] = None,
system_order: Optional[list[str]] = None,
**kwargs,
) -> AnnData:
"""Prepare adata for input to Model

Parameters
----------
adata
Adata object - will be modified in place.
system_key
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
Name of obs column with categorical system information.
layer
AnnData layer to use, default is X.
Should contain normalized and log+1 transformed expression.
categorical_covariate_keys
Name of obs column with additional categorical covariate information. Will be one hot encoded.
categorical_covariate_embed_keys
Name of obs column with additional categorical covariate information. Embedding will be learned.
This can be useful if the number of categories is very large, which would increase memory usage.
If using this type of covariate representation please also cite
`scPoli <[https://doi.org/10.1038/s41592-023-02035-2]>`_ .
continuous_covariate_keys
Name of obs column with additional continuous covariate information.
covariate_categ_orders
Covariate encoding information. Should be used if a new adata is to be set up according
to setup of an existing adata. Access via adata.uns['covariate_categ_orders'] of already setup adata.
covariate_key_orders
Covariate encoding information. Should be used if a new adata is to be set up according
to setup of an existing adata. Access via adata.uns['covariate_key_orders'] of already setup adata.
system_order
Same as covariate_orders, but for system. Access via adata.uns['system_order']
"""
setup_method_args = cls._get_setup_method_args(**locals())

# Make sure var names are unique
if adata.shape[1] != len(set(adata.var_names)):
raise ValueError("Adata var_names are not unique")

# If setup is to be prepared wtr another adata specs make sure all relevant info is present
if covariate_categ_orders or covariate_key_orders or system_order:
assert system_order is not None
if (
categorical_covariate_keys is not None
or categorical_covariate_embed_keys is not None
or continuous_covariate_keys is not None
):
assert covariate_categ_orders is not None
assert covariate_key_orders is not None

# Make system embedding with specific category order

# Define order of system categories
if system_order is None:
system_order = sorted(adata.obs[system_key].unique())
# Validate that the provided system_order matches the categories in adata.obs[system_key]
if set(system_order) != set(adata.obs[system_key].unique()):
raise ValueError(
"Provided system_order does not match the categories in adata.obs[system_key]"
)

# Make one-hot embedding with specified order
systems_dict = dict(
zip(system_order, ([float(i) for i in range(0, len(system_order))]))
)
adata.uns["system_order"] = system_order
system_cat = pd.Series(
pd.Categorical(
values=adata.obs[system_key], categories=system_order, ordered=True
),
index=adata.obs.index,
name="system",
)
adata.obsm["system"] = pd.get_dummies(system_cat, dtype=float)

# Set up covariates
# TODO this could be handled by specific field type in registry

# System must not be in cov
if categorical_covariate_keys is not None:
if system_key in categorical_covariate_keys:
raise ValueError("system_key should not be within covariate keys")
if categorical_covariate_embed_keys is not None:
if system_key in categorical_covariate_embed_keys:
raise ValueError("system_key should not be within covariate keys")
if continuous_covariate_keys is not None:
if system_key in continuous_covariate_keys:
raise ValueError("system_key should not be within covariate keys")

# Prepare covariate training representations/embedding
covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata(
meta_data=adata.obs,
cov_cat_keys=categorical_covariate_keys,
cov_cat_embed_keys=categorical_covariate_embed_keys,
cov_cont_keys=continuous_covariate_keys,
categ_orders=covariate_categ_orders,
key_orders=covariate_key_orders,
)

# Save covariate representation and order information
adata.uns["covariate_categ_orders"] = orders_dict
adata.uns["covariate_key_orders"] = cov_dict
if (
continuous_covariate_keys is not None
or categorical_covariate_keys is not None
):
adata.obsm["covariates"] = covariates
else:
# Remove if present since the presence of this key
# is in model used to determine if cov should be used or not
if "covariates" in adata.obsm:
del adata.obsm["covariates"]
if categorical_covariate_embed_keys is not None:
adata.obsm["covariates_embed"] = covariates_embed
else:
if "covariates_embed" in adata.obsm:
del adata.obsm["covariates_embed"]

# Anndata setup

anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=False),
ObsmField("system", "system"),
]
# Covariate fields are optional
if (
continuous_covariate_keys is not None
or categorical_covariate_keys is not None
):
anndata_fields.append(ObsmField("covariates", "covariates"))
if categorical_covariate_embed_keys is not None:
anndata_fields.append(ObsmField("covariates_embed", "covariates_embed"))
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
14 changes: 14 additions & 0 deletions scvi/external/csi/model/_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from scvi.external.csi.train import TrainingPlanCustom
from scvi.model.base import UnsupervisedTrainingMixin


class TrainingCustom(UnsupervisedTrainingMixin):
"""Train method with custom TrainingPlan."""

# TODO could make custom Trainer (in a custom TrainRunner) to have in init params for early stopping
# "loss" rather than "elbo" components in available param specifications - for now just use
# a loss that is against the param specification

# TODO run and log val before training - already tried some solutions by calling trainer.validate before
martinkim0 marked this conversation as resolved.
Show resolved Hide resolved
# fit and num_sanity_val_steps (designed not to log)
_training_plan_cls = TrainingPlanCustom
Loading
Loading