Skip to content

Commit

Permalink
fix tests, remove validation step from base training plan
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Oct 15, 2024
1 parent e48fe3b commit 2baf11d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 105 deletions.
14 changes: 10 additions & 4 deletions src/scvi/external/decipher/_trainingplan.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(
optim=self.optim,
loss=self.loss_fn,
)
self.validation_step_outputs = []

# See configure_optimizers for what this does
self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0]))

Expand Down Expand Up @@ -96,12 +98,16 @@ def on_train_epoch_end(self):

def validation_step(self, batch, batch_idx):
"""Validation step for Pyro training."""
out_dict = super().validation_step(batch, batch_idx)
args, kwargs = self.module._get_fn_args_from_batch(batch)
loss = self.differentiable_loss_fn(
self.scale_fn(self.module.model),
self.scale_fn(self.module.guide),
*args,
**kwargs,
)
nll = -self.module.predictive_log_likelihood(*args, **kwargs, n_samples=5)
out_dict["nll"] = nll
out_dict["n_obs"] = args[0].shape[0]
self.validation_step_outputs[-1].update(out_dict)
out_dict = {"loss": loss, "nll": nll, "n_obs": args[0].shape[0]}
self.validation_step_outputs.append(out_dict)
return out_dict

def on_validation_epoch_end(self):
Expand Down
119 changes: 18 additions & 101 deletions src/scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ def __init__(
self.optimizer_creator = optimizer_creator

if self.optimizer_name == "Custom" and self.optimizer_creator is None:
raise ValueError(
"If optimizer is 'Custom', `optimizer_creator` must be provided."
)
raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.")

self._n_obs_training = None
self._n_obs_validation = None
Expand Down Expand Up @@ -221,9 +219,7 @@ def initialize_train_metrics(self):
self.kl_local_train,
self.kl_global_train,
self.train_metrics,
) = self._create_elbo_metric_components(
mode="train", n_total=self.n_obs_training
)
) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training)
self.elbo_train.reset()

def initialize_val_metrics(self):
Expand All @@ -234,9 +230,7 @@ def initialize_val_metrics(self):
self.kl_local_val,
self.kl_global_val,
self.val_metrics,
) = self._create_elbo_metric_components(
mode="validation", n_total=self.n_obs_validation
)
) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation)
self.elbo_val.reset()

@property
Expand Down Expand Up @@ -372,9 +366,7 @@ def validation_step(self, batch, batch_idx):
)
self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation")

def _optimizer_creator_fn(
self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW
):
def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW):
"""Create optimizer for the model.
This type of function can be passed as the `optimizer_creator`
Expand Down Expand Up @@ -552,9 +544,7 @@ def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True):
if predict_true_class:
cls_target = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes)
else:
one_hot_batch = torch.nn.functional.one_hot(
batch_index.squeeze(-1), n_classes
)
one_hot_batch = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes)
# place zeroes where true label is
cls_target = (~one_hot_batch.bool()).float()
cls_target = cls_target / (n_classes - 1)
Expand Down Expand Up @@ -582,9 +572,7 @@ def training_step(self, batch, batch_idx):
else:
opt1, opt2 = opts

inference_outputs, _, scvi_loss = self.forward(
batch, loss_kwargs=self.loss_kwargs
)
inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
z = inference_outputs["z"]
loss = scvi_loss.loss
# fool classifier if doing adversarial training
Expand Down Expand Up @@ -617,10 +605,7 @@ def on_train_epoch_end(self):

def on_validation_epoch_end(self) -> None:
"""Update the learning rate via scheduler steps."""
if (
not self.reduce_lr_on_plateau
or "validation" not in self.lr_scheduler_metric
):
if not self.reduce_lr_on_plateau or "validation" not in self.lr_scheduler_metric:
return
else:
sch = self.lr_schedulers()
Expand Down Expand Up @@ -651,9 +636,7 @@ def configure_optimizers(self):
)

if self.adversarial_classifier is not False:
params2 = filter(
lambda p: p.requires_grad, self.adversarial_classifier.parameters()
)
params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters())
optimizer2 = torch.optim.Adam(
params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay
)
Expand Down Expand Up @@ -906,7 +889,6 @@ def __init__(
super().__init__()
self.module = pyro_module
self._n_obs_training = None
self._n_obs_validation = None

optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {}
if "lr" not in optim_kwargs.keys():
Expand All @@ -919,18 +901,15 @@ def __init__(
self.n_epochs_kl_warmup = n_epochs_kl_warmup
self.use_kl_weight = False
if isinstance(self.module.model, PyroModule):
self.use_kl_weight = (
"kl_weight" in signature(self.module.model.forward).parameters
)
self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters
elif callable(self.module.model):
self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters
self.scale_elbo = scale_elbo
self.scale_fn = lambda obj: (
pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj
self.scale_fn = (
lambda obj: pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj
)
self.differentiable_loss_fn = self.loss_fn.differentiable_loss
self.training_step_outputs = []
self.validation_step_outputs = []

def training_step(self, batch, batch_idx):
"""Training step for Pyro training."""
Expand Down Expand Up @@ -963,37 +942,6 @@ def on_train_epoch_end(self):
self.log("elbo_train", elbo, prog_bar=True)
self.training_step_outputs.clear()

def validation_step(self, batch, batch_idx):
"""Validation step for Pyro training."""
args, kwargs = self.module._get_fn_args_from_batch(batch)
# Set KL weight if necessary.
# Note: if applied, ELBO loss in progress bar is the effective KL annealed loss, not the
# true ELBO.
if self.use_kl_weight:
kwargs.update({"kl_weight": self.kl_weight})
# pytorch lightning requires a Tensor object for loss
loss = self.differentiable_loss_fn(
self.scale_fn(self.module.model),
self.scale_fn(self.module.guide),
*args,
**kwargs,
)
out_dict = {"loss": loss}
self.validation_step_outputs.append(out_dict)
return out_dict

def on_validation_epoch_end(self):
"""Validation epoch end for Pyro training."""
outputs = self.validation_step_outputs
elbo = 0
n = 0
for out in outputs:
elbo += out["loss"]
n += 1
elbo /= n
self.log("elbo_validation", elbo, prog_bar=True)
self.validation_step_outputs.clear()

def configure_optimizers(self):
"""Configure optimizers for the model."""
return self.optim(self.module.parameters(), **self.optim_kwargs)
Expand Down Expand Up @@ -1033,27 +981,6 @@ def n_obs_training(self, n_obs: int):

self._n_obs_training = n_obs

@property
def n_obs_validation(self):
"""Number of validation examples.
If not `None`, updates the `n_obs` attr
of the Pyro module's `model` and `guide`, if they exist.
"""
return self._n_obs_validation

@n_obs_validation.setter
def n_obs_validation(self, n_obs: int):
if n_obs is not None:
if hasattr(self.module, "n_obs"):
self.module.n_obs = n_obs
if hasattr(self.module.model, "n_obs"):
self.module.model.n_obs = n_obs
if hasattr(self.module.guide, "n_obs"):
self.module.guide.n_obs = n_obs

self._n_obs_validation = n_obs


class PyroTrainingPlan(LowLevelPyroTrainingPlan):
"""Lightning module task to train Pyro scvi-tools modules.
Expand Down Expand Up @@ -1102,9 +1029,7 @@ def __init__(
optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {}
if "lr" not in optim_kwargs.keys():
optim_kwargs.update({"lr": 1e-3})
self.optim = (
pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim
)
self.optim = pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim
# We let SVI take care of all optimization
self.automatic_optimization = False

Expand Down Expand Up @@ -1200,9 +1125,7 @@ def __init__(
self.loss_fn = loss()

if self.module.logits is False and loss == torch.nn.CrossEntropyLoss:
raise UserWarning(
"classifier should return logits when using CrossEntropyLoss."
)
raise UserWarning("classifier should return logits when using CrossEntropyLoss.")

def forward(self, *args, **kwargs):
"""Passthrough to the module's forward function."""
Expand Down Expand Up @@ -1232,9 +1155,7 @@ def configure_optimizers(self):
optim_cls = torch.optim.AdamW
else:
raise ValueError("Optimizer not understood.")
optimizer = optim_cls(
params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay
)
optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay)

return optimizer

Expand Down Expand Up @@ -1300,11 +1221,7 @@ def __init__(

def get_optimizer_creator(self) -> JaxOptimizerCreator:
"""Get optimizer creator for the model."""
clip_by = (
optax.clip_by_global_norm(self.max_norm)
if self.max_norm
else optax.identity()
)
clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity()
if self.optimizer_name == "Adam":
# Replicates PyTorch Adam defaults
optim = optax.chain(
Expand Down Expand Up @@ -1358,9 +1275,9 @@ def loss_fn(params):
loss = loss_output.loss
return loss, (loss_output, new_model_state)

(loss, (loss_output, new_model_state)), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state.params)
(loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state.params
)
new_state = state.apply_gradients(grads=grads, state=new_model_state)
return new_state, loss, loss_output

Expand Down

0 comments on commit 2baf11d

Please sign in to comment.