Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

continue training of already saved model (extending TrainRunner) #1065

Closed
vitkl opened this issue May 16, 2021 · 11 comments · Fixed by #1091
Closed

continue training of already saved model (extending TrainRunner) #1065

vitkl opened this issue May 16, 2021 · 11 comments · Fixed by #1091
Assignees

Comments

@vitkl
Copy link
Contributor

vitkl commented May 16, 2021

Would be great to have the option to continue the training of the already saved model. @adamgayoso said this needs to go into the TrainRunner. One thing you need to do is to combine old and new training history like this:

if continue_training and self.is_trained_:
    # add ELBO listory
    index = range(
      len(self.module.history_),
      len(self.module.history_)
      + len(trainer.logger.history["train_loss_epoch"]),
    )
    trainer.logger.history["train_loss_epoch"].index = index
    self.module.history_ = pd.concat(
      [self.module.history_, trainer.logger.history["train_loss_epoch"]]
    )
else:
    self.module.history_ = trainer.logger.history["train_loss_epoch"]
    self.history_ = self.module.history_
@njbernstein
Copy link
Contributor

plus 1 for this

@adamgayoso
Copy link
Member

So there is the simple way where we can just maintain the old history so it's not overwritten; however, if there is a desire to also have the whole optimizer state/learning rate scheduler state we have to do more engineering. Thoughts? In the first case (simple way) it would be a fresh optimizer and schedulers.

@njbernstein
Copy link
Contributor

I only want the old history not overwritten for what its worth

@adamgayoso
Copy link
Member

In the more complicated case, we'd have to

  1. use save_hyperparameters in the training plans and avoid saving the modules (see here)
  2. save function would have to create a pytorch lightning checkpoint from the trainer attribute of the model (self)

Then to the train methods we can maybe add a parameter like continue_from_checkpoint: Path that you give the path of the save directory and the train method will then load the training plan from the checkpoint.

@adamgayoso
Copy link
Member

So in either simple or complex case, we can do the following:

Change this line
https://github.com/YosefLab/scvi-tools/blob/a0a608912aff56e94bb89b9e8c4f122a6c776500/scvi/train/_trainrunner.py#L75

to self.model.history = check_and_extend_history(self.trainer.logger.history) where if the history is not None, it extends it

@vitkl
Copy link
Contributor Author

vitkl commented Jun 5, 2021

I think it's ok to create a new optimiser when continuing training (this is what pymc3 does by the way) - just load state param dict and continue history. @adamgayoso is this what you mean by a simple case?

@vitkl
Copy link
Contributor Author

vitkl commented Jun 5, 2021

My use case for this is 'train->save->potentially start new cluster job->load->continue training'. One problem with this which I see now is that when the saved model is loaded, one training step is run, and the training history is lost - this will solve that issue, right?

@adamgayoso
Copy link
Member

I think it's ok to create a new optimiser when continuing training (this is what pymc3 does by the way)

For Pyro based models we might not have a choice. In general though I think it would be nice to maintain the gradient information for optimizers like Adam (which is part of the "complex" solution, though easy with pytorch lightning)

One problem with this which I see now is that when the saved model is loaded, one training step is run, and the training history is lost - this will solve that issue, right?

Yes, but again, Pyro models need some special care.

@vitkl
Copy link
Contributor Author

vitkl commented Jun 10, 2021

I see. In my opinion, a simple solution should just preserve history, including when the models are loaded.

Is it necessary to train a loaded model for 1 iteration? If this is done just to initialise the guide properly - then maybe this can be done in evaluation mode? For example, just using svi.evaluate_loss rather than svi.step for both training and validation data? #1073

@vitkl
Copy link
Contributor Author

vitkl commented Jun 21, 2021

So this commit, f9652f2, solves the issue for loading models but keeping history when continuing training remains to be addressed, right?

@adamgayoso
Copy link
Member

Yes this issue remains to be addressed (we are getting there). That commit you referenced fixes the loading issue for pyro models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants