Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issues 1418, 1420 #1549

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions deeppavlov/core/common/chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pickle
from itertools import islice
from pathlib import Path
from logging import getLogger
from types import FunctionType
from typing import Union, Tuple, List, Optional, Hashable, Reversible
Expand Down Expand Up @@ -275,10 +276,10 @@ def get_main_component(self) -> Optional[Serializable]:
log.warning('Cannot get a main component for an empty chainer')
return None

def save(self) -> None:
def save(self, fname: Optional[Union[str, Path]] = None) -> None:
main_component = self.get_main_component()
if isinstance(main_component, Serializable):
main_component.save()
main_component.save(fname)

def load(self) -> None:
for in_params, out_params, component in self.train_pipe:
Expand Down
9 changes: 5 additions & 4 deletions deeppavlov/core/models/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,17 @@ def deserialize(self, weights: Iterable[Tuple[str, np.ndarray]]) -> None:
feed_dict[assign_placeholder] = value
self.sess.run(assign_ops, feed_dict=feed_dict)

def save(self, exclude_scopes: tuple = ('Optimizer',)) -> None:
def save(self, exclude_scopes: tuple = ('Optimizer',), fname: Optional[Union[str, Path]] = None) -> None:
"""Save model parameters to self.save_path"""
if not hasattr(self, 'sess'):
raise RuntimeError('Your TensorFlow model {} must'
' have sess attribute!'.format(self.__class__.__name__))
path = str(self.save_path.resolve())
log.info('[saving model to {}]'.format(path))
if fname is None:
fname = str(self.save_path.resolve())
log.info('[saving model to {}]'.format(fname))
var_list = self._get_saveable_variables(exclude_scopes)
saver = tf.train.Saver(var_list)
saver.save(self.sess, path)
saver.save(self.sess, fname)

def serialize(self) -> Tuple[Tuple[str, np.ndarray], ...]:
tf_vars = tf.global_variables()
Expand Down
21 changes: 13 additions & 8 deletions deeppavlov/core/models/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from copy import deepcopy
from logging import getLogger
from pathlib import Path
from typing import Optional
from typing import Optional, Union

import torch
from overrides import overrides
Expand Down Expand Up @@ -127,7 +127,7 @@ def init_from_opt(self, model_func: str) -> None:
raise AttributeError("Model is not defined.")

@overrides
def load(self, fname: Optional[str] = None, *args, **kwargs) -> None:
def load(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None:
"""Load model from `fname` (if `fname` is not given, use `self.load_path`) to `self.model` along with
the optimizer `self.optimizer`, optionally `self.lr_scheduler`.
If `fname` (if `fname` is not given, use `self.load_path`) does not exist, initialize model from scratch.
Expand All @@ -143,15 +143,18 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None:
if fname is not None:
self.load_path = fname

if isinstance(self.load_path, str):
self.load_path = Path(self.load_path)

model_func = getattr(self, self.opt.get("model_name"), None)

if self.load_path:
log.info(f"Load path {self.load_path} is given.")
if isinstance(self.load_path, Path) and not self.load_path.parent.is_dir():
if not self.load_path.parent.is_dir():
raise ConfigError("Provided load path is incorrect!")

weights_path = Path(self.load_path.resolve())
weights_path = weights_path.with_suffix(f".pth.tar")
weights_path = self.load_path.resolve()
weights_path = weights_path.with_suffix(".pth.tar")
if weights_path.exists():
log.info(f"Load path {weights_path} exists.")
log.info(f"Initializing `{self.__class__.__name__}` from saved.")
Expand All @@ -173,7 +176,7 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None:
self.init_from_opt(model_func)

@overrides
def save(self, fname: Optional[str] = None, *args, **kwargs) -> None:
def save(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None:
"""Save torch model to `fname` (if `fname` is not given, use `self.save_path`). Checkpoint includes
`model_state_dict`, `optimizer_state_dict`, and `epochs_done` (number of training epochs).

Expand All @@ -187,11 +190,13 @@ def save(self, fname: Optional[str] = None, *args, **kwargs) -> None:
"""
if fname is None:
fname = self.save_path

else:
fname = str(self.save_path) + fname
fname = Path(fname)
if not fname.parent.is_dir():
raise ConfigError("Provided save path is incorrect!")

weights_path = Path(fname).with_suffix(f".pth.tar")
weights_path = fname.with_suffix(f".pth.tar")
log.info(f"Saving model to {weights_path}.")
# move the model to `cpu` before saving to provide consistency
torch.save({
Expand Down
14 changes: 11 additions & 3 deletions deeppavlov/core/trainers/nn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
from deeppavlov.core.trainers.fit_trainer import FitTrainer
from deeppavlov.core.trainers.utils import parse_metrics, NumpyArrayEncoder
from deeppavlov.core.models.serializable import Serializable

log = getLogger(__name__)

Expand Down Expand Up @@ -72,6 +73,8 @@ class NNTrainer(FitTrainer):
log_on_k_batches: count of random train batches to calculate metrics in log (default is ``1``)
max_test_batches: maximum batches count for pipeline testing and evaluation, overrides ``log_on_k_batches``,
ignored if negative (default is ``-1``)
save_every_n_batches: how often (in batches) to save model into f'{save_path}_{current_step}, the best model
is still saved to `save_path`, ignored if negative or zero (default is ``-1``)
**kwargs: additional parameters whose names will be logged but otherwise ignored


Expand Down Expand Up @@ -103,6 +106,7 @@ def __init__(self, chainer_config: dict, *,
validate_first: bool = True,
validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1,
log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1,
save_every_n_batches: int = -1,
**kwargs) -> None:
super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets,
show_examples=show_examples, tensorboard_log_dir=tensorboard_log_dir,
Expand Down Expand Up @@ -134,6 +138,7 @@ def _improved(op):
self.log_every_n_epochs = log_every_n_epochs
self.log_every_n_batches = log_every_n_batches
self.log_on_k_batches = log_on_k_batches if log_on_k_batches >= 0 else None
self.save_every_n_batches = save_every_n_batches

self.max_epochs = epochs
self.epoch = start_epoch_num
Expand All @@ -150,11 +155,10 @@ def _improved(op):
self.tb_train_writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / 'train_log'))
self.tb_valid_writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / 'valid_log'))

def save(self) -> None:
def save(self, fname: Optional[Union[str, Path]] = None) -> None:
if self._loaded:
raise RuntimeError('Cannot save already finalized chainer')

self._chainer.save()
self._chainer.save(fname)

def _is_initial_validation(self):
return self.validation_number == 0
Expand Down Expand Up @@ -297,6 +301,10 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None:
if self.val_every_n_batches > 0 and self.train_batches_seen % self.val_every_n_batches == 0:
self._validate(iterator,
tensorboard_tag='every_n_batches', tensorboard_index=self.train_batches_seen)

if self.save_every_n_batches > 0 and self.train_batches_seen % self.save_every_n_batches == 0:
log.info(f'Saving model at step: {self.train_batches_seen}')
self.save(fname = f'_{self.train_batches_seen}' )

self._send_event(event_name='after_batch')

Expand Down
5 changes: 3 additions & 2 deletions deeppavlov/models/bert/bert_sequence_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from logging import getLogger
from typing import List, Union, Dict, Optional
from pathlib import Path

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -443,10 +444,10 @@ def __call__(self,
**kwargs) -> Union[List[List[int]], List[np.ndarray]]:
raise NotImplementedError("You must implement method __call__ in your derived class.")

def save(self, exclude_scopes=('Optimizer', 'EMA/BackupVariables')) -> None:
def save(self, exclude_scopes=('Optimizer', 'EMA/BackupVariables'), fname: Optional[Union[str, Path]] = None) -> None:
if self.ema:
self.sess.run(self.ema.switch_to_train_op)
return super().save(exclude_scopes=exclude_scopes)
return super().save(exclude_scopes=exclude_scopes, fname = fname)

def load(self,
exclude_scopes=('Optimizer',
Expand Down