diff --git a/deeppavlov/configs/classifiers/sentiment_twitter.json b/deeppavlov/configs/classifiers/sentiment_twitter.json index 0d02ec5927..be1f205fe9 100644 --- a/deeppavlov/configs/classifiers/sentiment_twitter.json +++ b/deeppavlov/configs/classifiers/sentiment_twitter.json @@ -73,8 +73,8 @@ ], "filters_cnn": 256, "optimizer": "Adam", - "learning_rate": 0.01, - "learning_rate_decay": 0.1, + "learning_rate": 0.1, + "learning_rate_decay": 0.01, "loss": "binary_crossentropy", "last_layer_activation": "softmax", "coef_reg_cnn": 1e-3, @@ -107,7 +107,10 @@ "f1_macro", { "name": "roc_auc", - "inputs": ["y_onehot", "y_pred_probas"] + "inputs": [ + "y_onehot", + "y_pred_probas" + ] } ], "validation_patience": 5, @@ -119,7 +122,31 @@ "valid", "test" ], - "class_name": "nn_trainer" + "class_name": "nn_trainer", + "logger": [ + { + "name": "TensorboardLogger", + "log_dir": "{MODELS_PATH}/sentiment_twitter/Tensorboard_logs" + }, + { + "name": "StdLogger" + }, + { + "name": "WandbLogger", + "API_Key":"40-chars API KEY", + "init":{ + "project": "Tuning Hyperparameters", + "group": "Tuning lr & lr_decay", + "job_type":"lr=0.01, lr_decay=0.01", + "config": { + "description": "add any hyperprameter you want to monitor, architecture discription,..", + "learning_rate": 0.02, + "architecture": "CNN", + "dataset": "sentiment_twitter_data" + } + } + } + ] }, "metadata": { "variables": { @@ -128,6 +155,10 @@ "MODELS_PATH": "{ROOT_PATH}/models", "MODEL_PATH": "{MODELS_PATH}/classifiers/sentiment_twitter_v6" }, + "requirements": [ + "{DEEPPAVLOV_PATH}/requirements/tf.txt", + "{DEEPPAVLOV_PATH}/requirements/fasttext.txt" + ], "download": [ { "url": "http://files.deeppavlov.ai/datasets/sentiment_twitter_data.tar.gz", diff --git a/deeppavlov/core/common/logging/logging_class.py b/deeppavlov/core/common/logging/logging_class.py new file mode 100644 index 0000000000..8014607650 --- /dev/null +++ b/deeppavlov/core/common/logging/logging_class.py @@ -0,0 +1,133 @@ +# Copyright 2022 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import datetime +from itertools import islice +from abc import ABC, abstractmethod +from typing import List, Tuple +from logging import getLogger + +from deeppavlov.core.data.data_learning_iterator import DataLearningIterator +from deeppavlov.core.trainers.nn_trainer import NNTrainer + + +log = getLogger(__name__) + + +class TrainLogger(ABC): + """An abstract class for logging metrics during training process.""" + + def get_report(self, nn_trainer: NNTrainer, iterator: DataLearningIterator, type: str = None) -> dict: + """ " + Get report about current process. + for 'valid' type, 'get_report' function also saves best score on validation data, and the model parameters corresponding to the best score. + + Args: + nn_trainer: 'NNTrainer' object contains parameters required for preparing the report. + iterator: :class:`~deeppavlov.core.data.data_learning_iterator.DataLearningIterator` used for evaluation + type : if "train" returns report about training process, "valid" returns report about validation process. + + Returns: + dict contains data about current 'type' process. + + """ + if type == "train": + if nn_trainer.log_on_k_batches == 0: + report = {"time_spent": str(datetime.timedelta( + seconds=round(time.time() - nn_trainer.start_time + 0.5)))} + else: + data = islice(iterator.gen_batches(nn_trainer.batch_size, data_type="train", shuffle=True), + nn_trainer.log_on_k_batches,) + report = nn_trainer.test( + data, nn_trainer.train_metrics, start_time=nn_trainer.start_time + ) + + report.update( + { + "epochs_done": nn_trainer.epoch, + "batches_seen": nn_trainer.train_batches_seen, + "train_examples_seen": nn_trainer.examples, + } + ) + + metrics: List[Tuple[str, float]] = list( + report.get("metrics", {}).items() + ) + list(nn_trainer.last_result.items()) + + report.update(nn_trainer.last_result) + if nn_trainer.losses: + report["loss"] = sum(nn_trainer.losses) / len(nn_trainer.losses) + nn_trainer.losses.clear() + metrics.append(("loss", report["loss"])) + + elif type == "valid": + report = nn_trainer.test( + iterator.gen_batches( + nn_trainer.batch_size, data_type="valid", shuffle=False + ), + start_time=nn_trainer.start_time, + ) + + report["epochs_done"] = nn_trainer.epoch + report["batches_seen"] = nn_trainer.train_batches_seen + report["train_examples_seen"] = nn_trainer.examples + + metrics = list(report["metrics"].items()) + + m_name, score = metrics[0] + + # Update the patience + if nn_trainer.score_best is None: + nn_trainer.patience = 0 + else: + if nn_trainer.improved(score, nn_trainer.score_best): + nn_trainer.patience = 0 + else: + nn_trainer.patience += 1 + + # Run the validation model-saving logic + if nn_trainer._is_initial_validation(): + log.info("Initial best {} of {}".format(m_name, score)) + nn_trainer.score_best = score + elif nn_trainer._is_first_validation() and nn_trainer.score_best is None: + log.info("First best {} of {}".format(m_name, score)) + nn_trainer.score_best = score + log.info("Saving model") + nn_trainer.save() + elif nn_trainer.improved(score, nn_trainer.score_best): + log.info("Improved best {} of {}".format(m_name, score)) + nn_trainer.score_best = score + log.info("Saving model") + nn_trainer.save() + else: + log.info( + "Did not improve on the {} of {}".format( + m_name, nn_trainer.score_best + ) + ) + + report["impatience"] = nn_trainer.patience + if nn_trainer.validation_patience > 0: + report["patience_limit"] = nn_trainer.validation_patience + + nn_trainer.validation_number += 1 + return report + + @abstractmethod + def __call__() -> None: + raise NotImplementedError + + def close(): + raise NotImplementedError diff --git a/deeppavlov/core/common/logging/std_logger.py b/deeppavlov/core/common/logging/std_logger.py new file mode 100644 index 0000000000..19e6a45677 --- /dev/null +++ b/deeppavlov/core/common/logging/std_logger.py @@ -0,0 +1,53 @@ +from typing import Dict +from logging import getLogger +import json + +from deeppavlov.core.data.data_learning_iterator import DataLearningIterator +from deeppavlov.core.trainers.nn_trainer import NNTrainer +from deeppavlov.core.trainers.utils import NumpyArrayEncoder +from deeppavlov.core.common.logging.logging_class import TrainLogger + +log = getLogger(__name__) + + +class StdLogger(TrainLogger): + """ + StdLogger class for logging report about current training and validation processes to stdout. + + Args: + stdlogging (bool): if True, log report to stdout. + the object of this class with stdlogging = False can be used for validation process. + **kwargs: additional parameters whose names will be logged but otherwise ignored + """ + + def __init__(self, stdlogging: bool = True, **kwargs) -> None: + self.stdlogging = stdlogging + + def __call__(self,nn_trainer: NNTrainer, iterator: DataLearningIterator, type: str = None, report: Dict = None, + **kwargs) -> dict: + """ + override call method, to log report to stdout. + + Args: + nn_trainer: NNTrainer object contains parameters required for preparing report. + iterator: :class:`~deeppavlov.core.data.data_learning_iterator.DataLearningIterator` used for evaluation. + type : process type, if "train" logs report about training process, else if "valid" logs report about validation process. + report: dictionary contains current process information, if None, use 'get_report' method to get this report. + **kwargs: additional parameters whose names will be logged but otherwise ignored + Returns: + dict contains logged data to stdout. + + """ + if report is None: + report = self.get_report( + nn_trainer=nn_trainer, iterator=iterator, type=type + ) + if self.stdlogging: + log.info( + json.dumps({type: report}, ensure_ascii=False, cls=NumpyArrayEncoder) + ) + return report + + @staticmethod + def close(): + log.info("Logging to Stdout completed") \ No newline at end of file diff --git a/deeppavlov/core/common/logging/tensorboard_logger.py b/deeppavlov/core/common/logging/tensorboard_logger.py new file mode 100644 index 0000000000..dc99c57c4b --- /dev/null +++ b/deeppavlov/core/common/logging/tensorboard_logger.py @@ -0,0 +1,98 @@ +from pathlib import Path +from typing import List, Tuple, Optional, Dict +from logging import getLogger + +from deeppavlov.core.commands.utils import expand_path + +from deeppavlov.core.data.data_learning_iterator import DataLearningIterator +from deeppavlov.core.trainers.nn_trainer import NNTrainer +from deeppavlov.core.trainers.fit_trainer import FitTrainer +from deeppavlov.core.common.logging.logging_class import TrainLogger + +log = getLogger(__name__) + + +class TensorboardLogger(TrainLogger): + """ + TensorboardLogger class for logging to tesnorboard. + + Args: + fit_trainer: FitTrainer object passed to set Tensorflow as one of its parameter if successful importation. + log_dir (Path): path to local folder to log data into. + + """ + + def __init__(self, fit_trainer:FitTrainer , log_dir: Path = None) -> None: + try: + # noinspection PyPackageRequirements + # noinspection PyUnresolvedReferences + import tensorflow as tf + except ImportError: + log.warning('TensorFlow could not be imported, so tensorboard log directory' + f'`{log_dir}` will be ignored') + else: + log_dir = expand_path(log_dir) + fit_trainer._tf = tf + self.train_log_dir = str(log_dir / 'train_log') + self.valid_log_dir = str(log_dir / 'valid_log') + self.tb_train_writer = tf.summary.FileWriter(self.train_log_dir) + self.tb_valid_writer = tf.summary.FileWriter(self.valid_log_dir) + + def __call__(self, nn_trainer: NNTrainer, iterator: DataLearningIterator, type: str = None, + tensorboard_tag: Optional[str] = None, tensorboard_index: Optional[int] = None, + report: Dict = None, **kwargs) -> dict: + """ + override call method, for 'train' logging type, log metircs of training process to log_dir/train_log. + for 'valid' logging type, log metrics of validation process to log_dir/valid_log. + + Args: + nn_trainer: NNTrainer object contains parameters required for preparing the report. + iterator: :class:`~deeppavlov.core.data.data_learning_iterator.DataLearningIterator` used for evaluation + type : process type, if "train" logs report about training process, else if "valid" logs report about validation process. + tensorboard_tag: one of two options : 'every_n_batches', 'every_n_epochs' + tensorboard_index: one of two options: 'train_batches_seen', 'epoch' corresponding to 'tensorboard_tag' types respectively. + report: dictionary contains current process information, if None, use 'get_report' method to get this report. + **kwargs: additional parameters whose names will be logged but otherwise ignored + + Returns: + dict contains metrics logged to tesnorboard. + + """ + if report is None: + report = self.get_report( + nn_trainer=nn_trainer, iterator=iterator, type=type + ) + + if type == "train": + metrics: List[Tuple[str, float]] = list( + report.get("metrics", {}).items() + ) + list(nn_trainer.last_result.items()) + if report.get("loss", None) is not None: + metrics.append(("loss", report["loss"])) + + if metrics and self.train_log_dir is not None: + summary = nn_trainer._tf.Summary() + + for name, score in metrics: + summary.value.add( + tag=f"{tensorboard_tag}/{name}", simple_value=score + ) + self.tb_train_writer.add_summary(summary, tensorboard_index) + self.tb_train_writer.flush() + else: + metrics = list(report["metrics"].items()) + if tensorboard_tag is not None and self.valid_log_dir is not None: + summary = nn_trainer._tf.Summary() + for name, score in metrics: + summary.value.add( + tag=f"{tensorboard_tag}/{name}", simple_value=score + ) + if tensorboard_index is None: + tensorboard_index = nn_trainer.train_batches_seen + self.tb_valid_writer.add_summary(summary, tensorboard_index) + self.tb_valid_writer.flush() + return report + + @staticmethod + def close(): + log.info("Logging to Tensorboard completed") \ No newline at end of file diff --git a/deeppavlov/core/common/logging/wandb_logger.py b/deeppavlov/core/common/logging/wandb_logger.py new file mode 100644 index 0000000000..c65c559ae6 --- /dev/null +++ b/deeppavlov/core/common/logging/wandb_logger.py @@ -0,0 +1,139 @@ +import time +import datetime +from typing import Dict, Optional +from logging import getLogger + +import wandb + +from deeppavlov.core.data.data_learning_iterator import DataLearningIterator +from deeppavlov.core.trainers.nn_trainer import NNTrainer +from deeppavlov.core.common.logging.logging_class import TrainLogger + + +log = getLogger(__name__) + + +class WandbLogger(TrainLogger): + """ + WandbLogger class for logging report about current training and validation processes to WandB during training. ("https://wandb.ai/site"). + + WandB is a central dashboard to keep track of your hyperparameters, system metrics, and predictions so you can compare models live, and share your findings. + WandB doesn't support more than one run concurrently, so logging will be on "epochs" or "batches" + If val_every_n_epochs > 0 or log_every_n_epochs > 0 in config file, logging to wandb will be on epochs. + Otherwise if val_every_n_batches > 0 or log_every_n_batches > 0 in config file, logging to wandb will be on batches. + if none of them, logging to wandb will be ignored. + + Args: + API_Key (str): authentication key. + relogin (bool): if True, force relogin if already logged in. + commit_on_valid (bool): If False wandb.log just updates the current metrics dict with the row argument and metrics won't be saved until wandb.log is called with commit=True + to commit training and validation reports with the same steps, this argument is True if logging on validation required + val_every_n_epochs: how often (in epochs) to validate the pipeline, ignored if negative or zero + (default is ``-1``) + val_every_n_batches: how often (in batches) to validate the pipeline, ignored if negative or zero + (default is ``-1``) + log_every_n_epochs: how often (in epochs) to calculate metrics on train data, ignored if negative or zero + (default is ``-1``) + log_every_n_batches: how often (in batches) to calculate metrics on train data, ignored if negative or zero + (default is ``-1``) + **kwargs: arguments for wandb initialization, more info: https://docs.wandb.ai/ref/python/init + + """ + + @staticmethod + def login(API_Key: str = None, relogin: bool = True) -> bool: + """ " + static method to login to wandb account, if login or init to wandb failed, logging to wandb will be ignored. + + Args: + API_Key (str): authentication key. + relogin (bool): if True, force relogin if already logged in. + + Returns: + True if login and init processes succeed, otherwise False and logging to wandb will be ignored. + + """ + try: + return wandb.login(key=API_Key, relogin=relogin) + except Exception as e: + log.warning(str(e) + ', logging to WandB will be ignored') + return False + + def __init__(self, API_Key: str = None, relogin: bool = True, val_every_n_epochs: int = -1, + val_every_n_batches: int = -1, log_every_n_batches: int = -1, log_every_n_epochs: int = -1, **kwargs) -> None: + if self.login(API_Key = API_Key, relogin = relogin): + try: + wandb.init(**kwargs.get('init', None)) + self.init_succeed = True + if log_every_n_epochs > 0 or val_every_n_epochs > 0: + self.log_on ='every_n_epochs' + self.commit_on_valid = val_every_n_epochs > 0 + + elif log_every_n_batches > 0 or val_every_n_batches > 0: + self.log_on ='every_n_batches' + self.commit_on_valid = val_every_n_batches > 0 + + except Exception as e: + log.warning(str(e) + ', logging to WandB will be ignored') + self.init_succeed = False + else: + log.warning('login to WandB failed') + self.init_succeed = False + + def __call__( + self, + nn_trainer: NNTrainer, + iterator: DataLearningIterator, + tensorboard_tag: Optional[str] = None, + type: str = None, + report: Dict = None, + **kwargs): + """ + Logging report of the training process to wandb. + + Args: + nn_trainer: 'NNTrainer' object contains parameters required for preparing the report. + iterator: :class:`~deeppavlov.core.data.data_learning_iterator.DataLearningIterator` used for evaluation + tensorboard_tag: one of two options : 'every_n_batches', 'every_n_epochs' + report (dict): report for logging to WandB. If None, use 'get_report' method to get this report. + type (str) : process type, if "train" logs report about training process, else if "valid" logs report about validation process. + + Returns: + dict contains logged data to WandB. + + """ + if not self.init_succeed or tensorboard_tag != self.log_on: + return None + + if report is None: + report = self.get_report(nn_trainer=nn_trainer, iterator=iterator, type=type) + logging_type = type + "/" + for i in report.keys(): + if isinstance(report[i], dict): + for key, value in report[i].items(): + wandb.log({logging_type + key: value}, commit=False) + else: + if i == "time_spent": + t = time.strptime(report[i], "%H:%M:%S") + y_seconds = int( + datetime.timedelta( + hours=t.tm_hour, minutes=t.tm_min, seconds=t.tm_sec + ).total_seconds() + ) + wandb.log({logging_type + i + ("(s)"): y_seconds}, commit=False) + else: + wandb.log({logging_type + i: report[i]}, commit=False) + + # if "val_every_n_epochs" is not None, we have to commit data on validation logging, otherwise on training. + if (self.commit_on_valid and logging_type == "valid/") or ( + not self.commit_on_valid and logging_type == "train/"): + wandb.log({}, commit=True) + + return report + + @staticmethod + def close(): + """close function to commit the not commited logs and to mark a run as finished wiht wanb.finish method, and finishes uploading all data.""" + wandb.log({}, commit=True) + wandb.finish() + log.info("Logging to W&B completed") diff --git a/deeppavlov/core/trainers/fit_trainer.py b/deeppavlov/core/trainers/fit_trainer.py index 0378560564..595d01235c 100644 --- a/deeppavlov/core/trainers/fit_trainer.py +++ b/deeppavlov/core/trainers/fit_trainer.py @@ -17,11 +17,9 @@ import time from itertools import islice from logging import getLogger -from pathlib import Path -from typing import Tuple, Dict, Union, Optional, Iterable, Any, Collection +from typing import List, Tuple, Dict, Union, Optional, Iterable, Any, Collection from deeppavlov.core.commands.infer import build_model -from deeppavlov.core.commands.utils import expand_path from deeppavlov.core.common.chainer import Chainer from deeppavlov.core.common.params import from_params from deeppavlov.core.common.registry import register @@ -50,10 +48,18 @@ class FitTrainer: evaluation_targets: data types on which to evaluate trained pipeline (default is ``('valid', 'test')``) show_examples: a flag used to print inputs, expected outputs and predicted outputs for the last batch in evaluation logs (default is ``False``) - tensorboard_log_dir: path to a directory where tensorboard logs can be stored, ignored if None + logger: list of dictionaries with train and evaluation loggers configuration. (default is ``None``) max_test_batches: maximum batches count for pipeline testing and evaluation, ignored if negative (default is ``-1``) + val_every_n_epochs: how often (in epochs) to validate the pipeline, ignored if negative or zero + (default is ``-1``) + val_every_n_batches: how often (in batches) to validate the pipeline, ignored if negative or zero + (default is ``-1``) + log_every_n_epochs: how often (in epochs) to calculate metrics on train data, ignored if negative or zero + (default is ``-1``) + log_every_n_batches: how often (in batches) to calculate metrics on train data, ignored if negative or zero + (default is ``-1``) **kwargs: additional parameters whose names will be logged but otherwise ignored """ @@ -61,8 +67,10 @@ def __init__(self, chainer_config: dict, *, batch_size: int = -1, metrics: Iterable[Union[str, dict]] = ('accuracy',), evaluation_targets: Iterable[str] = ('valid', 'test'), show_examples: bool = False, - tensorboard_log_dir: Optional[Union[str, Path]] = None, max_test_batches: int = -1, + logger: Optional[List[dict]] = None, + val_every_n_batches: int = -1, val_every_n_epochs: int = -1, + log_every_n_batches: int = -1, log_every_n_epochs: int = -1, **kwargs) -> None: if kwargs: log.info(f'{self.__class__.__name__} got additional init parameters {list(kwargs)} that will be ignored:') @@ -75,19 +83,36 @@ def __init__(self, chainer_config: dict, *, batch_size: int = -1, self.max_test_batches = None if max_test_batches < 0 else max_test_batches - self.tensorboard_log_dir: Optional[Path] = tensorboard_log_dir - if tensorboard_log_dir is not None: + from deeppavlov.core.common.logging.logging_class import TrainLogger + from deeppavlov.core.common.logging.std_logger import StdLogger + + self.logger: List[TrainLogger] = [] + + self.tensorboard_idx, self.stdlogger_idx, self.wandblogger_idx = None, None, None + + if logger is None: + logger = [{'name': 'StdLogger'}] + for logger_config in logger: + logger_name = logger_config.pop('name',None) + if logger_name is None: + raise KeyError("There is no 'name' key in logger configuration") + lgr = None try: - # noinspection PyPackageRequirements - # noinspection PyUnresolvedReferences - import tensorflow + if logger_name == 'StdLogger': + lgr = StdLogger(**logger_config) + elif logger_name == 'TensorboardLogger': + from deeppavlov.core.common.logging.tensorboard_logger import TensorboardLogger + lgr = TensorboardLogger(self, **logger_config) + elif logger_name == 'WandbLogger': + from deeppavlov.core.common.logging.wandb_logger import WandbLogger + lgr = WandbLogger(**logger_config, val_every_n_batches = val_every_n_batches, + val_every_n_epochs = val_every_n_epochs, + log_every_n_batches = log_every_n_batches, log_every_n_epochs=log_every_n_epochs) except ImportError: - log.warning('TensorFlow could not be imported, so tensorboard log directory' - f'`{self.tensorboard_log_dir}` will be ignored') - self.tensorboard_log_dir = None - else: - self.tensorboard_log_dir = expand_path(tensorboard_log_dir) - self._tf = tensorflow + log.warning(f'{logger_name} will be ignored due to import error. Check that all necessary requirements' + f'are installed') + if lgr is not None: + self.logger.append(lgr) self._built = False self._saved = False @@ -117,13 +142,14 @@ def fit_chainer(self, iterator: Union[DataFittingIterator, DataLearningIterator] # noinspection PyUnresolvedReferences result = component.partial_fit(*preprocessed) - if result is not None and self.tensorboard_log_dir is not None: + if result is not None and self.tensorboard_idx is not None: if writer is None: - writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / + writer = self._tf.summary.FileWriter(str(self.logger[self.tensorboard_idx]["log_dir"] / f'partial_fit_{component_index}_log')) for name, score in result.items(): summary = self._tf.Summary() - summary.value.add(tag='partial_fit/' + name, simple_value=score) + summary.value.add( + tag='partial_fit/' + name, simple_value=score) writer.add_summary(summary, i) writer.flush() else: @@ -132,13 +158,14 @@ def fit_chainer(self, iterator: Union[DataFittingIterator, DataLearningIterator] preprocessed = [preprocessed] result: Optional[Dict[str, Iterable[float]]] = component.fit(*preprocessed) - if result is not None and self.tensorboard_log_dir is not None: - writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / + if result is not None and self.tensorboard_idx is not None: + writer = self._tf.summary.FileWriter(str(self.logger[self.tensorboard_idx]["log_dir"] / f'fit_log_{component_index}')) for name, scores in result.items(): for i, score in enumerate(scores): summary = self._tf.Summary() - summary.value.add(tag='fit/' + name, simple_value=score) + summary.value.add( + tag='fit/' + name, simple_value=score) writer.add_summary(summary, i) writer.flush() @@ -264,6 +291,7 @@ def evaluate(self, iterator: DataLearningIterator, evaluation_targets: Optional[ report = self.test(data_gen) res[data_type] = report if print_reports: - print(json.dumps({data_type: report}, ensure_ascii=False, cls=NumpyArrayEncoder)) + print(json.dumps({data_type: report}, + ensure_ascii=False, cls=NumpyArrayEncoder)) return res diff --git a/deeppavlov/core/trainers/nn_trainer.py b/deeppavlov/core/trainers/nn_trainer.py index 6f6fd8b4bf..ec578b51eb 100644 --- a/deeppavlov/core/trainers/nn_trainer.py +++ b/deeppavlov/core/trainers/nn_trainer.py @@ -13,18 +13,15 @@ # limitations under the License. import datetime -import json import time -from itertools import islice from logging import getLogger -from pathlib import Path -from typing import List, Tuple, Union, Optional, Iterable +from typing import List, Dict, Union, Optional, Iterable from deeppavlov.core.common.errors import ConfigError from deeppavlov.core.common.registry import register from deeppavlov.core.data.data_learning_iterator import DataLearningIterator from deeppavlov.core.trainers.fit_trainer import FitTrainer -from deeppavlov.core.trainers.utils import parse_metrics, NumpyArrayEncoder +from deeppavlov.core.trainers.utils import parse_metrics log = getLogger(__name__) @@ -55,8 +52,8 @@ class NNTrainer(FitTrainer): evaluation_targets: data types on which to evaluate a trained pipeline (default is ``('valid', 'test')``) show_examples: a flag used to print inputs, expected outputs and predicted outputs for the last batch in evaluation logs (default is ``False``) - tensorboard_log_dir: path to a directory where tensorboard logs can be stored, ignored if None - (default is ``None``) + logger : list of dictionaries of possible loggers provided in config file, ignored if None + (default is ``None``), possible loggers: TensorboardLogger, StdLogger and WandbLogger validate_first: flag used to calculate metrics on the ``'valid'`` data type before starting training (default is ``True``) validation_patience: how many times in a row the validation metric has to not improve for early stopping, @@ -88,25 +85,31 @@ class NNTrainer(FitTrainer): """ - def __init__(self, chainer_config: dict, *, - batch_size: int = 1, - epochs: int = -1, - start_epoch_num: int = 0, - max_batches: int = -1, - metrics: Iterable[Union[str, dict]] = ('accuracy',), - train_metrics: Optional[Iterable[Union[str, dict]]] = None, - metric_optimization: str = 'maximize', - evaluation_targets: Iterable[str] = ('valid', 'test'), - show_examples: bool = False, - tensorboard_log_dir: Optional[Union[str, Path]] = None, - max_test_batches: int = -1, - validate_first: bool = True, - validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1, - log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1, - **kwargs) -> None: + def __init__(self, chainer_config: dict, *, + batch_size: int = 1, + epochs: int = -1, + start_epoch_num: int = 0, + max_batches: int = -1, + metrics: Iterable[Union[str, dict]] = ("accuracy",), + train_metrics: Optional[Iterable[Union[str, dict]]] = None, + metric_optimization: str = "maximize", + evaluation_targets: Iterable[str] = ("valid", "test"), + show_examples: bool = False, + logger: Optional[List[Dict]] = None, + max_test_batches: int = -1, + validate_first: bool = True, + validation_patience: int = 5, + val_every_n_epochs: int = -1, + val_every_n_batches: int = -1, + log_every_n_batches: int = -1, + log_every_n_epochs: int = -1, + log_on_k_batches: int = 1, + **kwargs) -> None: super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets, - show_examples=show_examples, tensorboard_log_dir=tensorboard_log_dir, - max_test_batches=max_test_batches, **kwargs) + show_examples=show_examples, logger=logger, max_test_batches=max_test_batches, + val_every_n_batches = val_every_n_batches, val_every_n_epochs = val_every_n_epochs, + log_every_n_batches = log_every_n_batches, log_every_n_epochs=log_every_n_epochs, **kwargs) + if train_metrics is None: self.train_metrics = self.metrics else: @@ -119,14 +122,16 @@ def _improved(op): return lambda score, baseline: False if baseline is None or score is None \ else op(score, baseline) - if metric_optimization == 'maximize': + if metric_optimization == "maximize": self.improved = _improved(lambda a, b: a > b) - elif metric_optimization == 'minimize': + elif metric_optimization == "minimize": self.improved = _improved(lambda a, b: a < b) else: - raise ConfigError('metric_optimization has to be one of {}'.format(['maximize', 'minimize'])) + raise ConfigError("metric_optimization has to be one of {}".format(["maximize", "minimize"])) self.validate_first = validate_first + from deeppavlov.core.common.logging.std_logger import StdLogger + self.validate_ = StdLogger(self.stdlogger_idx is not None) self.validation_number = 0 if validate_first else 1 self.validation_patience = validation_patience self.val_every_n_epochs = val_every_n_epochs @@ -146,9 +151,6 @@ def _improved(op): self.losses = [] self.start_time: Optional[float] = None - if self.tensorboard_log_dir is not None: - self.tb_train_writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / 'train_log')) - self.tb_valid_writer = self._tf.summary.FileWriter(str(self.tensorboard_log_dir / 'valid_log')) def save(self) -> None: if self._loaded: @@ -162,109 +164,14 @@ def _is_initial_validation(self): def _is_first_validation(self): return self.validation_number == 1 - def _validate(self, iterator: DataLearningIterator, - tensorboard_tag: Optional[str] = None, tensorboard_index: Optional[int] = None) -> None: - self._send_event(event_name='before_validation') - report = self.test(iterator.gen_batches(self.batch_size, data_type='valid', shuffle=False), - start_time=self.start_time) - - report['epochs_done'] = self.epoch - report['batches_seen'] = self.train_batches_seen - report['train_examples_seen'] = self.examples - - metrics = list(report['metrics'].items()) - - if tensorboard_tag is not None and self.tensorboard_log_dir is not None: - summary = self._tf.Summary() - for name, score in metrics: - summary.value.add(tag=f'{tensorboard_tag}/{name}', simple_value=score) - if tensorboard_index is None: - tensorboard_index = self.train_batches_seen - self.tb_valid_writer.add_summary(summary, tensorboard_index) - self.tb_valid_writer.flush() - - m_name, score = metrics[0] - - # Update the patience - if self.score_best is None: - self.patience = 0 - else: - if self.improved(score, self.score_best): - self.patience = 0 - else: - self.patience += 1 - - # Run the validation model-saving logic - if self._is_initial_validation(): - log.info('Initial best {} of {}'.format(m_name, score)) - self.score_best = score - elif self._is_first_validation() and self.score_best is None: - log.info('First best {} of {}'.format(m_name, score)) - self.score_best = score - log.info('Saving model') - self.save() - elif self.improved(score, self.score_best): - log.info(f'Improved best {m_name} from {self.score_best} to {score}') - self.score_best = score - log.info('Saving model') - self.save() - else: - log.info('Did not improve on the {} of {}'.format(m_name, self.score_best)) - - report['impatience'] = self.patience - if self.validation_patience > 0: - report['patience_limit'] = self.validation_patience - - self._send_event(event_name='after_validation', data=report) - report = {'valid': report} - print(json.dumps(report, ensure_ascii=False, cls=NumpyArrayEncoder)) - self.validation_number += 1 - - def _log(self, iterator: DataLearningIterator, - tensorboard_tag: Optional[str] = None, tensorboard_index: Optional[int] = None) -> None: - self._send_event(event_name='before_log') - if self.log_on_k_batches == 0: - report = { - 'time_spent': str(datetime.timedelta(seconds=round(time.time() - self.start_time + 0.5))) - } - else: - data = islice(iterator.gen_batches(self.batch_size, data_type='train', shuffle=True), - self.log_on_k_batches) - report = self.test(data, self.train_metrics, start_time=self.start_time) - - report.update({ - 'epochs_done': self.epoch, - 'batches_seen': self.train_batches_seen, - 'train_examples_seen': self.examples - }) - - metrics: List[Tuple[str, float]] = list(report.get('metrics', {}).items()) + list(self.last_result.items()) - - report.update(self.last_result) - if self.losses: - report['loss'] = sum(self.losses) / len(self.losses) - self.losses.clear() - metrics.append(('loss', report['loss'])) - - if metrics and self.tensorboard_log_dir is not None: - summary = self._tf.Summary() - - for name, score in metrics: - summary.value.add(tag=f'{tensorboard_tag}/{name}', simple_value=score) - self.tb_train_writer.add_summary(summary, tensorboard_index) - self.tb_train_writer.flush() - - self._send_event(event_name='after_train_log', data=report) - - report = {'train': report} - print(json.dumps(report, ensure_ascii=False, cls=NumpyArrayEncoder)) - def _send_event(self, event_name: str, data: Optional[dict] = None) -> None: report = { - 'time_spent': str(datetime.timedelta(seconds=round(time.time() - self.start_time + 0.5))), - 'epochs_done': self.epoch, - 'batches_seen': self.train_batches_seen, - 'train_examples_seen': self.examples + "time_spent": str( + datetime.timedelta(seconds=round(time.time() - self.start_time + 0.5)) + ), + "epochs_done": self.epoch, + "batches_seen": self.train_batches_seen, + "train_examples_seen": self.examples, } if data is not None: report.update(data) @@ -274,31 +181,50 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None: """Train pipeline on batches using provided data iterator and initialization parameters""" self.start_time = time.time() if self.validate_first: - self._validate(iterator) + self._send_event(event_name="before_validation") + report = self.validate_(self, iterator, "valid") + self._send_event(event_name="after_validation", data=report) while True: impatient = False - self._send_event(event_name='before_train') - for x, y_true in iterator.gen_batches(self.batch_size, data_type='train'): + self._send_event(event_name="before_train") + for x, y_true in iterator.gen_batches(self.batch_size, data_type="train"): self.last_result = self._chainer.train_on_batch(x, y_true) if self.last_result is None: self.last_result = {} elif not isinstance(self.last_result, dict): - self.last_result = {'loss': self.last_result} - if 'loss' in self.last_result: - self.losses.append(self.last_result.pop('loss')) + self.last_result = {"loss": self.last_result} + if "loss" in self.last_result: + self.losses.append(self.last_result.pop("loss")) self.train_batches_seen += 1 self.examples += len(x) - if self.log_every_n_batches > 0 and self.train_batches_seen % self.log_every_n_batches == 0: - self._log(iterator, tensorboard_tag='every_n_batches', tensorboard_index=self.train_batches_seen) + if ( + self.log_every_n_batches > 0 + and self.train_batches_seen % self.log_every_n_batches == 0 + ): + self._send_event(event_name="before_log") + report = None + + for lgr in self.logger: + report = lgr(self, iterator, type="train", tensorboard_tag="every_n_batches", + tensorboard_index=self.train_batches_seen, report=report) + + # empty report if no logging method. + self._send_event(event_name="after_train_log", data=report) + + if (self.val_every_n_batches > 0 and self.train_batches_seen % self.val_every_n_batches == 0): + self._send_event(event_name="before_validation") + report = None + + for lgr in self.logger: + report = lgr(self, iterator, type="valid",tensorboard_tag="every_n_batches", + tensorboard_index=self.train_batches_seen, report = report) - if self.val_every_n_batches > 0 and self.train_batches_seen % self.val_every_n_batches == 0: - self._validate(iterator, - tensorboard_tag='every_n_batches', tensorboard_index=self.train_batches_seen) + self._send_event(event_name="after_validation", data=report) - self._send_event(event_name='after_batch') + self._send_event(event_name="after_batch") if 0 < self.max_batches <= self.train_batches_seen: impatient = True @@ -313,20 +239,35 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None: break self.epoch += 1 + if (self.log_every_n_epochs > 0 and self.epoch % self.log_every_n_epochs == 0): + self._send_event(event_name="before_log") - if self.log_every_n_epochs > 0 and self.epoch % self.log_every_n_epochs == 0: - self._log(iterator, tensorboard_tag='every_n_epochs', tensorboard_index=self.epoch) + report = None - if self.val_every_n_epochs > 0 and self.epoch % self.val_every_n_epochs == 0: - self._validate(iterator, tensorboard_tag='every_n_epochs', tensorboard_index=self.epoch) + for lgr in self.logger: + report = lgr(self, iterator, type="train",tensorboard_tag="every_n_epochs", + tensorboard_index=self.epoch, report=report) - self._send_event(event_name='after_epoch') + self._send_event(event_name="after_train_log", data=report) + + if (self.val_every_n_epochs > 0 and self.epoch % self.val_every_n_epochs == 0): + self._send_event(event_name="before_validation") + + report = None + + for lgr in self.logger: + report = lgr(self, iterator, type="valid",tensorboard_tag="every_n_epochs", + tensorboard_index=self.epoch,report = report) + + self._send_event(event_name="after_validation", data=report) + + self._send_event(event_name="after_epoch") if 0 < self.max_epochs <= self.epoch: break if 0 < self.validation_patience <= self.patience: - log.info('Ran out of patience') + log.info("Ran out of patience") break def train(self, iterator: DataLearningIterator) -> None: @@ -344,3 +285,8 @@ def train(self, iterator: DataLearningIterator) -> None: if self.validation_number < 1: log.info('Save model to capture early training results') self.save() + + for lgr in self.logger: + lgr.close() + + diff --git a/deeppavlov/requirements/wandb.txt b/deeppavlov/requirements/wandb.txt new file mode 100644 index 0000000000..910c81a42a --- /dev/null +++ b/deeppavlov/requirements/wandb.txt @@ -0,0 +1 @@ +wandb==0.12.7 diff --git a/docs/conf.py b/docs/conf.py index b3a4f11237..7454c1a691 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -193,7 +193,7 @@ autodoc_mock_imports = ['bert_dp', 'bs4', 'faiss', 'fastText', 'fasttext', 'gensim', 'hdt', 'kenlm', 'librosa', 'lxml', 'nemo', 'nemo_asr', 'nemo_tts', 'nltk', 'opt_einsum', 'rapidfuzz', 'rasa', 'russian_tagsets', 'sacremoses', 'sortedcontainers', 'spacy', 'tensorflow', 'tensorflow_hub', - 'torch', 'transformers', 'udapi', 'ufal_udpipe', 'whapi', 'xeger'] + 'torch', 'transformers', 'udapi', 'ufal_udpipe','wandb', 'whapi', 'xeger'] extlinks = { 'config': (f'https://github.com/deepmipt/DeepPavlov/blob/{release}/deeppavlov/configs/%s', None) diff --git a/docs/intro/configuration.rst b/docs/intro/configuration.rst index 9f873c5e9c..c5818386a6 100644 --- a/docs/intro/configuration.rst +++ b/docs/intro/configuration.rst @@ -222,6 +222,138 @@ _______ | | Default value for ``inputs`` parameter is a concatenation of chainer's ``in_y`` and ``out`` parameters. +Logging data during training process +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +| Logging data is done following two steps: + +* Add at least one of the following arguments in configuration file with strictly positive integer value: + + - ``val_every_n_batches``: how often (in batches) to validate the pipeline. + + - ``val_every_n_epochs``: how often (in epochs) to validate the pipeline. + + - ``log_every_n_epochs``: how often (in epochs) to calculate metrics on train data. + + - ``log_every_n_batches``: how often (in batches) to calculate metrics on train data. + + Logging will be ignored for negative or zero. + + Example: + + .. code:: python + + "train": { + "log_every_n_epochs": 3, + "val_every_n_batches": 2 + } + + To log training data every 3 epochs, and validation data every 2 batches, using the appropriate logging method. + +* Add the logging method: + + Deeppavlov library supports three types of logging: + + - StdLogging: for logging data about current training and validation processes to stdout. + + To log data using this logger, add "logger" list containing dictionary with ``name``: ``StdLogger`` in configuration file. + For example: + + .. code:: python + + "train": { + "logger": [ + { + "name": "StdLogger" + } + ], + ... + } + + - TensorboardLogger: for logging data to Tensorboard, stored in local folder. + + To log data using this logger, add logger name, with local directory path. + + For example: + + .. code:: python + + "train": { + "logger": [ + { + "name": "TensorboardLogger", + "log_dir": "local_folder/Tensorboard_logs" + } + ], + ... + } + + In this case, training data will be stored in "local_folder/Tensorboard_logs/train_log", + and validation data in "local_folder/Tensorboard_logs/valid_log". + + To visualize training logs, use the following command line: + + "tensorboard --logdir local_folder/Tensorboard_logs/train_log" + + - WandbLogger: for logging data to Weights & Biases platform in real time. + + To log data using this logger, add logger name, with API key. + + To get API key: + + Sign up to wandb platform : https://wandb.ai/site if don’t have an account, login and go to setting (upper right corner), copy the API key. + + To create a new run in W&B with specific configurations, add ``init`` keyword with its configuration as dictionary (see https://docs.wandb.ai/ref/python/init). + + For example: + + .. code:: python + + "train": { + "logger": [ + { + "name": "WandbLogger", + "API_Key":"API of 40 characters long", + "init":{ + "project": "project_name", + "group": "group_name", + "job_type":"job_type", + "name":"run_name", + }, + "config": { + "learning_rate": 0.1, + } + } + ], + ... + } + + Logging to W&B will be on epochs if ``log_every_n_epochs`` or ``val_every_n_epochs`` were added to configuration file, otherwise logging on batches if ``log_every_n_batches`` or ``val_every_n_batches`` were added. + + To view run while training, follow the run link logged to stdout. + + To add more than one logger type as dictionary, for example: + + .. code:: python + + "train": { + "logger": [ + { + "name": "TensorboardLogger", + "log_dir": "local_folder/Tensorboard_logs" + }, + { + "name": "StdLogger" + } + ], + ... + } + + Default logging method is ``StdLogger`` (if ``logger`` not provided in configuration file), for no logging, add ``logger`` with empty list. + + + + + DatasetReader ~~~~~~~~~~~~~