From e48fe3be747f7b553b8e0adb1862ae23a4a11701 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 14 Oct 2024 21:13:06 -0400 Subject: [PATCH] drop last and fix loss scaling --- src/scvi/external/decipher/_model.py | 3 ++ src/scvi/external/decipher/_module.py | 2 -- src/scvi/external/decipher/_trainingplan.py | 40 ++++++++++++--------- tests/external/mrvi/test_model.py | 8 ++--- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 0e74e74ab2..75a8db88de 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -85,6 +85,9 @@ def train( ): if "early_stopping_monitor" not in trainer_kwargs: trainer_kwargs["early_stopping_monitor"] = "nll_validation" + datasplitter_kwargs = datasplitter_kwargs or {} + if "drop_last" not in datasplitter_kwargs: + datasplitter_kwargs["drop_last"] = True super().train( max_epochs=max_epochs, accelerator=accelerator, diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index aa264fd1bc..433bc83706 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -169,8 +169,6 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): Parameters ---------- - decipher_module : PyroBaseModuleClass - The Decipher module to evaluate. x : torch.Tensor Batch of data to compute the log-likelihood for. n_samples : int, optional diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 32371eceb5..992ad164a8 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -38,7 +38,9 @@ def __init__( ) optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): - optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4}) + optim_kwargs.update({"lr": 5e-3}) + if "weight_decay" not in optim_kwargs.keys(): + optim_kwargs.update({"weight_decay": 1e-4}) self.optim = ( pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim ) @@ -54,38 +56,41 @@ def __init__( # See configure_optimizers for what this does self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) + def on_validation_model_train(self): + """Prepare the model for validation by switching to train mode.""" + super().on_validation_model_train() + if self.current_epoch > 0: + # freeze the batch norm layers after the first epoch + # 1) the batch norm layers helps with the initialization + # 2) but then, they seem to imply a strong normal prior on the latent space + for module in self.module.modules(): + if isinstance(module, torch.nn.BatchNorm1d): + module.eval() + def training_step(self, batch, batch_idx): """Training step for Pyro training.""" args, kwargs = self.module._get_fn_args_from_batch(batch) # pytorch lightning requires a Tensor object for loss loss = torch.Tensor([self.svi.step(*args, **kwargs)]) + n_obs = args[0].shape[0] _opt = self.optimizers() _opt.step() - out_dict = {"loss": loss} + out_dict = {"loss": loss, "n_obs": n_obs} self.training_step_outputs.append(out_dict) return out_dict - def on_validation_model_train(self): - """Prepare the model for validation by switching to train mode.""" - super().on_validation_model_train() - if self.current_epoch > 0: - # freeze the batch norm layers after the first epoch - # 1) the batch norm layers helps with the initialization - # 2) but then, they seem to imply a strong normal prior on the latent space - for module in self.module.modules(): - if isinstance(module, torch.nn.BatchNorm1d): - module.eval() - def on_train_epoch_end(self): """Training epoch end for Pyro training.""" outputs = self.training_step_outputs elbo = 0 + n_obs = 0 for out in outputs: elbo += out["loss"] - elbo /= self.n_obs_training + n_obs += out["n_obs"] + elbo /= n_obs self.log("elbo_train", elbo, prog_bar=True) self.training_step_outputs.clear() @@ -95,6 +100,7 @@ def validation_step(self, batch, batch_idx): args, kwargs = self.module._get_fn_args_from_batch(batch) nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5) out_dict["nll"] = nll + out_dict["n_obs"] = args[0].shape[0] self.validation_step_outputs[-1].update(out_dict) return out_dict @@ -103,11 +109,13 @@ def on_validation_epoch_end(self): outputs = self.validation_step_outputs elbo = 0 nll = 0 + n_obs = 0 for out in outputs: elbo += out["loss"] nll += out["nll"] - elbo /= self.n_obs_validation - nll /= self.n_obs_validation + n_obs += out["n_obs"] + elbo /= n_obs + nll /= n_obs self.log("elbo_validation", elbo, prog_bar=True) self.log("nll_validation", nll, prog_bar=False) self.validation_step_outputs.clear() diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 5e3147c87e..05edb27496 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -155,9 +155,7 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs( - adata: AnnData, model_kwargs: dict[str, Any], save_path: str -): +def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -175,9 +173,7 @@ def test_mrvi_model_kwargs( def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression( - sample_cov_keys=sample_cov_keys, sample_subset=sample_subset - ) + model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True)