From 9fd03ce1c8844492210fc1b020caa19bcea10f3e Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 10 Oct 2024 14:06:54 -0400 Subject: [PATCH 01/40] first draft of moving decipher model into scvi-tools --- src/scvi/external/__init__.py | 2 + src/scvi/external/decipher/__init__.py | 4 + src/scvi/external/decipher/_components.py | 110 ++++++++++++++ src/scvi/external/decipher/_model.py | 104 ++++++++++++++ src/scvi/external/decipher/_module.py | 168 ++++++++++++++++++++++ src/scvi/module/base/_base_module.py | 40 ++++-- src/scvi/train/_trainingplans.py | 119 ++++++++++++--- tests/external/decipher/test_decipher.py | 22 +++ tests/external/mrvi/test_model.py | 8 +- 9 files changed, 547 insertions(+), 30 deletions(-) create mode 100644 src/scvi/external/decipher/__init__.py create mode 100644 src/scvi/external/decipher/_components.py create mode 100644 src/scvi/external/decipher/_model.py create mode 100644 src/scvi/external/decipher/_module.py create mode 100644 tests/external/decipher/test_decipher.py diff --git a/src/scvi/external/__init__.py b/src/scvi/external/__init__.py index 4e46ca1846..c54253c2da 100644 --- a/src/scvi/external/__init__.py +++ b/src/scvi/external/__init__.py @@ -1,5 +1,6 @@ from .cellassign import CellAssign from .contrastivevi import ContrastiveVI +from .decipher import Decipher from .gimvi import GIMVI from .methylvi import METHYLVI from .mrvi import MRVI @@ -15,6 +16,7 @@ "SCAR", "SOLO", "GIMVI", + "Decipher", "RNAStereoscope", "SpatialStereoscope", "CellAssign", diff --git a/src/scvi/external/decipher/__init__.py b/src/scvi/external/decipher/__init__.py new file mode 100644 index 0000000000..d1a6049056 --- /dev/null +++ b/src/scvi/external/decipher/__init__.py @@ -0,0 +1,4 @@ +from ._model import Decipher +from ._module import DecipherPyroModule + +__all__ = ["Decipher", "DecipherPyroModule"] diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py new file mode 100644 index 0000000000..2951528aca --- /dev/null +++ b/src/scvi/external/decipher/_components.py @@ -0,0 +1,110 @@ +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn + + +class ConditionalDenseNN(nn.Module): + """Dense neural network with multiple outputs, optionally conditioned on a context variable. + + (Derived from pyro.nn.dense_nn.ConditionalDenseNN with some modifications [1]) + + Parameters + ---------- + input_dim : int + Dimension of the input + hidden_dims : sequence of ints + Dimensions of the hidden layers (excluding the output layer) + output_dims : sequence of ints (optional) + Dimensions of each output layer + Default: (1,) + context_dim : int (optional) + Dimension of the context input. + Default: 0. No context input. + deep_context_injection : bool (optional) + If True, inject the context into every hidden layer. + If False, only inject the context into the first hidden layer (concatenated with the input). + Default: False. + activation : torch.nn.Module (optional) + Activation function to use between hidden layers (not applied to the outputs). + Default: torch.nn.ReLU() + """ + + def __init__( + self, + input_dim: int, + hidden_dims: Sequence[int], + output_dims: Sequence = (1,), + context_dim: int = 0, + deep_context_injection: bool = False, + activation=torch.nn.ReLU(), + ): + super().__init__() + + self.input_dim = input_dim + self.context_dim = context_dim + self.hidden_dims = hidden_dims + self.output_dims = output_dims + self.deep_context_injection = deep_context_injection + self.n_output_layers = len(self.output_dims) + self.output_total_dim = sum(self.output_dims) + + # The multiple outputs are computed as a single output layer, and then split + indices = np.concatenate(([0], np.cumsum(self.output_dims))) + self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:])] + + # Create masked layers + deep_context_dim = self.context_dim if self.deep_context_injection else 0 + layers = [] + batch_norms = [] + if len(hidden_dims): + layers.append(torch.nn.Linear(input_dim + context_dim, hidden_dims[0])) + batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) + for i in range(1, len(hidden_dims)): + layers.append( + torch.nn.Linear( + hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] + ) + ) + batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) + + layers.append( + torch.nn.Linear( + hidden_dims[-1] + deep_context_dim, self.output_total_dim + ) + ) + else: + layers.append( + torch.nn.Linear(input_dim + context_dim, self.output_total_dim) + ) + + self.layers = torch.nn.ModuleList(layers) + + self.f = activation + self.batch_norms = torch.nn.ModuleList(batch_norms) + + def forward(self, x, context=None): + if context is not None: + # We must be able to broadcast the size of the context over the input + context = context.expand(x.size()[:-1] + (context.size(-1),)) + + h = x + for i, layer in enumerate(self.layers): + if self.context_dim > 0 and (self.deep_context_injection or i == 0): + h = torch.cat([context, h], dim=-1) + h = layer(h) + if i < len(self.layers) - 1: + h = self.batch_norms[i](h) + h = self.f(h) + + if self.n_output_layers == 1: + return h + else: + h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim]) + + if self.n_output_layers == 1: + return h + + else: + return tuple([h[..., s] for s in self.output_slices]) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py new file mode 100644 index 0000000000..2381ff8d85 --- /dev/null +++ b/src/scvi/external/decipher/_model.py @@ -0,0 +1,104 @@ +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import logging + +from anndata import AnnData +import pyro + +from scvi._constants import REGISTRY_KEYS +from scvi.data import AnnDataManager +from scvi.data.fields import LayerField, CategoricalJointObsField +from scvi.utils import setup_anndata_dsp +from scvi.train import PyroTrainingPlan + +from scvi.model.base import BaseModelClass, PyroSviTrainMixin + +from ._module import DecipherPyroModule + +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + +logger = logging.getLogger(__name__) + + +class Decipher(PyroSviTrainMixin, BaseModelClass): + _module_cls = DecipherPyroModule + + def __init__(self, adata: AnnData, **kwargs): + pyro.clear_param_store() + + super().__init__(adata) + + dim_genes = self.summary_stats.n_vars + + self.module = self._module_cls( + dim_genes, + **kwargs, + ) + + self.init_params = self._get_init_params(locals()) + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + layer: str | None = None, + **kwargs, + ) -> AnnData | None: + """%(summary)s. + + Parameters + ---------- + %(param_adata)s + %(param_layer)s + """ + + setup_method_args = cls._get_setup_method_args(**locals()) + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + ] + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) + + def train( + self, + max_epochs: int | None = None, + accelerator: str = "auto", + device: int | str = "auto", + train_size: float = 0.9, + validation_size: float | None = None, + shuffle_set_split: bool = True, + batch_size: int = 128, + early_stopping: bool = False, + lr: float | None = None, + training_plan: PyroTrainingPlan | None = None, + datasplitter_kwargs: dict | None = None, + plan_kwargs: dict | None = None, + **trainer_kwargs, + ): + optim_kwargs = trainer_kwargs.pop("optim_kwargs", {}) + optim_kwargs.update({"lr": lr or 5e-3, "weight_decay": 1e-4}) + optim = pyro.optim.ClippedAdam(optim_kwargs) + plan_kwargs = plan_kwargs or {} + plan_kwargs.update({"optim": optim}) + super().train( + max_epochs=max_epochs, + accelerator=accelerator, + device=device, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + batch_size=batch_size, + early_stopping=early_stopping, + plan_kwargs=plan_kwargs, + training_plan=training_plan, + datasplitter_kwargs=datasplitter_kwargs, + **trainer_kwargs, + ) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py new file mode 100644 index 0000000000..861bf9025a --- /dev/null +++ b/src/scvi/external/decipher/_module.py @@ -0,0 +1,168 @@ +from collections.abc import Iterable, Sequence + +import numpy as np +import pyro +import pyro.distributions as dist +import pyro.poutine as poutine +import torch +import torch.nn as nn +import torch.utils.data +from torch.distributions import constraints +from torch.nn.functional import softmax, softplus + +from scvi._constants import REGISTRY_KEYS +from scvi.module.base import PyroBaseModuleClass, auto_move_data + +from ._components import ConditionalDenseNN + + +class DecipherPyroModule(PyroBaseModuleClass): + """Decipher _decipher for single-cell data. + + Parameters + ---------- + config : DecipherConfig or dict + Configuration for the decipher _decipher. + """ + + def __init__( + self, + dim_genes: int, + dim_v: int = 2, + dim_z: int = 10, + layers_v_to_z: Sequence[int] = (64,), + layers_z_to_x: Sequence[int] = tuple(), + beta: float = 0.1, + prior: str = "normal", + ): + super().__init__() + self.dim_v = dim_v + self.dim_z = dim_z + self.dim_genes = dim_genes + self.layers_v_to_z = layers_v_to_z + self.layers_z_to_x = layers_z_to_x + self.beta = beta + self.prior = prior + + self.decoder_v_to_z = ConditionalDenseNN(dim_v, layers_v_to_z, [dim_z] * 2) + self.decoder_z_to_x = ConditionalDenseNN(dim_z, layers_z_to_x, [dim_genes]) + self.encoder_x_to_z = ConditionalDenseNN(dim_genes, [128], [dim_z] * 2) + self.encoder_zx_to_v = ConditionalDenseNN( + dim_genes + dim_z, + [128], + [dim_v, dim_v], + ) + + self.theta = None + + self._epsilon = 1e-5 + + # Hack: to allow auto_move_data to infer device. + self._dummy_param = nn.Parameter(torch.empty(0), requires_grad=False) + + @property + def device(self): + return self._dummy_param.device + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: + x = tensor_dict[REGISTRY_KEYS.X_KEY] + return (x,), {} + + @auto_move_data + def model(self, x: torch.Tensor): + pyro.module("decipher", self) + + self.theta = pyro.param( + "theta", + x.new_ones(self.dim_genes), + constraint=constraints.positive, + ) + + with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + with poutine.scale(scale=self.beta): + if self.prior == "normal": + prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) + elif self.prior == "gamma": + prior = dist.Gamma(0.3, x.new_ones(self.dim_v) * 0.8).to_event(1) + else: + raise ValueError("Invalid prior, must be normal or gamma") + v = pyro.sample("v", prior) + + z_loc, z_scale = self.decoder_v_to_z(v) + z_scale = softplus(z_scale) + z = pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1)) + + mu = self.decoder_z_to_x(z) + mu = softmax(mu, dim=-1) + library_size = x.sum(axis=-1, keepdim=True) + # Parametrization of Negative Binomial by the mean and inverse dispersion + # See https://github.com/pytorch/pytorch/issues/42449 + # noinspection PyTypeChecker + logit = torch.log(library_size * mu + self._epsilon) - torch.log( + self.theta + self._epsilon + ) + # noinspection PyUnresolvedReferences + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) + pyro.sample("x", x_dist.to_event(1), obs=x) + + @auto_move_data + def guide(self, x: torch.Tensor): + pyro.module("decipher", self) + with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + x = torch.log1p(x) + + z_loc, z_scale = self.encoder_x_to_z(x) + z_scale = softplus(z_scale) + posterior_z = dist.Normal(z_loc, z_scale).to_event(1) + z = pyro.sample("z", posterior_z) + + zx = torch.cat([z, x], dim=-1) + v_loc, v_scale = self.encoder_zx_to_v(zx) + v_scale = softplus(v_scale) + with poutine.scale(scale=self.beta): + if self.prior == "gamma": + posterior_v = dist.Gamma(softplus(v_loc), v_scale).to_event(1) + elif self.prior == "normal" or self.prior == "student-normal": + posterior_v = dist.Normal(v_loc, v_scale).to_event(1) + else: + raise ValueError("Invalid prior, must be normal or gamma") + pyro.sample("v", posterior_v) + return z_loc, v_loc, z_scale, v_scale + + def compute_v_z_numpy(self, x: np.array): + """Compute decipher_v and decipher_z for a given input. + + Parameters + ---------- + x : np.ndarray or torch.Tensor + Input data of shape (n_cells, n_genes). + + Returns + ------- + v : np.ndarray + Decipher components v of shape (n_cells, dim_v). + z : np.ndarray + Decipher latent z of shape (n_cells, dim_z). + """ + if type(x) == np.ndarray: + x = torch.tensor(x, dtype=torch.float32) + + x = torch.log1p(x) + z_loc, _ = self.encoder_x_to_z(x) + zx = torch.cat([z_loc, x], dim=-1) + v_loc, _ = self.encoder_zx_to_v(zx) + return v_loc.detach().numpy(), z_loc.detach().numpy() + + def impute_gene_expression_numpy(self, x): + if type(x) == np.ndarray: + x = torch.tensor(x, dtype=torch.float32) + z_loc, _, _, _ = self.guide(x) + mu = self.decoder_z_to_x(z_loc) + mu = softmax(mu, dim=-1) + library_size = x.sum(axis=-1, keepdim=True) + return (library_size * mu).detach().numpy() diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index 39097c9039..e9d97dccdd 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -96,7 +96,9 @@ def __post_init__(self): object.__setattr__(self, "loss", self.dict_sum(self.loss)) if self.n_obs_minibatch is None and self.reconstruction_loss is None: - raise ValueError("Must provide either n_obs_minibatch or reconstruction_loss") + raise ValueError( + "Must provide either n_obs_minibatch or reconstruction_loss" + ) default = 0 * self.loss if self.reconstruction_loss is None: @@ -106,7 +108,9 @@ def __post_init__(self): if self.kl_global is None: object.__setattr__(self, "kl_global", default) - object.__setattr__(self, "reconstruction_loss", self._as_dict("reconstruction_loss")) + object.__setattr__( + self, "reconstruction_loss", self._as_dict("reconstruction_loss") + ) object.__setattr__(self, "kl_local", self._as_dict("kl_local")) object.__setattr__(self, "kl_global", self._as_dict("kl_global")) object.__setattr__( @@ -119,13 +123,16 @@ def __post_init__(self): if self.reconstruction_loss is not None and self.n_obs_minibatch is None: rec_loss = self.reconstruction_loss - object.__setattr__(self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0]) + object.__setattr__( + self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0] + ) if self.classification_loss is not None and ( self.logits is None or self.true_labels is None ): raise ValueError( - "Must provide `logits` and `true_labels` if `classification_loss` is " "provided." + "Must provide `logits` and `true_labels` if `classification_loss` is " + "provided." ) @staticmethod @@ -184,7 +191,10 @@ def forward( generative_kwargs: dict | None = None, loss_kwargs: dict | None = None, compute_loss=True, - ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]: + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, LossOutput] + ): """Forward pass through the network. Parameters @@ -348,7 +358,9 @@ def __init__(self, on_load_kwargs: dict | None = None): @staticmethod @abstractmethod - def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: """Parse the minibatched data to get the correct inputs for ``model`` and ``guide``. In Pyro, ``model`` and ``guide`` must have the same signature. This is a helper method @@ -649,7 +661,9 @@ def load_state_dict(self, state_dict: dict[str, Any]): raise RuntimeError( "Train state is not set. Train for one iteration prior to loading state dict." ) - self.train_state = flax.serialization.from_state_dict(self.train_state, state_dict) + self.train_state = flax.serialization.from_state_dict( + self.train_state, state_dict + ) def to(self, device: Device): """Move module to device.""" @@ -677,7 +691,9 @@ def get_jit_inference_fn( self, get_inference_input_kwargs: dict[str, Any] | None = None, inference_kwargs: dict[str, Any] | None = None, - ) -> Callable[[dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]: + ) -> Callable[ + [dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray] + ]: """Create a method to run inference using the bound module. Parameters @@ -744,14 +760,18 @@ def _generic_forward( get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) - inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) + inference_inputs = module._get_inference_input( + tensors, **get_inference_input_kwargs + ) inference_outputs = module.inference(**inference_inputs, **inference_kwargs) generative_inputs = module._get_generative_input( tensors, inference_outputs, **get_generative_input_kwargs ) generative_outputs = module.generative(**generative_inputs, **generative_kwargs) if compute_loss: - losses = module.loss(tensors, inference_outputs, generative_outputs, **loss_kwargs) + losses = module.loss( + tensors, inference_outputs, generative_outputs, **loss_kwargs + ) return inference_outputs, generative_outputs, losses else: return inference_outputs, generative_outputs diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 79aa4bf0e3..5067d9cdbf 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -182,7 +182,9 @@ def __init__( self.optimizer_creator = optimizer_creator if self.optimizer_name == "Custom" and self.optimizer_creator is None: - raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") + raise ValueError( + "If optimizer is 'Custom', `optimizer_creator` must be provided." + ) self._n_obs_training = None self._n_obs_validation = None @@ -219,7 +221,9 @@ def initialize_train_metrics(self): self.kl_local_train, self.kl_global_train, self.train_metrics, - ) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training) + ) = self._create_elbo_metric_components( + mode="train", n_total=self.n_obs_training + ) self.elbo_train.reset() def initialize_val_metrics(self): @@ -230,7 +234,9 @@ def initialize_val_metrics(self): self.kl_local_val, self.kl_global_val, self.val_metrics, - ) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation) + ) = self._create_elbo_metric_components( + mode="validation", n_total=self.n_obs_validation + ) self.elbo_val.reset() @property @@ -366,7 +372,9 @@ def validation_step(self, batch, batch_idx): ) self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") - def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW): + def _optimizer_creator_fn( + self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW + ): """Create optimizer for the model. This type of function can be passed as the `optimizer_creator` @@ -544,7 +552,9 @@ def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True): if predict_true_class: cls_target = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) else: - one_hot_batch = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) + one_hot_batch = torch.nn.functional.one_hot( + batch_index.squeeze(-1), n_classes + ) # place zeroes where true label is cls_target = (~one_hot_batch.bool()).float() cls_target = cls_target / (n_classes - 1) @@ -572,7 +582,9 @@ def training_step(self, batch, batch_idx): else: opt1, opt2 = opts - inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) + inference_outputs, _, scvi_loss = self.forward( + batch, loss_kwargs=self.loss_kwargs + ) z = inference_outputs["z"] loss = scvi_loss.loss # fool classifier if doing adversarial training @@ -605,7 +617,10 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self) -> None: """Update the learning rate via scheduler steps.""" - if not self.reduce_lr_on_plateau or "validation" not in self.lr_scheduler_metric: + if ( + not self.reduce_lr_on_plateau + or "validation" not in self.lr_scheduler_metric + ): return else: sch = self.lr_schedulers() @@ -636,7 +651,9 @@ def configure_optimizers(self): ) if self.adversarial_classifier is not False: - params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters()) + params2 = filter( + lambda p: p.requires_grad, self.adversarial_classifier.parameters() + ) optimizer2 = torch.optim.Adam( params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay ) @@ -889,6 +906,7 @@ def __init__( super().__init__() self.module = pyro_module self._n_obs_training = None + self._n_obs_validation = None optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): @@ -901,15 +919,18 @@ def __init__( self.n_epochs_kl_warmup = n_epochs_kl_warmup self.use_kl_weight = False if isinstance(self.module.model, PyroModule): - self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters + self.use_kl_weight = ( + "kl_weight" in signature(self.module.model.forward).parameters + ) elif callable(self.module.model): self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters self.scale_elbo = scale_elbo - self.scale_fn = ( - lambda obj: pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj + self.scale_fn = lambda obj: ( + pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj ) self.differentiable_loss_fn = self.loss_fn.differentiable_loss self.training_step_outputs = [] + self.validation_step_outputs = [] def training_step(self, batch, batch_idx): """Training step for Pyro training.""" @@ -942,6 +963,37 @@ def on_train_epoch_end(self): self.log("elbo_train", elbo, prog_bar=True) self.training_step_outputs.clear() + def validation_step(self, batch, batch_idx): + """Validation step for Pyro training.""" + args, kwargs = self.module._get_fn_args_from_batch(batch) + # Set KL weight if necessary. + # Note: if applied, ELBO loss in progress bar is the effective KL annealed loss, not the + # true ELBO. + if self.use_kl_weight: + kwargs.update({"kl_weight": self.kl_weight}) + # pytorch lightning requires a Tensor object for loss + loss = self.differentiable_loss_fn( + self.scale_fn(self.module.model), + self.scale_fn(self.module.guide), + *args, + **kwargs, + ) + out_dict = {"loss": loss} + self.validation_step_outputs.append(out_dict) + return out_dict + + def on_validation_epoch_end(self): + """Validation epoch end for Pyro training.""" + outputs = self.validation_step_outputs + elbo = 0 + n = 0 + for out in outputs: + elbo += out["loss"] + n += 1 + elbo /= n + self.log("elbo_validation", elbo, prog_bar=True) + self.validation_step_outputs.clear() + def configure_optimizers(self): """Configure optimizers for the model.""" return self.optim(self.module.parameters(), **self.optim_kwargs) @@ -981,6 +1033,27 @@ def n_obs_training(self, n_obs: int): self._n_obs_training = n_obs + @property + def n_obs_validation(self): + """Number of validation examples. + + If not `None`, updates the `n_obs` attr + of the Pyro module's `model` and `guide`, if they exist. + """ + return self._n_obs_validation + + @n_obs_validation.setter + def n_obs_validation(self, n_obs: int): + if n_obs is not None: + if hasattr(self.module, "n_obs"): + self.module.n_obs = n_obs + if hasattr(self.module.model, "n_obs"): + self.module.model.n_obs = n_obs + if hasattr(self.module.guide, "n_obs"): + self.module.guide.n_obs = n_obs + + self._n_obs_validation = n_obs + class PyroTrainingPlan(LowLevelPyroTrainingPlan): """Lightning module task to train Pyro scvi-tools modules. @@ -1029,7 +1102,9 @@ def __init__( optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): optim_kwargs.update({"lr": 1e-3}) - self.optim = pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim + self.optim = ( + pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim + ) # We let SVI take care of all optimization self.automatic_optimization = False @@ -1125,7 +1200,9 @@ def __init__( self.loss_fn = loss() if self.module.logits is False and loss == torch.nn.CrossEntropyLoss: - raise UserWarning("classifier should return logits when using CrossEntropyLoss.") + raise UserWarning( + "classifier should return logits when using CrossEntropyLoss." + ) def forward(self, *args, **kwargs): """Passthrough to the module's forward function.""" @@ -1155,7 +1232,9 @@ def configure_optimizers(self): optim_cls = torch.optim.AdamW else: raise ValueError("Optimizer not understood.") - optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay) + optimizer = optim_cls( + params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay + ) return optimizer @@ -1221,7 +1300,11 @@ def __init__( def get_optimizer_creator(self) -> JaxOptimizerCreator: """Get optimizer creator for the model.""" - clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity() + clip_by = ( + optax.clip_by_global_norm(self.max_norm) + if self.max_norm + else optax.identity() + ) if self.optimizer_name == "Adam": # Replicates PyTorch Adam defaults optim = optax.chain( @@ -1275,9 +1358,9 @@ def loss_fn(params): loss = loss_output.loss return loss, (loss_output, new_model_state) - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)( - state.params - ) + (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( + loss_fn, has_aux=True + )(state.params) new_state = state.apply_gradients(grads=grads, state=new_model_state) return new_state, loss, loss_output diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py new file mode 100644 index 0000000000..de05730eb3 --- /dev/null +++ b/tests/external/decipher/test_decipher.py @@ -0,0 +1,22 @@ +import pytest +import pyro + +from scvi.data import synthetic_iid +from scvi.external import Decipher + + +@pytest.fixture(scope="session") +def adata(): + adata = synthetic_iid() + return adata + + +def test_decipher_train(adata): + Decipher.setup_anndata(adata) + model = Decipher(adata) + model.train( + max_epochs=1, + check_val_every_n_epoch=1, + train_size=0.5, + early_stopping=True, + ) diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 05edb27496..5e3147c87e 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -155,7 +155,9 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): +def test_mrvi_model_kwargs( + adata: AnnData, model_kwargs: dict[str, Any], save_path: str +): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -173,7 +175,9 @@ def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_pa def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) + model.differential_expression( + sample_cov_keys=sample_cov_keys, sample_subset=sample_subset + ) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From 3ad77a28786c1ded3b2f5de205034aa9bebe6691 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 16:42:24 -0400 Subject: [PATCH 02/40] add early stopping based on predictive nll and freeze batch norm after first epoch --- src/scvi/external/decipher/_model.py | 30 ++--- src/scvi/external/decipher/_module.py | 22 ++-- src/scvi/external/decipher/_trainingplan.py | 123 ++++++++++++++++++++ src/scvi/external/decipher/_utils.py | 45 +++++++ 4 files changed, 191 insertions(+), 29 deletions(-) create mode 100644 src/scvi/external/decipher/_trainingplan.py create mode 100644 src/scvi/external/decipher/_utils.py diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 2381ff8d85..971f07e8a6 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -1,24 +1,20 @@ -from collections.abc import Sequence -from typing import TYPE_CHECKING - import logging +from typing import TYPE_CHECKING -from anndata import AnnData import pyro +from anndata import AnnData from scvi._constants import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data.fields import LayerField, CategoricalJointObsField -from scvi.utils import setup_anndata_dsp -from scvi.train import PyroTrainingPlan - +from scvi.data.fields import LayerField from scvi.model.base import BaseModelClass, PyroSviTrainMixin +from scvi.train import PyroTrainingPlan +from scvi.utils import setup_anndata_dsp from ._module import DecipherPyroModule +from ._trainingplan import DecipherTrainingPlan if TYPE_CHECKING: - from collections.abc import Sequence - from anndata import AnnData logger = logging.getLogger(__name__) @@ -26,6 +22,7 @@ class Decipher(PyroSviTrainMixin, BaseModelClass): _module_cls = DecipherPyroModule + _training_plan_cls = DecipherTrainingPlan def __init__(self, adata: AnnData, **kwargs): pyro.clear_param_store() @@ -56,14 +53,11 @@ def setup_anndata( %(param_adata)s %(param_layer)s """ - setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -77,17 +71,13 @@ def train( shuffle_set_split: bool = True, batch_size: int = 128, early_stopping: bool = False, - lr: float | None = None, training_plan: PyroTrainingPlan | None = None, datasplitter_kwargs: dict | None = None, plan_kwargs: dict | None = None, **trainer_kwargs, ): - optim_kwargs = trainer_kwargs.pop("optim_kwargs", {}) - optim_kwargs.update({"lr": lr or 5e-3, "weight_decay": 1e-4}) - optim = pyro.optim.ClippedAdam(optim_kwargs) - plan_kwargs = plan_kwargs or {} - plan_kwargs.update({"optim": optim}) + if "early_stopping_monitor" not in trainer_kwargs: + trainer_kwargs["early_stopping_monitor"] = "nll_validation" super().train( max_epochs=max_epochs, accelerator=accelerator, diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 861bf9025a..11a892a03d 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -57,17 +57,17 @@ def __init__( self._epsilon = 1e-5 + self.n_obs = None # Populated by PyroTrainingPlan + # Hack: to allow auto_move_data to infer device. - self._dummy_param = nn.Parameter(torch.empty(0), requires_grad=False) + self._dummy_param = nn.Parameter(torch.empty(0)) @property def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -81,7 +81,10 @@ def model(self, x: torch.Tensor): constraint=constraints.positive, ) - with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + with ( + pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + poutine.scale(scale=1.0), + ): with poutine.scale(scale=self.beta): if self.prior == "normal": prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) @@ -105,15 +108,16 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data def guide(self, x: torch.Tensor): pyro.module("decipher", self) - with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + with ( + pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + poutine.scale(scale=1.0), + ): x = torch.log1p(x) z_loc, z_scale = self.encoder_x_to_z(x) diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py new file mode 100644 index 0000000000..10a9223e86 --- /dev/null +++ b/src/scvi/external/decipher/_trainingplan.py @@ -0,0 +1,123 @@ +import pyro +import torch + +from scvi.module.base import ( + PyroBaseModuleClass, +) +from scvi.train import LowLevelPyroTrainingPlan + +from ._utils import predictive_log_likelihood + + +class DecipherTrainingPlan(LowLevelPyroTrainingPlan): + """Lightning module task to train the Decipher Pyro module. + + Parameters + ---------- + pyro_module + An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object + should have callable `model` and `guide` attributes or methods. + loss_fn + A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. + If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`. + optim + A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`, + defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`. + optim_kwargs + Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`. + """ + + def __init__( + self, + pyro_module: PyroBaseModuleClass, + loss_fn: pyro.infer.ELBO | None = None, + optim: pyro.optim.PyroOptim | None = None, + optim_kwargs: dict | None = None, + ): + super().__init__( + pyro_module=pyro_module, + loss_fn=loss_fn, + ) + optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} + if "lr" not in optim_kwargs.keys(): + optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4}) + self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + # We let SVI take care of all optimization + self.automatic_optimization = False + + self.svi = pyro.infer.SVI( + model=self.module.model, + guide=self.module.guide, + optim=self.optim, + loss=self.loss_fn, + ) + # See configure_optimizers for what this does + self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) + + def training_step(self, batch, batch_idx): + """Training step for Pyro training.""" + args, kwargs = self.module._get_fn_args_from_batch(batch) + + # pytorch lightning requires a Tensor object for loss + loss = torch.Tensor([self.svi.step(*args, **kwargs)]) + + _opt = self.optimizers() + _opt.step() + + out_dict = {"loss": loss} + self.training_step_outputs.append(out_dict) + return out_dict + + def on_train_epoch_start(self): + """Training epoch start for Pyro training.""" + super().on_train_epoch_start() + if self.current_epoch > 0: + # freeze the batch norm layers after the first epoch + # 1) the batch norm layers helps with the initialization + # 2) but then, they seem to imply a strong normal prior on the latent space + for module in self.module.modules(): + if isinstance(module, torch.nn.BatchNorm1d): + module.eval() + + def validation_step(self, batch, batch_idx): + """Validation step for Pyro training.""" + out_dict = super().validation_step(batch, batch_idx) + args, kwargs = self.module._get_fn_args_from_batch(batch) + nll = -predictive_log_likelihood(self.module, *args, **kwargs, n_samples=5) + out_dict["nll"] = nll + self.validation_step_outputs[-1].update(out_dict) + return out_dict + + def on_validation_epoch_end(self): + """Validation epoch end for Pyro training.""" + outputs = self.validation_step_outputs + elbo = 0 + nll = 0 + n = 0 + for out in outputs: + elbo += out["loss"] + nll += out["nll"] + n += 1 + elbo /= n + nll /= n + self.log("elbo_validation", elbo, prog_bar=True) + self.log("nll_validation", nll, prog_bar=True) + self.validation_step_outputs.clear() + + def configure_optimizers(self): + """Shim optimizer for PyTorch Lightning. + + PyTorch Lightning wants to take steps on an optimizer + returned by this function in order to increment the global + step count. See PyTorch Lighinting optimizer manual loop. + + Here we provide a shim optimizer that we can take steps on + at minimal computational cost in order to keep Lightning happy :). + """ + return torch.optim.Adam([self._dummy_param]) + + def optimizer_step(self, *args, **kwargs): + pass + + def backward(self, *args, **kwargs): + pass diff --git a/src/scvi/external/decipher/_utils.py b/src/scvi/external/decipher/_utils.py new file mode 100644 index 0000000000..ebbc1c9edd --- /dev/null +++ b/src/scvi/external/decipher/_utils.py @@ -0,0 +1,45 @@ +import numpy as np +import pyro.poutine as poutine +import torch + + +def predictive_log_likelihood(decipher_module, batch, n_samples=5): + """ + Calculate the predictive log-likelihood for a Decipher module. + + This function performs multiple runs through the dataloader to obtain + an empirical estimate of the predictive log-likelihood. It calculates the + log-likelihood for each run and returns the average. The beta parameter + of the Decipher module is temporarily modified and restored even if an + exception occurs. + + Parameters + ---------- + decipher_module : PyroBaseModuleClass + The Decipher module to evaluate. + batch : torch.Tensor + Batch of data to compute the log-likelihood for. + n_samples : int, optional + Number of passes through the dataloader (default is 5). + + Returns + ------- + float + The average estimated predictive log-likelihood across multiple runs. + """ + log_weights = [] + old_beta = decipher_module.beta + decipher_module.beta = 1.0 + try: + for _ in range(n_samples): + guide_trace = poutine.trace(decipher_module.guide).get_trace(batch) + model_trace = poutine.trace( + poutine.replay(decipher_module.model, trace=guide_trace) + ).get_trace(batch) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) + + finally: + decipher_module.beta = old_beta + + log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) + return log_z.item() From c8c987febd59b0644caeab7c1c253a46f94ad583 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 17:45:02 -0400 Subject: [PATCH 03/40] add get latent rep to model class --- src/scvi/external/decipher/_model.py | 28 +++++++++++++++++++++ src/scvi/external/decipher/_module.py | 24 ------------------ src/scvi/external/decipher/_trainingplan.py | 16 +++++++++--- tests/external/decipher/test_decipher.py | 5 ++-- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 971f07e8a6..75a02164b4 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -1,7 +1,10 @@ import logging +from collections.abc import Sequence from typing import TYPE_CHECKING +import numpy as np import pyro +import torch from anndata import AnnData from scvi._constants import REGISTRY_KEYS @@ -15,6 +18,8 @@ from ._trainingplan import DecipherTrainingPlan if TYPE_CHECKING: + from collections.abc import Sequence + from anndata import AnnData logger = logging.getLogger(__name__) @@ -92,3 +97,26 @@ def train( datasplitter_kwargs=datasplitter_kwargs, **trainer_kwargs, ) + + def get_latent_representation( + self, + adata: AnnData | None = None, + indices: Sequence[int] | None = None, + batch_size: int | None = None, + give_z: bool = False, + ) -> np.ndarray: + self._check_if_trained(warn=False) + adata = self._validate_anndata(adata) + + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + latent_locs = [] + for tensors in scdl: + x = tensors[REGISTRY_KEYS.X_KEY] + x = torch.log1p(x) + z_loc, _ = self.module.encoder_x_to_z(x) + if give_z: + latent_locs.append(z_loc) + else: + v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1)) + latent_locs.append(v_loc) + return torch.cat(latent_locs).detach().numpy() diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 11a892a03d..c7baf7f24e 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -138,30 +138,6 @@ def guide(self, x: torch.Tensor): pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale - def compute_v_z_numpy(self, x: np.array): - """Compute decipher_v and decipher_z for a given input. - - Parameters - ---------- - x : np.ndarray or torch.Tensor - Input data of shape (n_cells, n_genes). - - Returns - ------- - v : np.ndarray - Decipher components v of shape (n_cells, dim_v). - z : np.ndarray - Decipher latent z of shape (n_cells, dim_z). - """ - if type(x) == np.ndarray: - x = torch.tensor(x, dtype=torch.float32) - - x = torch.log1p(x) - z_loc, _ = self.encoder_x_to_z(x) - zx = torch.cat([z_loc, x], dim=-1) - v_loc, _ = self.encoder_zx_to_v(zx) - return v_loc.detach().numpy(), z_loc.detach().numpy() - def impute_gene_expression_numpy(self, x): if type(x) == np.ndarray: x = torch.tensor(x, dtype=torch.float32) diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 10a9223e86..4b1de9682e 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -79,6 +79,16 @@ def on_train_epoch_start(self): if isinstance(module, torch.nn.BatchNorm1d): module.eval() + def on_train_epoch_end(self): + """Training epoch end for Pyro training.""" + outputs = self.training_step_outputs + elbo = 0 + for out in outputs: + elbo += out["loss"] + elbo /= self.n_obs_training + self.log("elbo_train", elbo, prog_bar=True) + self.training_step_outputs.clear() + def validation_step(self, batch, batch_idx): """Validation step for Pyro training.""" out_dict = super().validation_step(batch, batch_idx) @@ -93,13 +103,11 @@ def on_validation_epoch_end(self): outputs = self.validation_step_outputs elbo = 0 nll = 0 - n = 0 for out in outputs: elbo += out["loss"] nll += out["nll"] - n += 1 - elbo /= n - nll /= n + elbo /= self.n_obs_validation + nll /= self.n_obs_validation self.log("elbo_validation", elbo, prog_bar=True) self.log("nll_validation", nll, prog_bar=True) self.validation_step_outputs.clear() diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py index de05730eb3..0056096eb0 100644 --- a/tests/external/decipher/test_decipher.py +++ b/tests/external/decipher/test_decipher.py @@ -1,5 +1,4 @@ import pytest -import pyro from scvi.data import synthetic_iid from scvi.external import Decipher @@ -15,8 +14,10 @@ def test_decipher_train(adata): Decipher.setup_anndata(adata) model = Decipher(adata) model.train( - max_epochs=1, + max_epochs=2, check_val_every_n_epoch=1, train_size=0.5, early_stopping=True, ) + model.get_latent_representation(give_z=False) + model.get_latent_representation(give_z=True) From bfee40e50e0bb8673415d0c7e7b9659a6174ee02 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 17:45:45 -0400 Subject: [PATCH 04/40] remove impute fn for now --- src/scvi/external/decipher/_module.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index c7baf7f24e..02eb4f0096 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -67,7 +67,9 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -108,7 +110,9 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -137,12 +141,3 @@ def guide(self, x: torch.Tensor): raise ValueError("Invalid prior, must be normal or gamma") pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale - - def impute_gene_expression_numpy(self, x): - if type(x) == np.ndarray: - x = torch.tensor(x, dtype=torch.float32) - z_loc, _, _, _ = self.guide(x) - mu = self.decoder_z_to_x(z_loc) - mu = softmax(mu, dim=-1) - library_size = x.sum(axis=-1, keepdim=True) - return (library_size * mu).detach().numpy() From 8c8b52d2a29b429e63cd469a6c6fb1f5b476ac9f Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 18:27:59 -0400 Subject: [PATCH 05/40] revert base module lints --- src/scvi/module/base/_base_module.py | 40 +++++++--------------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index e9d97dccdd..39097c9039 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -96,9 +96,7 @@ def __post_init__(self): object.__setattr__(self, "loss", self.dict_sum(self.loss)) if self.n_obs_minibatch is None and self.reconstruction_loss is None: - raise ValueError( - "Must provide either n_obs_minibatch or reconstruction_loss" - ) + raise ValueError("Must provide either n_obs_minibatch or reconstruction_loss") default = 0 * self.loss if self.reconstruction_loss is None: @@ -108,9 +106,7 @@ def __post_init__(self): if self.kl_global is None: object.__setattr__(self, "kl_global", default) - object.__setattr__( - self, "reconstruction_loss", self._as_dict("reconstruction_loss") - ) + object.__setattr__(self, "reconstruction_loss", self._as_dict("reconstruction_loss")) object.__setattr__(self, "kl_local", self._as_dict("kl_local")) object.__setattr__(self, "kl_global", self._as_dict("kl_global")) object.__setattr__( @@ -123,16 +119,13 @@ def __post_init__(self): if self.reconstruction_loss is not None and self.n_obs_minibatch is None: rec_loss = self.reconstruction_loss - object.__setattr__( - self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0] - ) + object.__setattr__(self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0]) if self.classification_loss is not None and ( self.logits is None or self.true_labels is None ): raise ValueError( - "Must provide `logits` and `true_labels` if `classification_loss` is " - "provided." + "Must provide `logits` and `true_labels` if `classification_loss` is " "provided." ) @staticmethod @@ -191,10 +184,7 @@ def forward( generative_kwargs: dict | None = None, loss_kwargs: dict | None = None, compute_loss=True, - ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, LossOutput] - ): + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]: """Forward pass through the network. Parameters @@ -358,9 +348,7 @@ def __init__(self, on_load_kwargs: dict | None = None): @staticmethod @abstractmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: """Parse the minibatched data to get the correct inputs for ``model`` and ``guide``. In Pyro, ``model`` and ``guide`` must have the same signature. This is a helper method @@ -661,9 +649,7 @@ def load_state_dict(self, state_dict: dict[str, Any]): raise RuntimeError( "Train state is not set. Train for one iteration prior to loading state dict." ) - self.train_state = flax.serialization.from_state_dict( - self.train_state, state_dict - ) + self.train_state = flax.serialization.from_state_dict(self.train_state, state_dict) def to(self, device: Device): """Move module to device.""" @@ -691,9 +677,7 @@ def get_jit_inference_fn( self, get_inference_input_kwargs: dict[str, Any] | None = None, inference_kwargs: dict[str, Any] | None = None, - ) -> Callable[ - [dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray] - ]: + ) -> Callable[[dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]: """Create a method to run inference using the bound module. Parameters @@ -760,18 +744,14 @@ def _generic_forward( get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) - inference_inputs = module._get_inference_input( - tensors, **get_inference_input_kwargs - ) + inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) inference_outputs = module.inference(**inference_inputs, **inference_kwargs) generative_inputs = module._get_generative_input( tensors, inference_outputs, **get_generative_input_kwargs ) generative_outputs = module.generative(**generative_inputs, **generative_kwargs) if compute_loss: - losses = module.loss( - tensors, inference_outputs, generative_outputs, **loss_kwargs - ) + losses = module.loss(tensors, inference_outputs, generative_outputs, **loss_kwargs) return inference_outputs, generative_outputs, losses else: return inference_outputs, generative_outputs From 0d9039d8f0fcd931c6d0e3a8aa5acfaa8aa619b5 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 18:29:40 -0400 Subject: [PATCH 06/40] converges but with blobby latent space --- src/scvi/external/decipher/_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 02eb4f0096..efbeaa0056 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -84,7 +84,7 @@ def model(self, x: torch.Tensor): ) with ( - pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + pyro.plate("batch", x.shape[0]), poutine.scale(scale=1.0), ): with poutine.scale(scale=self.beta): @@ -119,7 +119,7 @@ def model(self, x: torch.Tensor): def guide(self, x: torch.Tensor): pyro.module("decipher", self) with ( - pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + pyro.plate("batch", x.shape[0]), poutine.scale(scale=1.0), ): x = torch.log1p(x) From 600eba1a4a23ed38ccd98fc4a772073a252a89d2 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 19:00:56 -0400 Subject: [PATCH 07/40] fix batch norm freezing and move pll impl into module --- src/scvi/external/decipher/_module.py | 64 ++++++++++++++++++++- src/scvi/external/decipher/_trainingplan.py | 14 ++--- src/scvi/external/decipher/_utils.py | 45 --------------- tests/external/decipher/test_decipher.py | 2 +- 4 files changed, 69 insertions(+), 56 deletions(-) delete mode 100644 src/scvi/external/decipher/_utils.py diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index efbeaa0056..aa264fd1bc 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -17,12 +17,27 @@ class DecipherPyroModule(PyroBaseModuleClass): - """Decipher _decipher for single-cell data. + """Decipher PyroModule for single-cell data analysis. + + This module implements the Decipher model for dimensionality reduction and + interpretable representation learning in single-cell RNA sequencing data. Parameters ---------- - config : DecipherConfig or dict - Configuration for the decipher _decipher. + dim_genes : int + Number of genes (features) in the dataset. + dim_v : int, optional + Dimension of the interpretable latent space v. Default is 2. + dim_z : int, optional + Dimension of the intermediate latent space z. Default is 10. + layers_v_to_z : Sequence[int], optional + Hidden layer sizes for the v to z decoder network. Default is (64,). + layers_z_to_x : Sequence[int], optional + Hidden layer sizes for the z to x decoder network. Default is empty tuple. + beta : float, optional + Regularization parameter for the KL divergence. Default is 0.1. + prior : str, optional + Type of prior distribution to use. Default is "normal". """ def __init__( @@ -141,3 +156,46 @@ def guide(self, x: torch.Tensor): raise ValueError("Invalid prior, must be normal or gamma") pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale + + def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): + """ + Calculate the predictive log-likelihood for a Decipher module. + + This function performs multiple runs through the dataloader to obtain + an empirical estimate of the predictive log-likelihood. It calculates the + log-likelihood for each run and returns the average. The beta parameter + of the Decipher module is temporarily modified and restored even if an + exception occurs. Used by default as an early stopping criterion. + + Parameters + ---------- + decipher_module : PyroBaseModuleClass + The Decipher module to evaluate. + x : torch.Tensor + Batch of data to compute the log-likelihood for. + n_samples : int, optional + Number of passes through the dataloader (default is 5). + + Returns + ------- + float + The average estimated predictive log-likelihood across multiple runs. + """ + log_weights = [] + old_beta = self.beta + self.beta = 1.0 + try: + for _ in range(n_samples): + guide_trace = poutine.trace(self.guide).get_trace(x) + model_trace = poutine.trace( + poutine.replay(self.model, trace=guide_trace) + ).get_trace(x) + log_weights.append( + model_trace.log_prob_sum() - guide_trace.log_prob_sum() + ) + + finally: + self.beta = old_beta + + log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) + return log_z.item() diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 4b1de9682e..8c72eae656 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -6,8 +6,6 @@ ) from scvi.train import LowLevelPyroTrainingPlan -from ._utils import predictive_log_likelihood - class DecipherTrainingPlan(LowLevelPyroTrainingPlan): """Lightning module task to train the Decipher Pyro module. @@ -41,7 +39,9 @@ def __init__( optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4}) - self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + self.optim = ( + pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + ) # We let SVI take care of all optimization self.automatic_optimization = False @@ -68,9 +68,9 @@ def training_step(self, batch, batch_idx): self.training_step_outputs.append(out_dict) return out_dict - def on_train_epoch_start(self): - """Training epoch start for Pyro training.""" - super().on_train_epoch_start() + def on_validation_model_train(self): + """Prepare the model for validation by switching to train mode.""" + super().on_validation_model_train() if self.current_epoch > 0: # freeze the batch norm layers after the first epoch # 1) the batch norm layers helps with the initialization @@ -93,7 +93,7 @@ def validation_step(self, batch, batch_idx): """Validation step for Pyro training.""" out_dict = super().validation_step(batch, batch_idx) args, kwargs = self.module._get_fn_args_from_batch(batch) - nll = -predictive_log_likelihood(self.module, *args, **kwargs, n_samples=5) + nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) out_dict["nll"] = nll self.validation_step_outputs[-1].update(out_dict) return out_dict diff --git a/src/scvi/external/decipher/_utils.py b/src/scvi/external/decipher/_utils.py deleted file mode 100644 index ebbc1c9edd..0000000000 --- a/src/scvi/external/decipher/_utils.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -import pyro.poutine as poutine -import torch - - -def predictive_log_likelihood(decipher_module, batch, n_samples=5): - """ - Calculate the predictive log-likelihood for a Decipher module. - - This function performs multiple runs through the dataloader to obtain - an empirical estimate of the predictive log-likelihood. It calculates the - log-likelihood for each run and returns the average. The beta parameter - of the Decipher module is temporarily modified and restored even if an - exception occurs. - - Parameters - ---------- - decipher_module : PyroBaseModuleClass - The Decipher module to evaluate. - batch : torch.Tensor - Batch of data to compute the log-likelihood for. - n_samples : int, optional - Number of passes through the dataloader (default is 5). - - Returns - ------- - float - The average estimated predictive log-likelihood across multiple runs. - """ - log_weights = [] - old_beta = decipher_module.beta - decipher_module.beta = 1.0 - try: - for _ in range(n_samples): - guide_trace = poutine.trace(decipher_module.guide).get_trace(batch) - model_trace = poutine.trace( - poutine.replay(decipher_module.model, trace=guide_trace) - ).get_trace(batch) - log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) - - finally: - decipher_module.beta = old_beta - - log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) - return log_z.item() diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py index 0056096eb0..d71c6ab642 100644 --- a/tests/external/decipher/test_decipher.py +++ b/tests/external/decipher/test_decipher.py @@ -14,7 +14,7 @@ def test_decipher_train(adata): Decipher.setup_anndata(adata) model = Decipher(adata) model.train( - max_epochs=2, + max_epochs=3, check_val_every_n_epoch=1, train_size=0.5, early_stopping=True, From 7fe1bebc891d3b8b2c940f657b31eb212955aa9f Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 19:13:59 -0400 Subject: [PATCH 08/40] fix save/load --- src/scvi/external/decipher/_model.py | 10 +++++++--- src/scvi/external/decipher/_trainingplan.py | 2 +- tests/external/decipher/test_decipher.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 75a02164b4..0e74e74ab2 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -41,7 +41,7 @@ def __init__(self, adata: AnnData, **kwargs): **kwargs, ) - self.init_params = self._get_init_params(locals()) + self.init_params_ = self._get_init_params(locals()) @classmethod @setup_anndata_dsp.dedent @@ -62,7 +62,9 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -108,7 +110,9 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + scdl = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 8c72eae656..32371eceb5 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -109,7 +109,7 @@ def on_validation_epoch_end(self): elbo /= self.n_obs_validation nll /= self.n_obs_validation self.log("elbo_validation", elbo, prog_bar=True) - self.log("nll_validation", nll, prog_bar=True) + self.log("nll_validation", nll, prog_bar=False) self.validation_step_outputs.clear() def configure_optimizers(self): diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py index d71c6ab642..0056096eb0 100644 --- a/tests/external/decipher/test_decipher.py +++ b/tests/external/decipher/test_decipher.py @@ -14,7 +14,7 @@ def test_decipher_train(adata): Decipher.setup_anndata(adata) model = Decipher(adata) model.train( - max_epochs=3, + max_epochs=2, check_val_every_n_epoch=1, train_size=0.5, early_stopping=True, From e48fe3be747f7b553b8e0adb1862ae23a4a11701 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 21:13:06 -0400 Subject: [PATCH 09/40] drop last and fix loss scaling --- src/scvi/external/decipher/_model.py | 3 ++ src/scvi/external/decipher/_module.py | 2 -- src/scvi/external/decipher/_trainingplan.py | 40 ++++++++++++--------- tests/external/mrvi/test_model.py | 8 ++--- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 0e74e74ab2..75a8db88de 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -85,6 +85,9 @@ def train( ): if "early_stopping_monitor" not in trainer_kwargs: trainer_kwargs["early_stopping_monitor"] = "nll_validation" + datasplitter_kwargs = datasplitter_kwargs or {} + if "drop_last" not in datasplitter_kwargs: + datasplitter_kwargs["drop_last"] = True super().train( max_epochs=max_epochs, accelerator=accelerator, diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index aa264fd1bc..433bc83706 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -169,8 +169,6 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): Parameters ---------- - decipher_module : PyroBaseModuleClass - The Decipher module to evaluate. x : torch.Tensor Batch of data to compute the log-likelihood for. n_samples : int, optional diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 32371eceb5..992ad164a8 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -38,7 +38,9 @@ def __init__( ) optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): - optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4}) + optim_kwargs.update({"lr": 5e-3}) + if "weight_decay" not in optim_kwargs.keys(): + optim_kwargs.update({"weight_decay": 1e-4}) self.optim = ( pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim ) @@ -54,38 +56,41 @@ def __init__( # See configure_optimizers for what this does self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) + def on_validation_model_train(self): + """Prepare the model for validation by switching to train mode.""" + super().on_validation_model_train() + if self.current_epoch > 0: + # freeze the batch norm layers after the first epoch + # 1) the batch norm layers helps with the initialization + # 2) but then, they seem to imply a strong normal prior on the latent space + for module in self.module.modules(): + if isinstance(module, torch.nn.BatchNorm1d): + module.eval() + def training_step(self, batch, batch_idx): """Training step for Pyro training.""" args, kwargs = self.module._get_fn_args_from_batch(batch) # pytorch lightning requires a Tensor object for loss loss = torch.Tensor([self.svi.step(*args, **kwargs)]) + n_obs = args[0].shape[0] _opt = self.optimizers() _opt.step() - out_dict = {"loss": loss} + out_dict = {"loss": loss, "n_obs": n_obs} self.training_step_outputs.append(out_dict) return out_dict - def on_validation_model_train(self): - """Prepare the model for validation by switching to train mode.""" - super().on_validation_model_train() - if self.current_epoch > 0: - # freeze the batch norm layers after the first epoch - # 1) the batch norm layers helps with the initialization - # 2) but then, they seem to imply a strong normal prior on the latent space - for module in self.module.modules(): - if isinstance(module, torch.nn.BatchNorm1d): - module.eval() - def on_train_epoch_end(self): """Training epoch end for Pyro training.""" outputs = self.training_step_outputs elbo = 0 + n_obs = 0 for out in outputs: elbo += out["loss"] - elbo /= self.n_obs_training + n_obs += out["n_obs"] + elbo /= n_obs self.log("elbo_train", elbo, prog_bar=True) self.training_step_outputs.clear() @@ -95,6 +100,7 @@ def validation_step(self, batch, batch_idx): args, kwargs = self.module._get_fn_args_from_batch(batch) nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) out_dict["nll"] = nll + out_dict["n_obs"] = args[0].shape[0] self.validation_step_outputs[-1].update(out_dict) return out_dict @@ -103,11 +109,13 @@ def on_validation_epoch_end(self): outputs = self.validation_step_outputs elbo = 0 nll = 0 + n_obs = 0 for out in outputs: elbo += out["loss"] nll += out["nll"] - elbo /= self.n_obs_validation - nll /= self.n_obs_validation + n_obs += out["n_obs"] + elbo /= n_obs + nll /= n_obs self.log("elbo_validation", elbo, prog_bar=True) self.log("nll_validation", nll, prog_bar=False) self.validation_step_outputs.clear() diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 5e3147c87e..05edb27496 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -155,9 +155,7 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs( - adata: AnnData, model_kwargs: dict[str, Any], save_path: str -): +def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -175,9 +173,7 @@ def test_mrvi_model_kwargs( def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression( - sample_cov_keys=sample_cov_keys, sample_subset=sample_subset - ) + model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From 2baf11dbefa5687cf708472be75be06ac3408a90 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Tue, 15 Oct 2024 11:00:14 -0400 Subject: [PATCH 10/40] fix tests, remove validation step from base training plan --- src/scvi/external/decipher/_trainingplan.py | 14 ++- src/scvi/train/_trainingplans.py | 119 +++----------------- 2 files changed, 28 insertions(+), 105 deletions(-) diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 992ad164a8..c07e19b18a 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -53,6 +53,8 @@ def __init__( optim=self.optim, loss=self.loss_fn, ) + self.validation_step_outputs = [] + # See configure_optimizers for what this does self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) @@ -96,12 +98,16 @@ def on_train_epoch_end(self): def validation_step(self, batch, batch_idx): """Validation step for Pyro training.""" - out_dict = super().validation_step(batch, batch_idx) args, kwargs = self.module._get_fn_args_from_batch(batch) + loss = self.differentiable_loss_fn( + self.scale_fn(self.module.model), + self.scale_fn(self.module.guide), + *args, + **kwargs, + ) nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) - out_dict["nll"] = nll - out_dict["n_obs"] = args[0].shape[0] - self.validation_step_outputs[-1].update(out_dict) + out_dict = {"loss": loss, "nll": nll, "n_obs": args[0].shape[0]} + self.validation_step_outputs.append(out_dict) return out_dict def on_validation_epoch_end(self): diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 5067d9cdbf..79aa4bf0e3 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -182,9 +182,7 @@ def __init__( self.optimizer_creator = optimizer_creator if self.optimizer_name == "Custom" and self.optimizer_creator is None: - raise ValueError( - "If optimizer is 'Custom', `optimizer_creator` must be provided." - ) + raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") self._n_obs_training = None self._n_obs_validation = None @@ -221,9 +219,7 @@ def initialize_train_metrics(self): self.kl_local_train, self.kl_global_train, self.train_metrics, - ) = self._create_elbo_metric_components( - mode="train", n_total=self.n_obs_training - ) + ) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training) self.elbo_train.reset() def initialize_val_metrics(self): @@ -234,9 +230,7 @@ def initialize_val_metrics(self): self.kl_local_val, self.kl_global_val, self.val_metrics, - ) = self._create_elbo_metric_components( - mode="validation", n_total=self.n_obs_validation - ) + ) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation) self.elbo_val.reset() @property @@ -372,9 +366,7 @@ def validation_step(self, batch, batch_idx): ) self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") - def _optimizer_creator_fn( - self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW - ): + def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW): """Create optimizer for the model. This type of function can be passed as the `optimizer_creator` @@ -552,9 +544,7 @@ def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True): if predict_true_class: cls_target = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) else: - one_hot_batch = torch.nn.functional.one_hot( - batch_index.squeeze(-1), n_classes - ) + one_hot_batch = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) # place zeroes where true label is cls_target = (~one_hot_batch.bool()).float() cls_target = cls_target / (n_classes - 1) @@ -582,9 +572,7 @@ def training_step(self, batch, batch_idx): else: opt1, opt2 = opts - inference_outputs, _, scvi_loss = self.forward( - batch, loss_kwargs=self.loss_kwargs - ) + inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) z = inference_outputs["z"] loss = scvi_loss.loss # fool classifier if doing adversarial training @@ -617,10 +605,7 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self) -> None: """Update the learning rate via scheduler steps.""" - if ( - not self.reduce_lr_on_plateau - or "validation" not in self.lr_scheduler_metric - ): + if not self.reduce_lr_on_plateau or "validation" not in self.lr_scheduler_metric: return else: sch = self.lr_schedulers() @@ -651,9 +636,7 @@ def configure_optimizers(self): ) if self.adversarial_classifier is not False: - params2 = filter( - lambda p: p.requires_grad, self.adversarial_classifier.parameters() - ) + params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters()) optimizer2 = torch.optim.Adam( params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay ) @@ -906,7 +889,6 @@ def __init__( super().__init__() self.module = pyro_module self._n_obs_training = None - self._n_obs_validation = None optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): @@ -919,18 +901,15 @@ def __init__( self.n_epochs_kl_warmup = n_epochs_kl_warmup self.use_kl_weight = False if isinstance(self.module.model, PyroModule): - self.use_kl_weight = ( - "kl_weight" in signature(self.module.model.forward).parameters - ) + self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters elif callable(self.module.model): self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters self.scale_elbo = scale_elbo - self.scale_fn = lambda obj: ( - pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj + self.scale_fn = ( + lambda obj: pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj ) self.differentiable_loss_fn = self.loss_fn.differentiable_loss self.training_step_outputs = [] - self.validation_step_outputs = [] def training_step(self, batch, batch_idx): """Training step for Pyro training.""" @@ -963,37 +942,6 @@ def on_train_epoch_end(self): self.log("elbo_train", elbo, prog_bar=True) self.training_step_outputs.clear() - def validation_step(self, batch, batch_idx): - """Validation step for Pyro training.""" - args, kwargs = self.module._get_fn_args_from_batch(batch) - # Set KL weight if necessary. - # Note: if applied, ELBO loss in progress bar is the effective KL annealed loss, not the - # true ELBO. - if self.use_kl_weight: - kwargs.update({"kl_weight": self.kl_weight}) - # pytorch lightning requires a Tensor object for loss - loss = self.differentiable_loss_fn( - self.scale_fn(self.module.model), - self.scale_fn(self.module.guide), - *args, - **kwargs, - ) - out_dict = {"loss": loss} - self.validation_step_outputs.append(out_dict) - return out_dict - - def on_validation_epoch_end(self): - """Validation epoch end for Pyro training.""" - outputs = self.validation_step_outputs - elbo = 0 - n = 0 - for out in outputs: - elbo += out["loss"] - n += 1 - elbo /= n - self.log("elbo_validation", elbo, prog_bar=True) - self.validation_step_outputs.clear() - def configure_optimizers(self): """Configure optimizers for the model.""" return self.optim(self.module.parameters(), **self.optim_kwargs) @@ -1033,27 +981,6 @@ def n_obs_training(self, n_obs: int): self._n_obs_training = n_obs - @property - def n_obs_validation(self): - """Number of validation examples. - - If not `None`, updates the `n_obs` attr - of the Pyro module's `model` and `guide`, if they exist. - """ - return self._n_obs_validation - - @n_obs_validation.setter - def n_obs_validation(self, n_obs: int): - if n_obs is not None: - if hasattr(self.module, "n_obs"): - self.module.n_obs = n_obs - if hasattr(self.module.model, "n_obs"): - self.module.model.n_obs = n_obs - if hasattr(self.module.guide, "n_obs"): - self.module.guide.n_obs = n_obs - - self._n_obs_validation = n_obs - class PyroTrainingPlan(LowLevelPyroTrainingPlan): """Lightning module task to train Pyro scvi-tools modules. @@ -1102,9 +1029,7 @@ def __init__( optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): optim_kwargs.update({"lr": 1e-3}) - self.optim = ( - pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim - ) + self.optim = pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim # We let SVI take care of all optimization self.automatic_optimization = False @@ -1200,9 +1125,7 @@ def __init__( self.loss_fn = loss() if self.module.logits is False and loss == torch.nn.CrossEntropyLoss: - raise UserWarning( - "classifier should return logits when using CrossEntropyLoss." - ) + raise UserWarning("classifier should return logits when using CrossEntropyLoss.") def forward(self, *args, **kwargs): """Passthrough to the module's forward function.""" @@ -1232,9 +1155,7 @@ def configure_optimizers(self): optim_cls = torch.optim.AdamW else: raise ValueError("Optimizer not understood.") - optimizer = optim_cls( - params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay - ) + optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay) return optimizer @@ -1300,11 +1221,7 @@ def __init__( def get_optimizer_creator(self) -> JaxOptimizerCreator: """Get optimizer creator for the model.""" - clip_by = ( - optax.clip_by_global_norm(self.max_norm) - if self.max_norm - else optax.identity() - ) + clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity() if self.optimizer_name == "Adam": # Replicates PyTorch Adam defaults optim = optax.chain( @@ -1358,9 +1275,9 @@ def loss_fn(params): loss = loss_output.loss return loss, (loss_output, new_model_state) - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( - loss_fn, has_aux=True - )(state.params) + (loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)( + state.params + ) new_state = state.apply_gradients(grads=grads, state=new_model_state) return new_state, loss, loss_output From b67f6c265cebc0a7cb687094ccb3379c97cf6d47 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 10 Oct 2024 14:06:54 -0400 Subject: [PATCH 11/40] first draft of moving decipher model into scvi-tools --- src/scvi/external/__init__.py | 2 + src/scvi/external/decipher/__init__.py | 4 + src/scvi/external/decipher/_components.py | 110 ++++++++++++++ src/scvi/external/decipher/_model.py | 104 ++++++++++++++ src/scvi/external/decipher/_module.py | 168 ++++++++++++++++++++++ src/scvi/module/base/_base_module.py | 40 ++++-- src/scvi/train/_trainingplans.py | 119 ++++++++++++--- tests/external/decipher/test_decipher.py | 22 +++ tests/external/mrvi/test_model.py | 8 +- 9 files changed, 547 insertions(+), 30 deletions(-) create mode 100644 src/scvi/external/decipher/__init__.py create mode 100644 src/scvi/external/decipher/_components.py create mode 100644 src/scvi/external/decipher/_model.py create mode 100644 src/scvi/external/decipher/_module.py create mode 100644 tests/external/decipher/test_decipher.py diff --git a/src/scvi/external/__init__.py b/src/scvi/external/__init__.py index 4e46ca1846..c54253c2da 100644 --- a/src/scvi/external/__init__.py +++ b/src/scvi/external/__init__.py @@ -1,5 +1,6 @@ from .cellassign import CellAssign from .contrastivevi import ContrastiveVI +from .decipher import Decipher from .gimvi import GIMVI from .methylvi import METHYLVI from .mrvi import MRVI @@ -15,6 +16,7 @@ "SCAR", "SOLO", "GIMVI", + "Decipher", "RNAStereoscope", "SpatialStereoscope", "CellAssign", diff --git a/src/scvi/external/decipher/__init__.py b/src/scvi/external/decipher/__init__.py new file mode 100644 index 0000000000..d1a6049056 --- /dev/null +++ b/src/scvi/external/decipher/__init__.py @@ -0,0 +1,4 @@ +from ._model import Decipher +from ._module import DecipherPyroModule + +__all__ = ["Decipher", "DecipherPyroModule"] diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py new file mode 100644 index 0000000000..2951528aca --- /dev/null +++ b/src/scvi/external/decipher/_components.py @@ -0,0 +1,110 @@ +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn + + +class ConditionalDenseNN(nn.Module): + """Dense neural network with multiple outputs, optionally conditioned on a context variable. + + (Derived from pyro.nn.dense_nn.ConditionalDenseNN with some modifications [1]) + + Parameters + ---------- + input_dim : int + Dimension of the input + hidden_dims : sequence of ints + Dimensions of the hidden layers (excluding the output layer) + output_dims : sequence of ints (optional) + Dimensions of each output layer + Default: (1,) + context_dim : int (optional) + Dimension of the context input. + Default: 0. No context input. + deep_context_injection : bool (optional) + If True, inject the context into every hidden layer. + If False, only inject the context into the first hidden layer (concatenated with the input). + Default: False. + activation : torch.nn.Module (optional) + Activation function to use between hidden layers (not applied to the outputs). + Default: torch.nn.ReLU() + """ + + def __init__( + self, + input_dim: int, + hidden_dims: Sequence[int], + output_dims: Sequence = (1,), + context_dim: int = 0, + deep_context_injection: bool = False, + activation=torch.nn.ReLU(), + ): + super().__init__() + + self.input_dim = input_dim + self.context_dim = context_dim + self.hidden_dims = hidden_dims + self.output_dims = output_dims + self.deep_context_injection = deep_context_injection + self.n_output_layers = len(self.output_dims) + self.output_total_dim = sum(self.output_dims) + + # The multiple outputs are computed as a single output layer, and then split + indices = np.concatenate(([0], np.cumsum(self.output_dims))) + self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:])] + + # Create masked layers + deep_context_dim = self.context_dim if self.deep_context_injection else 0 + layers = [] + batch_norms = [] + if len(hidden_dims): + layers.append(torch.nn.Linear(input_dim + context_dim, hidden_dims[0])) + batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) + for i in range(1, len(hidden_dims)): + layers.append( + torch.nn.Linear( + hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] + ) + ) + batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) + + layers.append( + torch.nn.Linear( + hidden_dims[-1] + deep_context_dim, self.output_total_dim + ) + ) + else: + layers.append( + torch.nn.Linear(input_dim + context_dim, self.output_total_dim) + ) + + self.layers = torch.nn.ModuleList(layers) + + self.f = activation + self.batch_norms = torch.nn.ModuleList(batch_norms) + + def forward(self, x, context=None): + if context is not None: + # We must be able to broadcast the size of the context over the input + context = context.expand(x.size()[:-1] + (context.size(-1),)) + + h = x + for i, layer in enumerate(self.layers): + if self.context_dim > 0 and (self.deep_context_injection or i == 0): + h = torch.cat([context, h], dim=-1) + h = layer(h) + if i < len(self.layers) - 1: + h = self.batch_norms[i](h) + h = self.f(h) + + if self.n_output_layers == 1: + return h + else: + h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim]) + + if self.n_output_layers == 1: + return h + + else: + return tuple([h[..., s] for s in self.output_slices]) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py new file mode 100644 index 0000000000..2381ff8d85 --- /dev/null +++ b/src/scvi/external/decipher/_model.py @@ -0,0 +1,104 @@ +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import logging + +from anndata import AnnData +import pyro + +from scvi._constants import REGISTRY_KEYS +from scvi.data import AnnDataManager +from scvi.data.fields import LayerField, CategoricalJointObsField +from scvi.utils import setup_anndata_dsp +from scvi.train import PyroTrainingPlan + +from scvi.model.base import BaseModelClass, PyroSviTrainMixin + +from ._module import DecipherPyroModule + +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + +logger = logging.getLogger(__name__) + + +class Decipher(PyroSviTrainMixin, BaseModelClass): + _module_cls = DecipherPyroModule + + def __init__(self, adata: AnnData, **kwargs): + pyro.clear_param_store() + + super().__init__(adata) + + dim_genes = self.summary_stats.n_vars + + self.module = self._module_cls( + dim_genes, + **kwargs, + ) + + self.init_params = self._get_init_params(locals()) + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + layer: str | None = None, + **kwargs, + ) -> AnnData | None: + """%(summary)s. + + Parameters + ---------- + %(param_adata)s + %(param_layer)s + """ + + setup_method_args = cls._get_setup_method_args(**locals()) + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + ] + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) + + def train( + self, + max_epochs: int | None = None, + accelerator: str = "auto", + device: int | str = "auto", + train_size: float = 0.9, + validation_size: float | None = None, + shuffle_set_split: bool = True, + batch_size: int = 128, + early_stopping: bool = False, + lr: float | None = None, + training_plan: PyroTrainingPlan | None = None, + datasplitter_kwargs: dict | None = None, + plan_kwargs: dict | None = None, + **trainer_kwargs, + ): + optim_kwargs = trainer_kwargs.pop("optim_kwargs", {}) + optim_kwargs.update({"lr": lr or 5e-3, "weight_decay": 1e-4}) + optim = pyro.optim.ClippedAdam(optim_kwargs) + plan_kwargs = plan_kwargs or {} + plan_kwargs.update({"optim": optim}) + super().train( + max_epochs=max_epochs, + accelerator=accelerator, + device=device, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + batch_size=batch_size, + early_stopping=early_stopping, + plan_kwargs=plan_kwargs, + training_plan=training_plan, + datasplitter_kwargs=datasplitter_kwargs, + **trainer_kwargs, + ) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py new file mode 100644 index 0000000000..861bf9025a --- /dev/null +++ b/src/scvi/external/decipher/_module.py @@ -0,0 +1,168 @@ +from collections.abc import Iterable, Sequence + +import numpy as np +import pyro +import pyro.distributions as dist +import pyro.poutine as poutine +import torch +import torch.nn as nn +import torch.utils.data +from torch.distributions import constraints +from torch.nn.functional import softmax, softplus + +from scvi._constants import REGISTRY_KEYS +from scvi.module.base import PyroBaseModuleClass, auto_move_data + +from ._components import ConditionalDenseNN + + +class DecipherPyroModule(PyroBaseModuleClass): + """Decipher _decipher for single-cell data. + + Parameters + ---------- + config : DecipherConfig or dict + Configuration for the decipher _decipher. + """ + + def __init__( + self, + dim_genes: int, + dim_v: int = 2, + dim_z: int = 10, + layers_v_to_z: Sequence[int] = (64,), + layers_z_to_x: Sequence[int] = tuple(), + beta: float = 0.1, + prior: str = "normal", + ): + super().__init__() + self.dim_v = dim_v + self.dim_z = dim_z + self.dim_genes = dim_genes + self.layers_v_to_z = layers_v_to_z + self.layers_z_to_x = layers_z_to_x + self.beta = beta + self.prior = prior + + self.decoder_v_to_z = ConditionalDenseNN(dim_v, layers_v_to_z, [dim_z] * 2) + self.decoder_z_to_x = ConditionalDenseNN(dim_z, layers_z_to_x, [dim_genes]) + self.encoder_x_to_z = ConditionalDenseNN(dim_genes, [128], [dim_z] * 2) + self.encoder_zx_to_v = ConditionalDenseNN( + dim_genes + dim_z, + [128], + [dim_v, dim_v], + ) + + self.theta = None + + self._epsilon = 1e-5 + + # Hack: to allow auto_move_data to infer device. + self._dummy_param = nn.Parameter(torch.empty(0), requires_grad=False) + + @property + def device(self): + return self._dummy_param.device + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: + x = tensor_dict[REGISTRY_KEYS.X_KEY] + return (x,), {} + + @auto_move_data + def model(self, x: torch.Tensor): + pyro.module("decipher", self) + + self.theta = pyro.param( + "theta", + x.new_ones(self.dim_genes), + constraint=constraints.positive, + ) + + with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + with poutine.scale(scale=self.beta): + if self.prior == "normal": + prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) + elif self.prior == "gamma": + prior = dist.Gamma(0.3, x.new_ones(self.dim_v) * 0.8).to_event(1) + else: + raise ValueError("Invalid prior, must be normal or gamma") + v = pyro.sample("v", prior) + + z_loc, z_scale = self.decoder_v_to_z(v) + z_scale = softplus(z_scale) + z = pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1)) + + mu = self.decoder_z_to_x(z) + mu = softmax(mu, dim=-1) + library_size = x.sum(axis=-1, keepdim=True) + # Parametrization of Negative Binomial by the mean and inverse dispersion + # See https://github.com/pytorch/pytorch/issues/42449 + # noinspection PyTypeChecker + logit = torch.log(library_size * mu + self._epsilon) - torch.log( + self.theta + self._epsilon + ) + # noinspection PyUnresolvedReferences + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) + pyro.sample("x", x_dist.to_event(1), obs=x) + + @auto_move_data + def guide(self, x: torch.Tensor): + pyro.module("decipher", self) + with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + x = torch.log1p(x) + + z_loc, z_scale = self.encoder_x_to_z(x) + z_scale = softplus(z_scale) + posterior_z = dist.Normal(z_loc, z_scale).to_event(1) + z = pyro.sample("z", posterior_z) + + zx = torch.cat([z, x], dim=-1) + v_loc, v_scale = self.encoder_zx_to_v(zx) + v_scale = softplus(v_scale) + with poutine.scale(scale=self.beta): + if self.prior == "gamma": + posterior_v = dist.Gamma(softplus(v_loc), v_scale).to_event(1) + elif self.prior == "normal" or self.prior == "student-normal": + posterior_v = dist.Normal(v_loc, v_scale).to_event(1) + else: + raise ValueError("Invalid prior, must be normal or gamma") + pyro.sample("v", posterior_v) + return z_loc, v_loc, z_scale, v_scale + + def compute_v_z_numpy(self, x: np.array): + """Compute decipher_v and decipher_z for a given input. + + Parameters + ---------- + x : np.ndarray or torch.Tensor + Input data of shape (n_cells, n_genes). + + Returns + ------- + v : np.ndarray + Decipher components v of shape (n_cells, dim_v). + z : np.ndarray + Decipher latent z of shape (n_cells, dim_z). + """ + if type(x) == np.ndarray: + x = torch.tensor(x, dtype=torch.float32) + + x = torch.log1p(x) + z_loc, _ = self.encoder_x_to_z(x) + zx = torch.cat([z_loc, x], dim=-1) + v_loc, _ = self.encoder_zx_to_v(zx) + return v_loc.detach().numpy(), z_loc.detach().numpy() + + def impute_gene_expression_numpy(self, x): + if type(x) == np.ndarray: + x = torch.tensor(x, dtype=torch.float32) + z_loc, _, _, _ = self.guide(x) + mu = self.decoder_z_to_x(z_loc) + mu = softmax(mu, dim=-1) + library_size = x.sum(axis=-1, keepdim=True) + return (library_size * mu).detach().numpy() diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index 39097c9039..e9d97dccdd 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -96,7 +96,9 @@ def __post_init__(self): object.__setattr__(self, "loss", self.dict_sum(self.loss)) if self.n_obs_minibatch is None and self.reconstruction_loss is None: - raise ValueError("Must provide either n_obs_minibatch or reconstruction_loss") + raise ValueError( + "Must provide either n_obs_minibatch or reconstruction_loss" + ) default = 0 * self.loss if self.reconstruction_loss is None: @@ -106,7 +108,9 @@ def __post_init__(self): if self.kl_global is None: object.__setattr__(self, "kl_global", default) - object.__setattr__(self, "reconstruction_loss", self._as_dict("reconstruction_loss")) + object.__setattr__( + self, "reconstruction_loss", self._as_dict("reconstruction_loss") + ) object.__setattr__(self, "kl_local", self._as_dict("kl_local")) object.__setattr__(self, "kl_global", self._as_dict("kl_global")) object.__setattr__( @@ -119,13 +123,16 @@ def __post_init__(self): if self.reconstruction_loss is not None and self.n_obs_minibatch is None: rec_loss = self.reconstruction_loss - object.__setattr__(self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0]) + object.__setattr__( + self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0] + ) if self.classification_loss is not None and ( self.logits is None or self.true_labels is None ): raise ValueError( - "Must provide `logits` and `true_labels` if `classification_loss` is " "provided." + "Must provide `logits` and `true_labels` if `classification_loss` is " + "provided." ) @staticmethod @@ -184,7 +191,10 @@ def forward( generative_kwargs: dict | None = None, loss_kwargs: dict | None = None, compute_loss=True, - ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]: + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, LossOutput] + ): """Forward pass through the network. Parameters @@ -348,7 +358,9 @@ def __init__(self, on_load_kwargs: dict | None = None): @staticmethod @abstractmethod - def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: """Parse the minibatched data to get the correct inputs for ``model`` and ``guide``. In Pyro, ``model`` and ``guide`` must have the same signature. This is a helper method @@ -649,7 +661,9 @@ def load_state_dict(self, state_dict: dict[str, Any]): raise RuntimeError( "Train state is not set. Train for one iteration prior to loading state dict." ) - self.train_state = flax.serialization.from_state_dict(self.train_state, state_dict) + self.train_state = flax.serialization.from_state_dict( + self.train_state, state_dict + ) def to(self, device: Device): """Move module to device.""" @@ -677,7 +691,9 @@ def get_jit_inference_fn( self, get_inference_input_kwargs: dict[str, Any] | None = None, inference_kwargs: dict[str, Any] | None = None, - ) -> Callable[[dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]: + ) -> Callable[ + [dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray] + ]: """Create a method to run inference using the bound module. Parameters @@ -744,14 +760,18 @@ def _generic_forward( get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) - inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) + inference_inputs = module._get_inference_input( + tensors, **get_inference_input_kwargs + ) inference_outputs = module.inference(**inference_inputs, **inference_kwargs) generative_inputs = module._get_generative_input( tensors, inference_outputs, **get_generative_input_kwargs ) generative_outputs = module.generative(**generative_inputs, **generative_kwargs) if compute_loss: - losses = module.loss(tensors, inference_outputs, generative_outputs, **loss_kwargs) + losses = module.loss( + tensors, inference_outputs, generative_outputs, **loss_kwargs + ) return inference_outputs, generative_outputs, losses else: return inference_outputs, generative_outputs diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 79aa4bf0e3..5067d9cdbf 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -182,7 +182,9 @@ def __init__( self.optimizer_creator = optimizer_creator if self.optimizer_name == "Custom" and self.optimizer_creator is None: - raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") + raise ValueError( + "If optimizer is 'Custom', `optimizer_creator` must be provided." + ) self._n_obs_training = None self._n_obs_validation = None @@ -219,7 +221,9 @@ def initialize_train_metrics(self): self.kl_local_train, self.kl_global_train, self.train_metrics, - ) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training) + ) = self._create_elbo_metric_components( + mode="train", n_total=self.n_obs_training + ) self.elbo_train.reset() def initialize_val_metrics(self): @@ -230,7 +234,9 @@ def initialize_val_metrics(self): self.kl_local_val, self.kl_global_val, self.val_metrics, - ) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation) + ) = self._create_elbo_metric_components( + mode="validation", n_total=self.n_obs_validation + ) self.elbo_val.reset() @property @@ -366,7 +372,9 @@ def validation_step(self, batch, batch_idx): ) self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") - def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW): + def _optimizer_creator_fn( + self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW + ): """Create optimizer for the model. This type of function can be passed as the `optimizer_creator` @@ -544,7 +552,9 @@ def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True): if predict_true_class: cls_target = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) else: - one_hot_batch = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) + one_hot_batch = torch.nn.functional.one_hot( + batch_index.squeeze(-1), n_classes + ) # place zeroes where true label is cls_target = (~one_hot_batch.bool()).float() cls_target = cls_target / (n_classes - 1) @@ -572,7 +582,9 @@ def training_step(self, batch, batch_idx): else: opt1, opt2 = opts - inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) + inference_outputs, _, scvi_loss = self.forward( + batch, loss_kwargs=self.loss_kwargs + ) z = inference_outputs["z"] loss = scvi_loss.loss # fool classifier if doing adversarial training @@ -605,7 +617,10 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self) -> None: """Update the learning rate via scheduler steps.""" - if not self.reduce_lr_on_plateau or "validation" not in self.lr_scheduler_metric: + if ( + not self.reduce_lr_on_plateau + or "validation" not in self.lr_scheduler_metric + ): return else: sch = self.lr_schedulers() @@ -636,7 +651,9 @@ def configure_optimizers(self): ) if self.adversarial_classifier is not False: - params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters()) + params2 = filter( + lambda p: p.requires_grad, self.adversarial_classifier.parameters() + ) optimizer2 = torch.optim.Adam( params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay ) @@ -889,6 +906,7 @@ def __init__( super().__init__() self.module = pyro_module self._n_obs_training = None + self._n_obs_validation = None optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): @@ -901,15 +919,18 @@ def __init__( self.n_epochs_kl_warmup = n_epochs_kl_warmup self.use_kl_weight = False if isinstance(self.module.model, PyroModule): - self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters + self.use_kl_weight = ( + "kl_weight" in signature(self.module.model.forward).parameters + ) elif callable(self.module.model): self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters self.scale_elbo = scale_elbo - self.scale_fn = ( - lambda obj: pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj + self.scale_fn = lambda obj: ( + pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj ) self.differentiable_loss_fn = self.loss_fn.differentiable_loss self.training_step_outputs = [] + self.validation_step_outputs = [] def training_step(self, batch, batch_idx): """Training step for Pyro training.""" @@ -942,6 +963,37 @@ def on_train_epoch_end(self): self.log("elbo_train", elbo, prog_bar=True) self.training_step_outputs.clear() + def validation_step(self, batch, batch_idx): + """Validation step for Pyro training.""" + args, kwargs = self.module._get_fn_args_from_batch(batch) + # Set KL weight if necessary. + # Note: if applied, ELBO loss in progress bar is the effective KL annealed loss, not the + # true ELBO. + if self.use_kl_weight: + kwargs.update({"kl_weight": self.kl_weight}) + # pytorch lightning requires a Tensor object for loss + loss = self.differentiable_loss_fn( + self.scale_fn(self.module.model), + self.scale_fn(self.module.guide), + *args, + **kwargs, + ) + out_dict = {"loss": loss} + self.validation_step_outputs.append(out_dict) + return out_dict + + def on_validation_epoch_end(self): + """Validation epoch end for Pyro training.""" + outputs = self.validation_step_outputs + elbo = 0 + n = 0 + for out in outputs: + elbo += out["loss"] + n += 1 + elbo /= n + self.log("elbo_validation", elbo, prog_bar=True) + self.validation_step_outputs.clear() + def configure_optimizers(self): """Configure optimizers for the model.""" return self.optim(self.module.parameters(), **self.optim_kwargs) @@ -981,6 +1033,27 @@ def n_obs_training(self, n_obs: int): self._n_obs_training = n_obs + @property + def n_obs_validation(self): + """Number of validation examples. + + If not `None`, updates the `n_obs` attr + of the Pyro module's `model` and `guide`, if they exist. + """ + return self._n_obs_validation + + @n_obs_validation.setter + def n_obs_validation(self, n_obs: int): + if n_obs is not None: + if hasattr(self.module, "n_obs"): + self.module.n_obs = n_obs + if hasattr(self.module.model, "n_obs"): + self.module.model.n_obs = n_obs + if hasattr(self.module.guide, "n_obs"): + self.module.guide.n_obs = n_obs + + self._n_obs_validation = n_obs + class PyroTrainingPlan(LowLevelPyroTrainingPlan): """Lightning module task to train Pyro scvi-tools modules. @@ -1029,7 +1102,9 @@ def __init__( optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): optim_kwargs.update({"lr": 1e-3}) - self.optim = pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim + self.optim = ( + pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim + ) # We let SVI take care of all optimization self.automatic_optimization = False @@ -1125,7 +1200,9 @@ def __init__( self.loss_fn = loss() if self.module.logits is False and loss == torch.nn.CrossEntropyLoss: - raise UserWarning("classifier should return logits when using CrossEntropyLoss.") + raise UserWarning( + "classifier should return logits when using CrossEntropyLoss." + ) def forward(self, *args, **kwargs): """Passthrough to the module's forward function.""" @@ -1155,7 +1232,9 @@ def configure_optimizers(self): optim_cls = torch.optim.AdamW else: raise ValueError("Optimizer not understood.") - optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay) + optimizer = optim_cls( + params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay + ) return optimizer @@ -1221,7 +1300,11 @@ def __init__( def get_optimizer_creator(self) -> JaxOptimizerCreator: """Get optimizer creator for the model.""" - clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity() + clip_by = ( + optax.clip_by_global_norm(self.max_norm) + if self.max_norm + else optax.identity() + ) if self.optimizer_name == "Adam": # Replicates PyTorch Adam defaults optim = optax.chain( @@ -1275,9 +1358,9 @@ def loss_fn(params): loss = loss_output.loss return loss, (loss_output, new_model_state) - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)( - state.params - ) + (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( + loss_fn, has_aux=True + )(state.params) new_state = state.apply_gradients(grads=grads, state=new_model_state) return new_state, loss, loss_output diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py new file mode 100644 index 0000000000..de05730eb3 --- /dev/null +++ b/tests/external/decipher/test_decipher.py @@ -0,0 +1,22 @@ +import pytest +import pyro + +from scvi.data import synthetic_iid +from scvi.external import Decipher + + +@pytest.fixture(scope="session") +def adata(): + adata = synthetic_iid() + return adata + + +def test_decipher_train(adata): + Decipher.setup_anndata(adata) + model = Decipher(adata) + model.train( + max_epochs=1, + check_val_every_n_epoch=1, + train_size=0.5, + early_stopping=True, + ) diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 05edb27496..5e3147c87e 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -155,7 +155,9 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): +def test_mrvi_model_kwargs( + adata: AnnData, model_kwargs: dict[str, Any], save_path: str +): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -173,7 +175,9 @@ def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_pa def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) + model.differential_expression( + sample_cov_keys=sample_cov_keys, sample_subset=sample_subset + ) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From 84ff3b054c5c0fda9a4d18131d577f5789945feb Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 16:42:24 -0400 Subject: [PATCH 12/40] add early stopping based on predictive nll and freeze batch norm after first epoch --- src/scvi/external/decipher/_model.py | 30 ++--- src/scvi/external/decipher/_module.py | 22 ++-- src/scvi/external/decipher/_trainingplan.py | 123 ++++++++++++++++++++ src/scvi/external/decipher/_utils.py | 45 +++++++ 4 files changed, 191 insertions(+), 29 deletions(-) create mode 100644 src/scvi/external/decipher/_trainingplan.py create mode 100644 src/scvi/external/decipher/_utils.py diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 2381ff8d85..971f07e8a6 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -1,24 +1,20 @@ -from collections.abc import Sequence -from typing import TYPE_CHECKING - import logging +from typing import TYPE_CHECKING -from anndata import AnnData import pyro +from anndata import AnnData from scvi._constants import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data.fields import LayerField, CategoricalJointObsField -from scvi.utils import setup_anndata_dsp -from scvi.train import PyroTrainingPlan - +from scvi.data.fields import LayerField from scvi.model.base import BaseModelClass, PyroSviTrainMixin +from scvi.train import PyroTrainingPlan +from scvi.utils import setup_anndata_dsp from ._module import DecipherPyroModule +from ._trainingplan import DecipherTrainingPlan if TYPE_CHECKING: - from collections.abc import Sequence - from anndata import AnnData logger = logging.getLogger(__name__) @@ -26,6 +22,7 @@ class Decipher(PyroSviTrainMixin, BaseModelClass): _module_cls = DecipherPyroModule + _training_plan_cls = DecipherTrainingPlan def __init__(self, adata: AnnData, **kwargs): pyro.clear_param_store() @@ -56,14 +53,11 @@ def setup_anndata( %(param_adata)s %(param_layer)s """ - setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -77,17 +71,13 @@ def train( shuffle_set_split: bool = True, batch_size: int = 128, early_stopping: bool = False, - lr: float | None = None, training_plan: PyroTrainingPlan | None = None, datasplitter_kwargs: dict | None = None, plan_kwargs: dict | None = None, **trainer_kwargs, ): - optim_kwargs = trainer_kwargs.pop("optim_kwargs", {}) - optim_kwargs.update({"lr": lr or 5e-3, "weight_decay": 1e-4}) - optim = pyro.optim.ClippedAdam(optim_kwargs) - plan_kwargs = plan_kwargs or {} - plan_kwargs.update({"optim": optim}) + if "early_stopping_monitor" not in trainer_kwargs: + trainer_kwargs["early_stopping_monitor"] = "nll_validation" super().train( max_epochs=max_epochs, accelerator=accelerator, diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 861bf9025a..11a892a03d 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -57,17 +57,17 @@ def __init__( self._epsilon = 1e-5 + self.n_obs = None # Populated by PyroTrainingPlan + # Hack: to allow auto_move_data to infer device. - self._dummy_param = nn.Parameter(torch.empty(0), requires_grad=False) + self._dummy_param = nn.Parameter(torch.empty(0)) @property def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -81,7 +81,10 @@ def model(self, x: torch.Tensor): constraint=constraints.positive, ) - with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + with ( + pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + poutine.scale(scale=1.0), + ): with poutine.scale(scale=self.beta): if self.prior == "normal": prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) @@ -105,15 +108,16 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data def guide(self, x: torch.Tensor): pyro.module("decipher", self) - with pyro.plate("batch", len(x)), poutine.scale(scale=1.0): + with ( + pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + poutine.scale(scale=1.0), + ): x = torch.log1p(x) z_loc, z_scale = self.encoder_x_to_z(x) diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py new file mode 100644 index 0000000000..10a9223e86 --- /dev/null +++ b/src/scvi/external/decipher/_trainingplan.py @@ -0,0 +1,123 @@ +import pyro +import torch + +from scvi.module.base import ( + PyroBaseModuleClass, +) +from scvi.train import LowLevelPyroTrainingPlan + +from ._utils import predictive_log_likelihood + + +class DecipherTrainingPlan(LowLevelPyroTrainingPlan): + """Lightning module task to train the Decipher Pyro module. + + Parameters + ---------- + pyro_module + An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object + should have callable `model` and `guide` attributes or methods. + loss_fn + A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. + If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`. + optim + A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`, + defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`. + optim_kwargs + Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`. + """ + + def __init__( + self, + pyro_module: PyroBaseModuleClass, + loss_fn: pyro.infer.ELBO | None = None, + optim: pyro.optim.PyroOptim | None = None, + optim_kwargs: dict | None = None, + ): + super().__init__( + pyro_module=pyro_module, + loss_fn=loss_fn, + ) + optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} + if "lr" not in optim_kwargs.keys(): + optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4}) + self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + # We let SVI take care of all optimization + self.automatic_optimization = False + + self.svi = pyro.infer.SVI( + model=self.module.model, + guide=self.module.guide, + optim=self.optim, + loss=self.loss_fn, + ) + # See configure_optimizers for what this does + self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) + + def training_step(self, batch, batch_idx): + """Training step for Pyro training.""" + args, kwargs = self.module._get_fn_args_from_batch(batch) + + # pytorch lightning requires a Tensor object for loss + loss = torch.Tensor([self.svi.step(*args, **kwargs)]) + + _opt = self.optimizers() + _opt.step() + + out_dict = {"loss": loss} + self.training_step_outputs.append(out_dict) + return out_dict + + def on_train_epoch_start(self): + """Training epoch start for Pyro training.""" + super().on_train_epoch_start() + if self.current_epoch > 0: + # freeze the batch norm layers after the first epoch + # 1) the batch norm layers helps with the initialization + # 2) but then, they seem to imply a strong normal prior on the latent space + for module in self.module.modules(): + if isinstance(module, torch.nn.BatchNorm1d): + module.eval() + + def validation_step(self, batch, batch_idx): + """Validation step for Pyro training.""" + out_dict = super().validation_step(batch, batch_idx) + args, kwargs = self.module._get_fn_args_from_batch(batch) + nll = -predictive_log_likelihood(self.module, *args, **kwargs, n_samples=5) + out_dict["nll"] = nll + self.validation_step_outputs[-1].update(out_dict) + return out_dict + + def on_validation_epoch_end(self): + """Validation epoch end for Pyro training.""" + outputs = self.validation_step_outputs + elbo = 0 + nll = 0 + n = 0 + for out in outputs: + elbo += out["loss"] + nll += out["nll"] + n += 1 + elbo /= n + nll /= n + self.log("elbo_validation", elbo, prog_bar=True) + self.log("nll_validation", nll, prog_bar=True) + self.validation_step_outputs.clear() + + def configure_optimizers(self): + """Shim optimizer for PyTorch Lightning. + + PyTorch Lightning wants to take steps on an optimizer + returned by this function in order to increment the global + step count. See PyTorch Lighinting optimizer manual loop. + + Here we provide a shim optimizer that we can take steps on + at minimal computational cost in order to keep Lightning happy :). + """ + return torch.optim.Adam([self._dummy_param]) + + def optimizer_step(self, *args, **kwargs): + pass + + def backward(self, *args, **kwargs): + pass diff --git a/src/scvi/external/decipher/_utils.py b/src/scvi/external/decipher/_utils.py new file mode 100644 index 0000000000..ebbc1c9edd --- /dev/null +++ b/src/scvi/external/decipher/_utils.py @@ -0,0 +1,45 @@ +import numpy as np +import pyro.poutine as poutine +import torch + + +def predictive_log_likelihood(decipher_module, batch, n_samples=5): + """ + Calculate the predictive log-likelihood for a Decipher module. + + This function performs multiple runs through the dataloader to obtain + an empirical estimate of the predictive log-likelihood. It calculates the + log-likelihood for each run and returns the average. The beta parameter + of the Decipher module is temporarily modified and restored even if an + exception occurs. + + Parameters + ---------- + decipher_module : PyroBaseModuleClass + The Decipher module to evaluate. + batch : torch.Tensor + Batch of data to compute the log-likelihood for. + n_samples : int, optional + Number of passes through the dataloader (default is 5). + + Returns + ------- + float + The average estimated predictive log-likelihood across multiple runs. + """ + log_weights = [] + old_beta = decipher_module.beta + decipher_module.beta = 1.0 + try: + for _ in range(n_samples): + guide_trace = poutine.trace(decipher_module.guide).get_trace(batch) + model_trace = poutine.trace( + poutine.replay(decipher_module.model, trace=guide_trace) + ).get_trace(batch) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) + + finally: + decipher_module.beta = old_beta + + log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) + return log_z.item() From f22741b298e1ed35c7885c54f1068c905e68e015 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 17:45:02 -0400 Subject: [PATCH 13/40] add get latent rep to model class --- src/scvi/external/decipher/_model.py | 28 +++++++++++++++++++++ src/scvi/external/decipher/_module.py | 24 ------------------ src/scvi/external/decipher/_trainingplan.py | 16 +++++++++--- tests/external/decipher/test_decipher.py | 5 ++-- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 971f07e8a6..75a02164b4 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -1,7 +1,10 @@ import logging +from collections.abc import Sequence from typing import TYPE_CHECKING +import numpy as np import pyro +import torch from anndata import AnnData from scvi._constants import REGISTRY_KEYS @@ -15,6 +18,8 @@ from ._trainingplan import DecipherTrainingPlan if TYPE_CHECKING: + from collections.abc import Sequence + from anndata import AnnData logger = logging.getLogger(__name__) @@ -92,3 +97,26 @@ def train( datasplitter_kwargs=datasplitter_kwargs, **trainer_kwargs, ) + + def get_latent_representation( + self, + adata: AnnData | None = None, + indices: Sequence[int] | None = None, + batch_size: int | None = None, + give_z: bool = False, + ) -> np.ndarray: + self._check_if_trained(warn=False) + adata = self._validate_anndata(adata) + + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + latent_locs = [] + for tensors in scdl: + x = tensors[REGISTRY_KEYS.X_KEY] + x = torch.log1p(x) + z_loc, _ = self.module.encoder_x_to_z(x) + if give_z: + latent_locs.append(z_loc) + else: + v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1)) + latent_locs.append(v_loc) + return torch.cat(latent_locs).detach().numpy() diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 11a892a03d..c7baf7f24e 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -138,30 +138,6 @@ def guide(self, x: torch.Tensor): pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale - def compute_v_z_numpy(self, x: np.array): - """Compute decipher_v and decipher_z for a given input. - - Parameters - ---------- - x : np.ndarray or torch.Tensor - Input data of shape (n_cells, n_genes). - - Returns - ------- - v : np.ndarray - Decipher components v of shape (n_cells, dim_v). - z : np.ndarray - Decipher latent z of shape (n_cells, dim_z). - """ - if type(x) == np.ndarray: - x = torch.tensor(x, dtype=torch.float32) - - x = torch.log1p(x) - z_loc, _ = self.encoder_x_to_z(x) - zx = torch.cat([z_loc, x], dim=-1) - v_loc, _ = self.encoder_zx_to_v(zx) - return v_loc.detach().numpy(), z_loc.detach().numpy() - def impute_gene_expression_numpy(self, x): if type(x) == np.ndarray: x = torch.tensor(x, dtype=torch.float32) diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 10a9223e86..4b1de9682e 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -79,6 +79,16 @@ def on_train_epoch_start(self): if isinstance(module, torch.nn.BatchNorm1d): module.eval() + def on_train_epoch_end(self): + """Training epoch end for Pyro training.""" + outputs = self.training_step_outputs + elbo = 0 + for out in outputs: + elbo += out["loss"] + elbo /= self.n_obs_training + self.log("elbo_train", elbo, prog_bar=True) + self.training_step_outputs.clear() + def validation_step(self, batch, batch_idx): """Validation step for Pyro training.""" out_dict = super().validation_step(batch, batch_idx) @@ -93,13 +103,11 @@ def on_validation_epoch_end(self): outputs = self.validation_step_outputs elbo = 0 nll = 0 - n = 0 for out in outputs: elbo += out["loss"] nll += out["nll"] - n += 1 - elbo /= n - nll /= n + elbo /= self.n_obs_validation + nll /= self.n_obs_validation self.log("elbo_validation", elbo, prog_bar=True) self.log("nll_validation", nll, prog_bar=True) self.validation_step_outputs.clear() diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py index de05730eb3..0056096eb0 100644 --- a/tests/external/decipher/test_decipher.py +++ b/tests/external/decipher/test_decipher.py @@ -1,5 +1,4 @@ import pytest -import pyro from scvi.data import synthetic_iid from scvi.external import Decipher @@ -15,8 +14,10 @@ def test_decipher_train(adata): Decipher.setup_anndata(adata) model = Decipher(adata) model.train( - max_epochs=1, + max_epochs=2, check_val_every_n_epoch=1, train_size=0.5, early_stopping=True, ) + model.get_latent_representation(give_z=False) + model.get_latent_representation(give_z=True) From 89c93198981c051db0413873dab5f6e4c123236a Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 17:45:45 -0400 Subject: [PATCH 14/40] remove impute fn for now --- src/scvi/external/decipher/_module.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index c7baf7f24e..02eb4f0096 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -67,7 +67,9 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -108,7 +110,9 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -137,12 +141,3 @@ def guide(self, x: torch.Tensor): raise ValueError("Invalid prior, must be normal or gamma") pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale - - def impute_gene_expression_numpy(self, x): - if type(x) == np.ndarray: - x = torch.tensor(x, dtype=torch.float32) - z_loc, _, _, _ = self.guide(x) - mu = self.decoder_z_to_x(z_loc) - mu = softmax(mu, dim=-1) - library_size = x.sum(axis=-1, keepdim=True) - return (library_size * mu).detach().numpy() From 812d5db1638d441add997d3828a68ee5c24cb76a Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 18:27:59 -0400 Subject: [PATCH 15/40] revert base module lints --- src/scvi/module/base/_base_module.py | 40 +++++++--------------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index e9d97dccdd..39097c9039 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -96,9 +96,7 @@ def __post_init__(self): object.__setattr__(self, "loss", self.dict_sum(self.loss)) if self.n_obs_minibatch is None and self.reconstruction_loss is None: - raise ValueError( - "Must provide either n_obs_minibatch or reconstruction_loss" - ) + raise ValueError("Must provide either n_obs_minibatch or reconstruction_loss") default = 0 * self.loss if self.reconstruction_loss is None: @@ -108,9 +106,7 @@ def __post_init__(self): if self.kl_global is None: object.__setattr__(self, "kl_global", default) - object.__setattr__( - self, "reconstruction_loss", self._as_dict("reconstruction_loss") - ) + object.__setattr__(self, "reconstruction_loss", self._as_dict("reconstruction_loss")) object.__setattr__(self, "kl_local", self._as_dict("kl_local")) object.__setattr__(self, "kl_global", self._as_dict("kl_global")) object.__setattr__( @@ -123,16 +119,13 @@ def __post_init__(self): if self.reconstruction_loss is not None and self.n_obs_minibatch is None: rec_loss = self.reconstruction_loss - object.__setattr__( - self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0] - ) + object.__setattr__(self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0]) if self.classification_loss is not None and ( self.logits is None or self.true_labels is None ): raise ValueError( - "Must provide `logits` and `true_labels` if `classification_loss` is " - "provided." + "Must provide `logits` and `true_labels` if `classification_loss` is " "provided." ) @staticmethod @@ -191,10 +184,7 @@ def forward( generative_kwargs: dict | None = None, loss_kwargs: dict | None = None, compute_loss=True, - ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, LossOutput] - ): + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]: """Forward pass through the network. Parameters @@ -358,9 +348,7 @@ def __init__(self, on_load_kwargs: dict | None = None): @staticmethod @abstractmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: """Parse the minibatched data to get the correct inputs for ``model`` and ``guide``. In Pyro, ``model`` and ``guide`` must have the same signature. This is a helper method @@ -661,9 +649,7 @@ def load_state_dict(self, state_dict: dict[str, Any]): raise RuntimeError( "Train state is not set. Train for one iteration prior to loading state dict." ) - self.train_state = flax.serialization.from_state_dict( - self.train_state, state_dict - ) + self.train_state = flax.serialization.from_state_dict(self.train_state, state_dict) def to(self, device: Device): """Move module to device.""" @@ -691,9 +677,7 @@ def get_jit_inference_fn( self, get_inference_input_kwargs: dict[str, Any] | None = None, inference_kwargs: dict[str, Any] | None = None, - ) -> Callable[ - [dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray] - ]: + ) -> Callable[[dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]: """Create a method to run inference using the bound module. Parameters @@ -760,18 +744,14 @@ def _generic_forward( get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) - inference_inputs = module._get_inference_input( - tensors, **get_inference_input_kwargs - ) + inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) inference_outputs = module.inference(**inference_inputs, **inference_kwargs) generative_inputs = module._get_generative_input( tensors, inference_outputs, **get_generative_input_kwargs ) generative_outputs = module.generative(**generative_inputs, **generative_kwargs) if compute_loss: - losses = module.loss( - tensors, inference_outputs, generative_outputs, **loss_kwargs - ) + losses = module.loss(tensors, inference_outputs, generative_outputs, **loss_kwargs) return inference_outputs, generative_outputs, losses else: return inference_outputs, generative_outputs From 4ba55a25f231f943f0de2de1b53bf7dd997f2c07 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 18:29:40 -0400 Subject: [PATCH 16/40] converges but with blobby latent space --- src/scvi/external/decipher/_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 02eb4f0096..efbeaa0056 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -84,7 +84,7 @@ def model(self, x: torch.Tensor): ) with ( - pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + pyro.plate("batch", x.shape[0]), poutine.scale(scale=1.0), ): with poutine.scale(scale=self.beta): @@ -119,7 +119,7 @@ def model(self, x: torch.Tensor): def guide(self, x: torch.Tensor): pyro.module("decipher", self) with ( - pyro.plate("batch", size=self.n_obs, subsample_size=x.shape[0]), + pyro.plate("batch", x.shape[0]), poutine.scale(scale=1.0), ): x = torch.log1p(x) From 0be6d068b797b402c2b9b1284485b47af79c7a8f Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 19:00:56 -0400 Subject: [PATCH 17/40] fix batch norm freezing and move pll impl into module --- src/scvi/external/decipher/_module.py | 64 ++++++++++++++++++++- src/scvi/external/decipher/_trainingplan.py | 14 ++--- src/scvi/external/decipher/_utils.py | 45 --------------- tests/external/decipher/test_decipher.py | 2 +- 4 files changed, 69 insertions(+), 56 deletions(-) delete mode 100644 src/scvi/external/decipher/_utils.py diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index efbeaa0056..aa264fd1bc 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -17,12 +17,27 @@ class DecipherPyroModule(PyroBaseModuleClass): - """Decipher _decipher for single-cell data. + """Decipher PyroModule for single-cell data analysis. + + This module implements the Decipher model for dimensionality reduction and + interpretable representation learning in single-cell RNA sequencing data. Parameters ---------- - config : DecipherConfig or dict - Configuration for the decipher _decipher. + dim_genes : int + Number of genes (features) in the dataset. + dim_v : int, optional + Dimension of the interpretable latent space v. Default is 2. + dim_z : int, optional + Dimension of the intermediate latent space z. Default is 10. + layers_v_to_z : Sequence[int], optional + Hidden layer sizes for the v to z decoder network. Default is (64,). + layers_z_to_x : Sequence[int], optional + Hidden layer sizes for the z to x decoder network. Default is empty tuple. + beta : float, optional + Regularization parameter for the KL divergence. Default is 0.1. + prior : str, optional + Type of prior distribution to use. Default is "normal". """ def __init__( @@ -141,3 +156,46 @@ def guide(self, x: torch.Tensor): raise ValueError("Invalid prior, must be normal or gamma") pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale + + def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): + """ + Calculate the predictive log-likelihood for a Decipher module. + + This function performs multiple runs through the dataloader to obtain + an empirical estimate of the predictive log-likelihood. It calculates the + log-likelihood for each run and returns the average. The beta parameter + of the Decipher module is temporarily modified and restored even if an + exception occurs. Used by default as an early stopping criterion. + + Parameters + ---------- + decipher_module : PyroBaseModuleClass + The Decipher module to evaluate. + x : torch.Tensor + Batch of data to compute the log-likelihood for. + n_samples : int, optional + Number of passes through the dataloader (default is 5). + + Returns + ------- + float + The average estimated predictive log-likelihood across multiple runs. + """ + log_weights = [] + old_beta = self.beta + self.beta = 1.0 + try: + for _ in range(n_samples): + guide_trace = poutine.trace(self.guide).get_trace(x) + model_trace = poutine.trace( + poutine.replay(self.model, trace=guide_trace) + ).get_trace(x) + log_weights.append( + model_trace.log_prob_sum() - guide_trace.log_prob_sum() + ) + + finally: + self.beta = old_beta + + log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) + return log_z.item() diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 4b1de9682e..8c72eae656 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -6,8 +6,6 @@ ) from scvi.train import LowLevelPyroTrainingPlan -from ._utils import predictive_log_likelihood - class DecipherTrainingPlan(LowLevelPyroTrainingPlan): """Lightning module task to train the Decipher Pyro module. @@ -41,7 +39,9 @@ def __init__( optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4}) - self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + self.optim = ( + pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + ) # We let SVI take care of all optimization self.automatic_optimization = False @@ -68,9 +68,9 @@ def training_step(self, batch, batch_idx): self.training_step_outputs.append(out_dict) return out_dict - def on_train_epoch_start(self): - """Training epoch start for Pyro training.""" - super().on_train_epoch_start() + def on_validation_model_train(self): + """Prepare the model for validation by switching to train mode.""" + super().on_validation_model_train() if self.current_epoch > 0: # freeze the batch norm layers after the first epoch # 1) the batch norm layers helps with the initialization @@ -93,7 +93,7 @@ def validation_step(self, batch, batch_idx): """Validation step for Pyro training.""" out_dict = super().validation_step(batch, batch_idx) args, kwargs = self.module._get_fn_args_from_batch(batch) - nll = -predictive_log_likelihood(self.module, *args, **kwargs, n_samples=5) + nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) out_dict["nll"] = nll self.validation_step_outputs[-1].update(out_dict) return out_dict diff --git a/src/scvi/external/decipher/_utils.py b/src/scvi/external/decipher/_utils.py deleted file mode 100644 index ebbc1c9edd..0000000000 --- a/src/scvi/external/decipher/_utils.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -import pyro.poutine as poutine -import torch - - -def predictive_log_likelihood(decipher_module, batch, n_samples=5): - """ - Calculate the predictive log-likelihood for a Decipher module. - - This function performs multiple runs through the dataloader to obtain - an empirical estimate of the predictive log-likelihood. It calculates the - log-likelihood for each run and returns the average. The beta parameter - of the Decipher module is temporarily modified and restored even if an - exception occurs. - - Parameters - ---------- - decipher_module : PyroBaseModuleClass - The Decipher module to evaluate. - batch : torch.Tensor - Batch of data to compute the log-likelihood for. - n_samples : int, optional - Number of passes through the dataloader (default is 5). - - Returns - ------- - float - The average estimated predictive log-likelihood across multiple runs. - """ - log_weights = [] - old_beta = decipher_module.beta - decipher_module.beta = 1.0 - try: - for _ in range(n_samples): - guide_trace = poutine.trace(decipher_module.guide).get_trace(batch) - model_trace = poutine.trace( - poutine.replay(decipher_module.model, trace=guide_trace) - ).get_trace(batch) - log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) - - finally: - decipher_module.beta = old_beta - - log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) - return log_z.item() diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py index 0056096eb0..d71c6ab642 100644 --- a/tests/external/decipher/test_decipher.py +++ b/tests/external/decipher/test_decipher.py @@ -14,7 +14,7 @@ def test_decipher_train(adata): Decipher.setup_anndata(adata) model = Decipher(adata) model.train( - max_epochs=2, + max_epochs=3, check_val_every_n_epoch=1, train_size=0.5, early_stopping=True, From 2cdb5393b8b868f46ddb8d8d74a7d5c988d38013 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 19:13:59 -0400 Subject: [PATCH 18/40] fix save/load --- src/scvi/external/decipher/_model.py | 10 +++++++--- src/scvi/external/decipher/_trainingplan.py | 2 +- tests/external/decipher/test_decipher.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 75a02164b4..0e74e74ab2 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -41,7 +41,7 @@ def __init__(self, adata: AnnData, **kwargs): **kwargs, ) - self.init_params = self._get_init_params(locals()) + self.init_params_ = self._get_init_params(locals()) @classmethod @setup_anndata_dsp.dedent @@ -62,7 +62,9 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -108,7 +110,9 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + scdl = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 8c72eae656..32371eceb5 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -109,7 +109,7 @@ def on_validation_epoch_end(self): elbo /= self.n_obs_validation nll /= self.n_obs_validation self.log("elbo_validation", elbo, prog_bar=True) - self.log("nll_validation", nll, prog_bar=True) + self.log("nll_validation", nll, prog_bar=False) self.validation_step_outputs.clear() def configure_optimizers(self): diff --git a/tests/external/decipher/test_decipher.py b/tests/external/decipher/test_decipher.py index d71c6ab642..0056096eb0 100644 --- a/tests/external/decipher/test_decipher.py +++ b/tests/external/decipher/test_decipher.py @@ -14,7 +14,7 @@ def test_decipher_train(adata): Decipher.setup_anndata(adata) model = Decipher(adata) model.train( - max_epochs=3, + max_epochs=2, check_val_every_n_epoch=1, train_size=0.5, early_stopping=True, From 1caa156fdb443b83d6f82b8f99b203a1f57e53c0 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 21:13:06 -0400 Subject: [PATCH 19/40] drop last and fix loss scaling --- src/scvi/external/decipher/_model.py | 3 ++ src/scvi/external/decipher/_module.py | 2 -- src/scvi/external/decipher/_trainingplan.py | 40 ++++++++++++--------- tests/external/mrvi/test_model.py | 8 ++--- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 0e74e74ab2..75a8db88de 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -85,6 +85,9 @@ def train( ): if "early_stopping_monitor" not in trainer_kwargs: trainer_kwargs["early_stopping_monitor"] = "nll_validation" + datasplitter_kwargs = datasplitter_kwargs or {} + if "drop_last" not in datasplitter_kwargs: + datasplitter_kwargs["drop_last"] = True super().train( max_epochs=max_epochs, accelerator=accelerator, diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index aa264fd1bc..433bc83706 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -169,8 +169,6 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): Parameters ---------- - decipher_module : PyroBaseModuleClass - The Decipher module to evaluate. x : torch.Tensor Batch of data to compute the log-likelihood for. n_samples : int, optional diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 32371eceb5..992ad164a8 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -38,7 +38,9 @@ def __init__( ) optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): - optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4}) + optim_kwargs.update({"lr": 5e-3}) + if "weight_decay" not in optim_kwargs.keys(): + optim_kwargs.update({"weight_decay": 1e-4}) self.optim = ( pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim ) @@ -54,38 +56,41 @@ def __init__( # See configure_optimizers for what this does self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) + def on_validation_model_train(self): + """Prepare the model for validation by switching to train mode.""" + super().on_validation_model_train() + if self.current_epoch > 0: + # freeze the batch norm layers after the first epoch + # 1) the batch norm layers helps with the initialization + # 2) but then, they seem to imply a strong normal prior on the latent space + for module in self.module.modules(): + if isinstance(module, torch.nn.BatchNorm1d): + module.eval() + def training_step(self, batch, batch_idx): """Training step for Pyro training.""" args, kwargs = self.module._get_fn_args_from_batch(batch) # pytorch lightning requires a Tensor object for loss loss = torch.Tensor([self.svi.step(*args, **kwargs)]) + n_obs = args[0].shape[0] _opt = self.optimizers() _opt.step() - out_dict = {"loss": loss} + out_dict = {"loss": loss, "n_obs": n_obs} self.training_step_outputs.append(out_dict) return out_dict - def on_validation_model_train(self): - """Prepare the model for validation by switching to train mode.""" - super().on_validation_model_train() - if self.current_epoch > 0: - # freeze the batch norm layers after the first epoch - # 1) the batch norm layers helps with the initialization - # 2) but then, they seem to imply a strong normal prior on the latent space - for module in self.module.modules(): - if isinstance(module, torch.nn.BatchNorm1d): - module.eval() - def on_train_epoch_end(self): """Training epoch end for Pyro training.""" outputs = self.training_step_outputs elbo = 0 + n_obs = 0 for out in outputs: elbo += out["loss"] - elbo /= self.n_obs_training + n_obs += out["n_obs"] + elbo /= n_obs self.log("elbo_train", elbo, prog_bar=True) self.training_step_outputs.clear() @@ -95,6 +100,7 @@ def validation_step(self, batch, batch_idx): args, kwargs = self.module._get_fn_args_from_batch(batch) nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) out_dict["nll"] = nll + out_dict["n_obs"] = args[0].shape[0] self.validation_step_outputs[-1].update(out_dict) return out_dict @@ -103,11 +109,13 @@ def on_validation_epoch_end(self): outputs = self.validation_step_outputs elbo = 0 nll = 0 + n_obs = 0 for out in outputs: elbo += out["loss"] nll += out["nll"] - elbo /= self.n_obs_validation - nll /= self.n_obs_validation + n_obs += out["n_obs"] + elbo /= n_obs + nll /= n_obs self.log("elbo_validation", elbo, prog_bar=True) self.log("nll_validation", nll, prog_bar=False) self.validation_step_outputs.clear() diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 5e3147c87e..05edb27496 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -155,9 +155,7 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs( - adata: AnnData, model_kwargs: dict[str, Any], save_path: str -): +def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -175,9 +173,7 @@ def test_mrvi_model_kwargs( def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression( - sample_cov_keys=sample_cov_keys, sample_subset=sample_subset - ) + model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From 9cc3aa6ef76c5411c73e8b813ed6c481a5a01d0c Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Tue, 15 Oct 2024 11:00:14 -0400 Subject: [PATCH 20/40] fix tests, remove validation step from base training plan --- src/scvi/external/decipher/_trainingplan.py | 14 ++- src/scvi/train/_trainingplans.py | 119 +++----------------- 2 files changed, 28 insertions(+), 105 deletions(-) diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 992ad164a8..c07e19b18a 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -53,6 +53,8 @@ def __init__( optim=self.optim, loss=self.loss_fn, ) + self.validation_step_outputs = [] + # See configure_optimizers for what this does self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) @@ -96,12 +98,16 @@ def on_train_epoch_end(self): def validation_step(self, batch, batch_idx): """Validation step for Pyro training.""" - out_dict = super().validation_step(batch, batch_idx) args, kwargs = self.module._get_fn_args_from_batch(batch) + loss = self.differentiable_loss_fn( + self.scale_fn(self.module.model), + self.scale_fn(self.module.guide), + *args, + **kwargs, + ) nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) - out_dict["nll"] = nll - out_dict["n_obs"] = args[0].shape[0] - self.validation_step_outputs[-1].update(out_dict) + out_dict = {"loss": loss, "nll": nll, "n_obs": args[0].shape[0]} + self.validation_step_outputs.append(out_dict) return out_dict def on_validation_epoch_end(self): diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 5067d9cdbf..79aa4bf0e3 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -182,9 +182,7 @@ def __init__( self.optimizer_creator = optimizer_creator if self.optimizer_name == "Custom" and self.optimizer_creator is None: - raise ValueError( - "If optimizer is 'Custom', `optimizer_creator` must be provided." - ) + raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") self._n_obs_training = None self._n_obs_validation = None @@ -221,9 +219,7 @@ def initialize_train_metrics(self): self.kl_local_train, self.kl_global_train, self.train_metrics, - ) = self._create_elbo_metric_components( - mode="train", n_total=self.n_obs_training - ) + ) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training) self.elbo_train.reset() def initialize_val_metrics(self): @@ -234,9 +230,7 @@ def initialize_val_metrics(self): self.kl_local_val, self.kl_global_val, self.val_metrics, - ) = self._create_elbo_metric_components( - mode="validation", n_total=self.n_obs_validation - ) + ) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation) self.elbo_val.reset() @property @@ -372,9 +366,7 @@ def validation_step(self, batch, batch_idx): ) self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") - def _optimizer_creator_fn( - self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW - ): + def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW): """Create optimizer for the model. This type of function can be passed as the `optimizer_creator` @@ -552,9 +544,7 @@ def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True): if predict_true_class: cls_target = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) else: - one_hot_batch = torch.nn.functional.one_hot( - batch_index.squeeze(-1), n_classes - ) + one_hot_batch = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) # place zeroes where true label is cls_target = (~one_hot_batch.bool()).float() cls_target = cls_target / (n_classes - 1) @@ -582,9 +572,7 @@ def training_step(self, batch, batch_idx): else: opt1, opt2 = opts - inference_outputs, _, scvi_loss = self.forward( - batch, loss_kwargs=self.loss_kwargs - ) + inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) z = inference_outputs["z"] loss = scvi_loss.loss # fool classifier if doing adversarial training @@ -617,10 +605,7 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self) -> None: """Update the learning rate via scheduler steps.""" - if ( - not self.reduce_lr_on_plateau - or "validation" not in self.lr_scheduler_metric - ): + if not self.reduce_lr_on_plateau or "validation" not in self.lr_scheduler_metric: return else: sch = self.lr_schedulers() @@ -651,9 +636,7 @@ def configure_optimizers(self): ) if self.adversarial_classifier is not False: - params2 = filter( - lambda p: p.requires_grad, self.adversarial_classifier.parameters() - ) + params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters()) optimizer2 = torch.optim.Adam( params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay ) @@ -906,7 +889,6 @@ def __init__( super().__init__() self.module = pyro_module self._n_obs_training = None - self._n_obs_validation = None optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): @@ -919,18 +901,15 @@ def __init__( self.n_epochs_kl_warmup = n_epochs_kl_warmup self.use_kl_weight = False if isinstance(self.module.model, PyroModule): - self.use_kl_weight = ( - "kl_weight" in signature(self.module.model.forward).parameters - ) + self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters elif callable(self.module.model): self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters self.scale_elbo = scale_elbo - self.scale_fn = lambda obj: ( - pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj + self.scale_fn = ( + lambda obj: pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj ) self.differentiable_loss_fn = self.loss_fn.differentiable_loss self.training_step_outputs = [] - self.validation_step_outputs = [] def training_step(self, batch, batch_idx): """Training step for Pyro training.""" @@ -963,37 +942,6 @@ def on_train_epoch_end(self): self.log("elbo_train", elbo, prog_bar=True) self.training_step_outputs.clear() - def validation_step(self, batch, batch_idx): - """Validation step for Pyro training.""" - args, kwargs = self.module._get_fn_args_from_batch(batch) - # Set KL weight if necessary. - # Note: if applied, ELBO loss in progress bar is the effective KL annealed loss, not the - # true ELBO. - if self.use_kl_weight: - kwargs.update({"kl_weight": self.kl_weight}) - # pytorch lightning requires a Tensor object for loss - loss = self.differentiable_loss_fn( - self.scale_fn(self.module.model), - self.scale_fn(self.module.guide), - *args, - **kwargs, - ) - out_dict = {"loss": loss} - self.validation_step_outputs.append(out_dict) - return out_dict - - def on_validation_epoch_end(self): - """Validation epoch end for Pyro training.""" - outputs = self.validation_step_outputs - elbo = 0 - n = 0 - for out in outputs: - elbo += out["loss"] - n += 1 - elbo /= n - self.log("elbo_validation", elbo, prog_bar=True) - self.validation_step_outputs.clear() - def configure_optimizers(self): """Configure optimizers for the model.""" return self.optim(self.module.parameters(), **self.optim_kwargs) @@ -1033,27 +981,6 @@ def n_obs_training(self, n_obs: int): self._n_obs_training = n_obs - @property - def n_obs_validation(self): - """Number of validation examples. - - If not `None`, updates the `n_obs` attr - of the Pyro module's `model` and `guide`, if they exist. - """ - return self._n_obs_validation - - @n_obs_validation.setter - def n_obs_validation(self, n_obs: int): - if n_obs is not None: - if hasattr(self.module, "n_obs"): - self.module.n_obs = n_obs - if hasattr(self.module.model, "n_obs"): - self.module.model.n_obs = n_obs - if hasattr(self.module.guide, "n_obs"): - self.module.guide.n_obs = n_obs - - self._n_obs_validation = n_obs - class PyroTrainingPlan(LowLevelPyroTrainingPlan): """Lightning module task to train Pyro scvi-tools modules. @@ -1102,9 +1029,7 @@ def __init__( optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): optim_kwargs.update({"lr": 1e-3}) - self.optim = ( - pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim - ) + self.optim = pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim # We let SVI take care of all optimization self.automatic_optimization = False @@ -1200,9 +1125,7 @@ def __init__( self.loss_fn = loss() if self.module.logits is False and loss == torch.nn.CrossEntropyLoss: - raise UserWarning( - "classifier should return logits when using CrossEntropyLoss." - ) + raise UserWarning("classifier should return logits when using CrossEntropyLoss.") def forward(self, *args, **kwargs): """Passthrough to the module's forward function.""" @@ -1232,9 +1155,7 @@ def configure_optimizers(self): optim_cls = torch.optim.AdamW else: raise ValueError("Optimizer not understood.") - optimizer = optim_cls( - params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay - ) + optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay) return optimizer @@ -1300,11 +1221,7 @@ def __init__( def get_optimizer_creator(self) -> JaxOptimizerCreator: """Get optimizer creator for the model.""" - clip_by = ( - optax.clip_by_global_norm(self.max_norm) - if self.max_norm - else optax.identity() - ) + clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity() if self.optimizer_name == "Adam": # Replicates PyTorch Adam defaults optim = optax.chain( @@ -1358,9 +1275,9 @@ def loss_fn(params): loss = loss_output.loss return loss, (loss_output, new_model_state) - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( - loss_fn, has_aux=True - )(state.params) + (loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)( + state.params + ) new_state = state.apply_gradients(grads=grads, state=new_model_state) return new_state, loss, loss_output From 1bc5805ca88ce958e26754dbbc2e079dbd63f4be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 15:00:33 +0000 Subject: [PATCH 21/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/decipher/_components.py | 16 +++++----------- src/scvi/external/decipher/_model.py | 8 ++------ src/scvi/external/decipher/_module.py | 12 +++--------- src/scvi/external/decipher/_trainingplan.py | 4 +--- 4 files changed, 11 insertions(+), 29 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 2951528aca..5439190fb9 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence import numpy as np import torch @@ -52,7 +52,7 @@ def __init__( # The multiple outputs are computed as a single output layer, and then split indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:])] + self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -63,21 +63,15 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear( - hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] - ) + torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear( - hidden_dims[-1] + deep_context_dim, self.output_total_dim - ) + torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) ) else: - layers.append( - torch.nn.Linear(input_dim + context_dim, self.output_total_dim) - ) + layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) self.layers = torch.nn.ModuleList(layers) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 75a8db88de..48f730f2eb 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -62,9 +62,7 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -113,9 +111,7 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 433bc83706..c327654527 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -82,9 +82,7 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -125,9 +123,7 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -188,9 +184,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(x) - log_weights.append( - model_trace.log_prob_sum() - guide_trace.log_prob_sum() - ) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) finally: self.beta = old_beta diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index c07e19b18a..adcb9dbbbe 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -41,9 +41,7 @@ def __init__( optim_kwargs.update({"lr": 5e-3}) if "weight_decay" not in optim_kwargs.keys(): optim_kwargs.update({"weight_decay": 1e-4}) - self.optim = ( - pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim - ) + self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim # We let SVI take care of all optimization self.automatic_optimization = False From 46a23207e7a3fbb4cb3ee994ab27a9ae6d820213 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 15:02:25 +0000 Subject: [PATCH 22/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/decipher/_components.py | 16 +++++----------- src/scvi/external/decipher/_model.py | 8 ++------ src/scvi/external/decipher/_module.py | 12 +++--------- src/scvi/external/decipher/_trainingplan.py | 4 +--- 4 files changed, 11 insertions(+), 29 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 2951528aca..5439190fb9 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence import numpy as np import torch @@ -52,7 +52,7 @@ def __init__( # The multiple outputs are computed as a single output layer, and then split indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:])] + self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -63,21 +63,15 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear( - hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] - ) + torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear( - hidden_dims[-1] + deep_context_dim, self.output_total_dim - ) + torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) ) else: - layers.append( - torch.nn.Linear(input_dim + context_dim, self.output_total_dim) - ) + layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) self.layers = torch.nn.ModuleList(layers) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 75a8db88de..48f730f2eb 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -62,9 +62,7 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -113,9 +111,7 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 433bc83706..c327654527 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -82,9 +82,7 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -125,9 +123,7 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -188,9 +184,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(x) - log_weights.append( - model_trace.log_prob_sum() - guide_trace.log_prob_sum() - ) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) finally: self.beta = old_beta diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index c07e19b18a..adcb9dbbbe 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -41,9 +41,7 @@ def __init__( optim_kwargs.update({"lr": 5e-3}) if "weight_decay" not in optim_kwargs.keys(): optim_kwargs.update({"weight_decay": 1e-4}) - self.optim = ( - pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim - ) + self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim # We let SVI take care of all optimization self.automatic_optimization = False From 13e130e97b4cc0f9389dbfbf22be8c43403507d2 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Tue, 15 Oct 2024 11:05:24 -0400 Subject: [PATCH 23/40] fix ruff --- src/scvi/external/decipher/_components.py | 19 ++++++++++++++----- src/scvi/external/decipher/_model.py | 13 ++++++------- src/scvi/external/decipher/_module.py | 14 ++++++++++---- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 5439190fb9..8ee6fc1070 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -24,7 +24,8 @@ class ConditionalDenseNN(nn.Module): Default: 0. No context input. deep_context_injection : bool (optional) If True, inject the context into every hidden layer. - If False, only inject the context into the first hidden layer (concatenated with the input). + If False, only inject the context into the first hidden layer + (concatenated with the input). Default: False. activation : torch.nn.Module (optional) Activation function to use between hidden layers (not applied to the outputs). @@ -52,7 +53,9 @@ def __init__( # The multiple outputs are computed as a single output layer, and then split indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] + self.output_slices = [ + slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False) + ] # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -63,15 +66,21 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) + torch.nn.Linear( + hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] + ) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) + torch.nn.Linear( + hidden_dims[-1] + deep_context_dim, self.output_total_dim + ) ) else: - layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) + layers.append( + torch.nn.Linear(input_dim + context_dim, self.output_total_dim) + ) self.layers = torch.nn.ModuleList(layers) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 48f730f2eb..2243c52bc3 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -17,11 +17,6 @@ from ._module import DecipherPyroModule from ._trainingplan import DecipherTrainingPlan -if TYPE_CHECKING: - from collections.abc import Sequence - - from anndata import AnnData - logger = logging.getLogger(__name__) @@ -62,7 +57,9 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -111,7 +108,9 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + scdl = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index c327654527..e547d24f34 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -46,7 +46,7 @@ def __init__( dim_v: int = 2, dim_z: int = 10, layers_v_to_z: Sequence[int] = (64,), - layers_z_to_x: Sequence[int] = tuple(), + layers_z_to_x: Sequence[int] = (), beta: float = 0.1, prior: str = "normal", ): @@ -82,7 +82,9 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -123,7 +125,9 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -184,7 +188,9 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(x) - log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) + log_weights.append( + model_trace.log_prob_sum() - guide_trace.log_prob_sum() + ) finally: self.beta = old_beta From 089037c0e40f72772948621f20d1af580a8bb35d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 15:05:37 +0000 Subject: [PATCH 24/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/decipher/_components.py | 16 ++++------------ src/scvi/external/decipher/_model.py | 9 ++------- src/scvi/external/decipher/_module.py | 12 +++--------- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 8ee6fc1070..2f40d4e337 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -53,9 +53,7 @@ def __init__( # The multiple outputs are computed as a single output layer, and then split indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [ - slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False) - ] + self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -66,21 +64,15 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear( - hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] - ) + torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear( - hidden_dims[-1] + deep_context_dim, self.output_total_dim - ) + torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) ) else: - layers.append( - torch.nn.Linear(input_dim + context_dim, self.output_total_dim) - ) + layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) self.layers = torch.nn.ModuleList(layers) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 2243c52bc3..06dd297632 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -1,6 +1,5 @@ import logging from collections.abc import Sequence -from typing import TYPE_CHECKING import numpy as np import pyro @@ -57,9 +56,7 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -108,9 +105,7 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index e547d24f34..682c29d3fa 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -82,9 +82,7 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -125,9 +123,7 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -188,9 +184,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(x) - log_weights.append( - model_trace.log_prob_sum() - guide_trace.log_prob_sum() - ) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) finally: self.beta = old_beta From e2267ec833478a5b66585608b99291dba816ef2f Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 21 Oct 2024 11:15:35 -0400 Subject: [PATCH 25/40] fix tensor device bug for get latent --- src/scvi/external/decipher/_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 06dd297632..ab5a35ce7c 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -110,10 +110,11 @@ def get_latent_representation( for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] x = torch.log1p(x) + x = x.to(self.module.device) z_loc, _ = self.module.encoder_x_to_z(x) if give_z: latent_locs.append(z_loc) else: v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1)) latent_locs.append(v_loc) - return torch.cat(latent_locs).detach().numpy() + return torch.cat(latent_locs).detach().cpu().numpy() From e3562e2d09e025190b097a36bbf762421dc2d05c Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 21 Oct 2024 11:44:41 -0400 Subject: [PATCH 26/40] add docs and release note for decipher --- CHANGELOG.md | 3 ++ docs/references.bib | 9 ++++++ src/scvi/external/decipher/_components.py | 15 ++++------ src/scvi/external/decipher/_model.py | 36 ++++++++++++++++++++++- src/scvi/external/decipher/_module.py | 26 ++++++++-------- 5 files changed, 66 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 046502dbd8..fb35017f4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable + representation learning in single-cell RNA sequencing data {pr}`3015`. + #### Fixed - Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI` diff --git a/docs/references.bib b/docs/references.bib index 5a017924ff..978c7c8673 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -212,6 +212,15 @@ @article{Martens2023 publisher={Nature Publishing Group} } +@article{Nazaret23, + title={Deep generative model deciphers derailed trajectories in acute myeloid leukemia}, + author={Nazaret, Achille and Fan, Joy Linyue and Lavallee, Vincent-Philippe and Cornish, Andrew E and Kiseliovas, Vaidotas and Masilionis, Ignas and Chun, Jaeyoung and Bowman, Robert L and Eisman, Shira E and Wang, James and others}, + journal={bioRxiv}, + pages={2023--11}, + year={2023}, + publisher={Cold Spring Harbor Laboratory} +} + @article{Sheng22, title = {Probabilistic machine learning ensures accurate ambient denoising in droplet-based single-cell omics}, author = {Caibin Sheng and Rui Lopes and Gang Li and Sven Schuierer and Annick Waldt and Rachel Cuttat and Slavica Dimitrieva and Audrey Kauffmann and Eric Durand and Giorgio G. Galli and Guglielmo Roma and Antoine de Weck}, diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 8ee6fc1070..e1efb053ab 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -12,22 +12,19 @@ class ConditionalDenseNN(nn.Module): Parameters ---------- - input_dim : int + input_dim Dimension of the input - hidden_dims : sequence of ints + hidden_dims Dimensions of the hidden layers (excluding the output layer) - output_dims : sequence of ints (optional) + output_dims Dimensions of each output layer - Default: (1,) - context_dim : int (optional) + context_dim Dimension of the context input. - Default: 0. No context input. - deep_context_injection : bool (optional) + deep_context_injection If True, inject the context into every hidden layer. If False, only inject the context into the first hidden layer (concatenated with the input). - Default: False. - activation : torch.nn.Module (optional) + activation Activation function to use between hidden layers (not applied to the outputs). Default: torch.nn.ReLU() """ diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 2243c52bc3..4590bde80d 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -1,6 +1,5 @@ import logging from collections.abc import Sequence -from typing import TYPE_CHECKING import numpy as np import pyro @@ -21,6 +20,27 @@ class Decipher(PyroSviTrainMixin, BaseModelClass): + """Decipher model for single-cell data analysis :cite:p:`Nazaret2023`. + + Parameters + ---------- + adata + AnnData object that has been registered via + :meth:`~scvi.model.Decipher.setup_anndata`. + dim_v + Dimension of the interpretable latent space v. + dim_z + Dimension of the intermediate latent space z. + layers_v_to_z + Hidden layer sizes for the v to z decoder network. + layers_z_to_x + Hidden layer sizes for the z to x decoder network. + beta + Regularization parameter for the KL divergence. + prior + Type of prior distribution to use. + """ + _module_cls = DecipherPyroModule _training_plan_cls = DecipherTrainingPlan @@ -105,6 +125,20 @@ def get_latent_representation( batch_size: int | None = None, give_z: bool = False, ) -> np.ndarray: + """Get the latent representation of the data. + + Parameters + ---------- + adata + AnnData object with the data to get the latent representation of. + indices + Indices of the data to get the latent representation of. + batch_size + Batch size to use for the data loader. + give_z + Whether to return the intermediate latent space z or the top-level + latent space v. + """ self._check_if_trained(warn=False) adata = self._validate_anndata(adata) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index e547d24f34..bf00c120e4 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -24,20 +24,20 @@ class DecipherPyroModule(PyroBaseModuleClass): Parameters ---------- - dim_genes : int + dim_genes Number of genes (features) in the dataset. - dim_v : int, optional - Dimension of the interpretable latent space v. Default is 2. - dim_z : int, optional - Dimension of the intermediate latent space z. Default is 10. - layers_v_to_z : Sequence[int], optional - Hidden layer sizes for the v to z decoder network. Default is (64,). - layers_z_to_x : Sequence[int], optional - Hidden layer sizes for the z to x decoder network. Default is empty tuple. - beta : float, optional - Regularization parameter for the KL divergence. Default is 0.1. - prior : str, optional - Type of prior distribution to use. Default is "normal". + dim_v + Dimension of the interpretable latent space v. + dim_z + Dimension of the intermediate latent space z. + layers_v_to_z + Hidden layer sizes for the v to z decoder network. + layers_z_to_x + Hidden layer sizes for the z to x decoder network. + beta + Regularization parameter for the KL divergence. + prior + Type of prior distribution to use. """ def __init__( From 74f9323b1f8b495e92c9a18ff7ebe6f79d0ab785 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 21 Oct 2024 20:25:02 +0300 Subject: [PATCH 27/40] check if this fixes the cuda test --- .github/workflows/test_linux_cuda.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 0eca4d65c8..e75617cd4c 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -31,10 +31,10 @@ jobs: run: shell: bash -e {0} # -e to fail on error - container: - image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base - #image: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ env.BRANCH_NAME }}-base - options: --user root --gpus all + #container: + # image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base + # image: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ env.BRANCH_NAME }}-base + # options: --user root --gpus all name: integration From ec85018b8b3d01b3c1100e8fbd70908a2940d167 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 21 Oct 2024 20:32:35 +0300 Subject: [PATCH 28/40] check if this fixes the cuda test --- .github/workflows/test_linux_cuda.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index e75617cd4c..0eca4d65c8 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -31,10 +31,10 @@ jobs: run: shell: bash -e {0} # -e to fail on error - #container: - # image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base - # image: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ env.BRANCH_NAME }}-base - # options: --user root --gpus all + container: + image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base + #image: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ env.BRANCH_NAME }}-base + options: --user root --gpus all name: integration From 94c4da6ffcfbd360899e51e8140c4e40fcceb135 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 21 Oct 2024 21:21:16 +0300 Subject: [PATCH 29/40] check if this fixes the cuda test --- .github/workflows/test_linux_cuda.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 0eca4d65c8..e148b86915 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -48,6 +48,8 @@ jobs: # run: echo "BRANCH_NAME=$(echo $GITHUB_REF | awk -F'/' '{print $3}')" >> $GITHUB_ENV - uses: actions/checkout@v4 + with: + fetch-depth: 0 - uses: actions/setup-python@v5 with: From 8dceb5f26273335c67477b09990f219500b10bab Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 21 Oct 2024 21:02:12 -0400 Subject: [PATCH 30/40] add Decipher to docs --- docs/api/developer.md | 1 + docs/api/user.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/api/developer.md b/docs/api/developer.md index 3d3f6504c6..abbb1717b3 100644 --- a/docs/api/developer.md +++ b/docs/api/developer.md @@ -180,6 +180,7 @@ Module classes in the external API with respective generative and inference proc external.velovi.VELOVAE external.mrvi.MRVAE external.methylvi.METHYLVAE + external.decipher.DecipherPyroModule ``` diff --git a/docs/api/user.md b/docs/api/user.md index 10da50816b..825815d64a 100644 --- a/docs/api/user.md +++ b/docs/api/user.md @@ -62,6 +62,7 @@ import scvi external.VELOVI external.MRVI external.METHYLVI + external.Decipher ``` ## Data loading From a06d9368b36dd09722966d2ec2382c1ea93446d1 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Tue, 22 Oct 2024 13:37:07 -0400 Subject: [PATCH 31/40] fix docs --- src/scvi/external/decipher/_model.py | 2 +- src/scvi/external/decipher/_module.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 637e380ad1..3e8115996b 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -20,7 +20,7 @@ class Decipher(PyroSviTrainMixin, BaseModelClass): - """Decipher model for single-cell data analysis :cite:p:`Nazaret2023`. + """Decipher model for single-cell data analysis :cite:p:`Nazaret23`. Parameters ---------- diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 23011aa90c..2245a6a6dd 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -17,7 +17,7 @@ class DecipherPyroModule(PyroBaseModuleClass): - """Decipher PyroModule for single-cell data analysis. + """Pyro Module for the Decipher model. This module implements the Decipher model for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data. From 17e7d4a8f6203d7a9041efe62c1849ce4e6340a1 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Tue, 22 Oct 2024 13:48:56 -0400 Subject: [PATCH 32/40] add doi --- docs/references.bib | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/references.bib b/docs/references.bib index 978c7c8673..3baa082b55 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -218,7 +218,8 @@ @article{Nazaret23 journal={bioRxiv}, pages={2023--11}, year={2023}, - publisher={Cold Spring Harbor Laboratory} + publisher={Cold Spring Harbor Laboratory}, + doi = {10.1101/2023.11.11.566719} } @article{Sheng22, From c49e14d68196a48e6a104ce555252b6a2da3ab14 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Sun, 3 Nov 2024 12:44:44 -0500 Subject: [PATCH 33/40] remove prior arg --- src/scvi/external/decipher/_model.py | 10 +++++---- src/scvi/external/decipher/_module.py | 30 ++++++++++----------------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 637e380ad1..a2092caeb8 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -37,8 +37,6 @@ class Decipher(PyroSviTrainMixin, BaseModelClass): Hidden layer sizes for the z to x decoder network. beta Regularization parameter for the KL divergence. - prior - Type of prior distribution to use. """ _module_cls = DecipherPyroModule @@ -77,7 +75,9 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -140,7 +140,9 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + scdl = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 23011aa90c..7b06b057dc 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -36,8 +36,6 @@ class DecipherPyroModule(PyroBaseModuleClass): Hidden layer sizes for the z to x decoder network. beta Regularization parameter for the KL divergence. - prior - Type of prior distribution to use. """ def __init__( @@ -48,7 +46,6 @@ def __init__( layers_v_to_z: Sequence[int] = (64,), layers_z_to_x: Sequence[int] = (), beta: float = 0.1, - prior: str = "normal", ): super().__init__() self.dim_v = dim_v @@ -57,7 +54,6 @@ def __init__( self.layers_v_to_z = layers_v_to_z self.layers_z_to_x = layers_z_to_x self.beta = beta - self.prior = prior self.decoder_v_to_z = ConditionalDenseNN(dim_v, layers_v_to_z, [dim_z] * 2) self.decoder_z_to_x = ConditionalDenseNN(dim_z, layers_z_to_x, [dim_genes]) @@ -82,7 +78,9 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -101,12 +99,7 @@ def model(self, x: torch.Tensor): poutine.scale(scale=1.0), ): with poutine.scale(scale=self.beta): - if self.prior == "normal": - prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) - elif self.prior == "gamma": - prior = dist.Gamma(0.3, x.new_ones(self.dim_v) * 0.8).to_event(1) - else: - raise ValueError("Invalid prior, must be normal or gamma") + prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) v = pyro.sample("v", prior) z_loc, z_scale = self.decoder_v_to_z(v) @@ -123,7 +116,9 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -144,12 +139,7 @@ def guide(self, x: torch.Tensor): v_loc, v_scale = self.encoder_zx_to_v(zx) v_scale = softplus(v_scale) with poutine.scale(scale=self.beta): - if self.prior == "gamma": - posterior_v = dist.Gamma(softplus(v_loc), v_scale).to_event(1) - elif self.prior == "normal" or self.prior == "student-normal": - posterior_v = dist.Normal(v_loc, v_scale).to_event(1) - else: - raise ValueError("Invalid prior, must be normal or gamma") + posterior_v = dist.Normal(v_loc, v_scale).to_event(1) pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale @@ -184,7 +174,9 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(x) - log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) + log_weights.append( + model_trace.log_prob_sum() - guide_trace.log_prob_sum() + ) finally: self.beta = old_beta From e25173b6eda8fffc17800ab854b30f7eee06e704 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Nov 2024 17:45:13 +0000 Subject: [PATCH 34/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/decipher/_model.py | 8 ++------ src/scvi/external/decipher/_module.py | 12 +++--------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 5cfde98e35..50841b3ddc 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -75,9 +75,7 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -140,9 +138,7 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 889f3f432f..f96fa2dd70 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -78,9 +78,7 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -116,9 +114,7 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -174,9 +170,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(x) - log_weights.append( - model_trace.log_prob_sum() - guide_trace.log_prob_sum() - ) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) finally: self.beta = old_beta From 43ab670ce57294bbfb0ca34696583a143456c512 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Tue, 12 Nov 2024 19:48:24 -0500 Subject: [PATCH 35/40] address comments --- src/scvi/external/decipher/_components.py | 29 +++++++++-------- src/scvi/external/decipher/_module.py | 38 +++++++++++------------ 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 46e457a018..660048f1f6 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -50,7 +50,9 @@ def __init__( # The multiple outputs are computed as a single output layer, and then split indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] + self.output_slices = [ + slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False) + ] # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -61,19 +63,25 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) + torch.nn.Linear( + hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] + ) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) + torch.nn.Linear( + hidden_dims[-1] + deep_context_dim, self.output_total_dim + ) ) else: - layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) + layers.append( + torch.nn.Linear(input_dim + context_dim, self.output_total_dim) + ) self.layers = torch.nn.ModuleList(layers) - self.f = activation + self.activation_fn = activation self.batch_norms = torch.nn.ModuleList(batch_norms) def forward(self, x, context=None): @@ -88,15 +96,10 @@ def forward(self, x, context=None): h = layer(h) if i < len(self.layers) - 1: h = self.batch_norms[i](h) - h = self.f(h) + h = self.activation_fn(h) if self.n_output_layers == 1: return h - else: - h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim]) - - if self.n_output_layers == 1: - return h - else: - return tuple([h[..., s] for s in self.output_slices]) + h = h.reshape(list(x.size()[:-1]) + [self.output_total_dim]) + return tuple([h[..., s] for s in self.output_slices]) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index f96fa2dd70..87c1146d4d 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -15,6 +15,8 @@ from ._components import ConditionalDenseNN +TEST_BETA = 1.0 + class DecipherPyroModule(PyroBaseModuleClass): """Pyro Module for the Decipher model. @@ -61,7 +63,7 @@ def __init__( self.encoder_zx_to_v = ConditionalDenseNN( dim_genes + dim_z, [128], - [dim_v, dim_v], + [dim_v] * 2, ) self.theta = None @@ -78,12 +80,14 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: + def _get_fn_args_from_batch( + tensor_dict: dict[str, torch.Tensor] + ) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @auto_move_data - def model(self, x: torch.Tensor): + def model(self, x: torch.Tensor, beta: float | None = None): pyro.module("decipher", self) self.theta = pyro.param( @@ -96,7 +100,7 @@ def model(self, x: torch.Tensor): pyro.plate("batch", x.shape[0]), poutine.scale(scale=1.0), ): - with poutine.scale(scale=self.beta): + with poutine.scale(scale=beta or self.beta): prior = dist.Normal(0, x.new_ones(self.dim_v)).to_event(1) v = pyro.sample("v", prior) @@ -114,11 +118,13 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) + x_dist = dist.NegativeBinomial( + total_count=self.theta + self._epsilon, logits=logit + ) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data - def guide(self, x: torch.Tensor): + def guide(self, x: torch.Tensor, beta: float | None = None): pyro.module("decipher", self) with ( pyro.plate("batch", x.shape[0]), @@ -134,7 +140,7 @@ def guide(self, x: torch.Tensor): zx = torch.cat([z, x], dim=-1) v_loc, v_scale = self.encoder_zx_to_v(zx) v_scale = softplus(v_scale) - with poutine.scale(scale=self.beta): + with poutine.scale(scale=beta or self.beta): posterior_v = dist.Normal(v_loc, v_scale).to_event(1) pyro.sample("v", posterior_v) return z_loc, v_loc, z_scale, v_scale @@ -162,18 +168,12 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): The average estimated predictive log-likelihood across multiple runs. """ log_weights = [] - old_beta = self.beta - self.beta = 1.0 - try: - for _ in range(n_samples): - guide_trace = poutine.trace(self.guide).get_trace(x) - model_trace = poutine.trace( - poutine.replay(self.model, trace=guide_trace) - ).get_trace(x) - log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) - - finally: - self.beta = old_beta + for _ in range(n_samples): + guide_trace = poutine.trace(self.guide).get_trace(x, beta=TEST_BETA) + model_trace = poutine.trace( + poutine.replay(self.model, trace=guide_trace) + ).get_trace(x, beta=TEST_BETA) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) return log_z.item() From 9850619d018595e54aa6821704a301391a0922b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 00:48:38 +0000 Subject: [PATCH 36/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/decipher/_components.py | 16 ++++------------ src/scvi/external/decipher/_module.py | 14 +++++--------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 660048f1f6..8900ddab9f 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -50,9 +50,7 @@ def __init__( # The multiple outputs are computed as a single output layer, and then split indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [ - slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False) - ] + self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -63,21 +61,15 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear( - hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] - ) + torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear( - hidden_dims[-1] + deep_context_dim, self.output_total_dim - ) + torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) ) else: - layers.append( - torch.nn.Linear(input_dim + context_dim, self.output_total_dim) - ) + layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) self.layers = torch.nn.ModuleList(layers) diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 87c1146d4d..fb864ebe48 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -80,9 +80,7 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -118,9 +116,7 @@ def model(self, x: torch.Tensor, beta: float | None = None): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -170,9 +166,9 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): log_weights = [] for _ in range(n_samples): guide_trace = poutine.trace(self.guide).get_trace(x, beta=TEST_BETA) - model_trace = poutine.trace( - poutine.replay(self.model, trace=guide_trace) - ).get_trace(x, beta=TEST_BETA) + model_trace = poutine.trace(poutine.replay(self.model, trace=guide_trace)).get_trace( + x, beta=TEST_BETA + ) log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) log_z = torch.logsumexp(torch.tensor(log_weights) - np.log(n_samples), 0) From 61df480bfdfa1c27d706ddf6f7757cdd9ade2167 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Wed, 13 Nov 2024 10:09:27 -0500 Subject: [PATCH 37/40] Update test_linux_cuda.yml --- .github/workflows/test_linux_cuda.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index e148b86915..0eca4d65c8 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -48,8 +48,6 @@ jobs: # run: echo "BRANCH_NAME=$(echo $GITHUB_REF | awk -F'/' '{print $3}')" >> $GITHUB_ENV - uses: actions/checkout@v4 - with: - fetch-depth: 0 - uses: actions/setup-python@v5 with: From 26131c56cc90db73b27bca2c937b342af01e6049 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Wed, 13 Nov 2024 10:53:46 -0500 Subject: [PATCH 38/40] simplify output slice code --- src/scvi/external/decipher/_components.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 8900ddab9f..f49d6de203 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -49,8 +49,13 @@ def __init__( self.output_total_dim = sum(self.output_dims) # The multiple outputs are computed as a single output layer, and then split - indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] + last_output_end_idx = 0 + self.output_slices = [] + for dim in self.output_dims: + self.output_slices.append( + slice(last_output_end_idx, last_output_end_idx + dim) + ) + last_output_end_idx += dim # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -61,15 +66,21 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) + torch.nn.Linear( + hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] + ) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) + torch.nn.Linear( + hidden_dims[-1] + deep_context_dim, self.output_total_dim + ) ) else: - layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) + layers.append( + torch.nn.Linear(input_dim + context_dim, self.output_total_dim) + ) self.layers = torch.nn.ModuleList(layers) From c0eda75caedc9c63cdfbead754ca2cee1716e307 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 15:54:02 +0000 Subject: [PATCH 39/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/decipher/_components.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index f49d6de203..12900b29a6 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -import numpy as np import torch import torch.nn as nn @@ -52,9 +51,7 @@ def __init__( last_output_end_idx = 0 self.output_slices = [] for dim in self.output_dims: - self.output_slices.append( - slice(last_output_end_idx, last_output_end_idx + dim) - ) + self.output_slices.append(slice(last_output_end_idx, last_output_end_idx + dim)) last_output_end_idx += dim # Create masked layers @@ -66,21 +63,15 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear( - hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] - ) + torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear( - hidden_dims[-1] + deep_context_dim, self.output_total_dim - ) + torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) ) else: - layers.append( - torch.nn.Linear(input_dim + context_dim, self.output_total_dim) - ) + layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) self.layers = torch.nn.ModuleList(layers) From d8c5ac464628b167bbcfdecb4ba603c6443e1440 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Wed, 13 Nov 2024 16:29:15 -0500 Subject: [PATCH 40/40] move change in changelog to 1.3.0 --- CHANGELOG.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb35017f4e..7dd7c35377 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,13 +6,23 @@ to [Semantic Versioning]. Full commit history is available in the ## Version 1.2 -### 1.2.1 (2024-XX-XX) +### 1.3.0 (2024-XX-XX) #### Added +#### Fixed + +#### Changed + +#### Removed + - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`. +### 1.2.1 (2024-XX-XX) + +#### Added + #### Fixed - Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`