Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option "always save model" #1621

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions deeppavlov/core/trainers/nn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class NNTrainer(FitTrainer):
log_on_k_batches: count of random train batches to calculate metrics in log (default is ``1``)
max_test_batches: maximum batches count for pipeline testing and evaluation, overrides ``log_on_k_batches``,
ignored if negative (default is ``-1``)
always_save_model: if True, we always save the obtained weights of our model, regardless of the metric.
(default if ``False``)
**kwargs: additional parameters whose names will be logged but otherwise ignored


Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(self, chainer_config: dict, *,
validate_first: bool = True,
validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1,
log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1,
always_save_model: bool = False,
**kwargs) -> None:
super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets,
show_examples=show_examples, max_test_batches=max_test_batches, **kwargs)
Expand Down Expand Up @@ -141,6 +144,7 @@ def _improved(op):
self.max_epochs = epochs
self.epoch = start_epoch_num
self.max_batches = max_batches
self.always_save_model = always_save_model

self.train_batches_seen = 0
self.examples = 0
Expand Down Expand Up @@ -207,6 +211,11 @@ def _validate(self, iterator: DataLearningIterator,
self.score_best = score
log.info('Saving model')
self.save()
elif self.always_save_model:
log.info(f'Changed {m_name} from {self.score_best} to {score}')
self.score_best = score
log.info('But due to always_save_model, saving the model')
self.save()
else:
log.info('Did not improve on the {} of {}'.format(m_name, self.score_best))

Expand Down