Skip to content

Commit

Permalink
drop last and fix loss scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Oct 15, 2024
1 parent 7fe1beb commit e48fe3b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
3 changes: 3 additions & 0 deletions src/scvi/external/decipher/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def train(
):
if "early_stopping_monitor" not in trainer_kwargs:
trainer_kwargs["early_stopping_monitor"] = "nll_validation"
datasplitter_kwargs = datasplitter_kwargs or {}
if "drop_last" not in datasplitter_kwargs:
datasplitter_kwargs["drop_last"] = True
super().train(
max_epochs=max_epochs,
accelerator=accelerator,
Expand Down
2 changes: 0 additions & 2 deletions src/scvi/external/decipher/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5):
Parameters
----------
decipher_module : PyroBaseModuleClass
The Decipher module to evaluate.
x : torch.Tensor
Batch of data to compute the log-likelihood for.
n_samples : int, optional
Expand Down
40 changes: 24 additions & 16 deletions src/scvi/external/decipher/_trainingplan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(
)
optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {}
if "lr" not in optim_kwargs.keys():
optim_kwargs.update({"lr": 5e-3, "weight_decay": 1e-4})
optim_kwargs.update({"lr": 5e-3})
if "weight_decay" not in optim_kwargs.keys():
optim_kwargs.update({"weight_decay": 1e-4})
self.optim = (
pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim
)
Expand All @@ -54,38 +56,41 @@ def __init__(
# See configure_optimizers for what this does
self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0]))

def on_validation_model_train(self):
"""Prepare the model for validation by switching to train mode."""
super().on_validation_model_train()
if self.current_epoch > 0:
# freeze the batch norm layers after the first epoch
# 1) the batch norm layers helps with the initialization
# 2) but then, they seem to imply a strong normal prior on the latent space
for module in self.module.modules():
if isinstance(module, torch.nn.BatchNorm1d):
module.eval()

def training_step(self, batch, batch_idx):
"""Training step for Pyro training."""
args, kwargs = self.module._get_fn_args_from_batch(batch)

# pytorch lightning requires a Tensor object for loss
loss = torch.Tensor([self.svi.step(*args, **kwargs)])
n_obs = args[0].shape[0]

_opt = self.optimizers()
_opt.step()

out_dict = {"loss": loss}
out_dict = {"loss": loss, "n_obs": n_obs}
self.training_step_outputs.append(out_dict)
return out_dict

def on_validation_model_train(self):
"""Prepare the model for validation by switching to train mode."""
super().on_validation_model_train()
if self.current_epoch > 0:
# freeze the batch norm layers after the first epoch
# 1) the batch norm layers helps with the initialization
# 2) but then, they seem to imply a strong normal prior on the latent space
for module in self.module.modules():
if isinstance(module, torch.nn.BatchNorm1d):
module.eval()

def on_train_epoch_end(self):
"""Training epoch end for Pyro training."""
outputs = self.training_step_outputs
elbo = 0
n_obs = 0
for out in outputs:
elbo += out["loss"]
elbo /= self.n_obs_training
n_obs += out["n_obs"]
elbo /= n_obs
self.log("elbo_train", elbo, prog_bar=True)
self.training_step_outputs.clear()

Expand All @@ -95,6 +100,7 @@ def validation_step(self, batch, batch_idx):
args, kwargs = self.module._get_fn_args_from_batch(batch)
nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5)
out_dict["nll"] = nll
out_dict["n_obs"] = args[0].shape[0]
self.validation_step_outputs[-1].update(out_dict)
return out_dict

Expand All @@ -103,11 +109,13 @@ def on_validation_epoch_end(self):
outputs = self.validation_step_outputs
elbo = 0
nll = 0
n_obs = 0
for out in outputs:
elbo += out["loss"]
nll += out["nll"]
elbo /= self.n_obs_validation
nll /= self.n_obs_validation
n_obs += out["n_obs"]
elbo /= n_obs
nll /= n_obs
self.log("elbo_validation", elbo, prog_bar=True)
self.log("nll_validation", nll, prog_bar=False)
self.validation_step_outputs.clear()
Expand Down
8 changes: 2 additions & 6 deletions tests/external/mrvi/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def test_mrvi_da(model, sample_key, da_kwargs):
},
],
)
def test_mrvi_model_kwargs(
adata: AnnData, model_kwargs: dict[str, Any], save_path: str
):
def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str):
MRVI.setup_anndata(
adata,
sample_key="sample_str",
Expand All @@ -175,9 +173,7 @@ def test_mrvi_model_kwargs(
def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str):
sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"]
sample_subset = [chr(i + ord("a")) for i in range(8)]
model.differential_expression(
sample_cov_keys=sample_cov_keys, sample_subset=sample_subset
)
model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset)

model_path = os.path.join(save_path, "mrvi_model")
model.save(model_path, save_anndata=False, overwrite=True)
Expand Down

0 comments on commit e48fe3b

Please sign in to comment.