diff --git a/CHANGELOG.md b/CHANGELOG.md index c4af421343..7f994bfd76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - docs structure were updated during ([#822](https://github.com/catalyst-team/catalyst/pull/822)) -- `utils.process_components` moved from `utils.distributed` to `utils.components` ([#822](https://github.com/catalyst-team/catalyst/pull/822)) +- `utils.process_components` moved from `utils.distributed` to `utils.components` ([#822](https://github.com/catalyst-team/catalyst/pull/822)) +- `catalyst.core.state.State` merged to `catalyst.core.runner._Runner` ([#823](https://github.com/catalyst-team/catalyst/pull/823)) (backward compatibility included) + - `catalyst.core.callback.Callback` now works directly with `catalyst.core.runner._Runner` + - `state_kwargs` renamed to `stage_kwargs` ### Removed diff --git a/README.md b/README.md index b1143c6500..5899d31879 100644 --- a/README.md +++ b/README.md @@ -74,14 +74,14 @@ class CustomRunner(dl.Runner): loss = F.cross_entropy(y_hat, y) accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3)) - self.state.batch_metrics.update( + self.batch_metrics.update( {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03} ) - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() runner = CustomRunner() # model training @@ -233,17 +233,17 @@ class CustomRunner(dl.Runner): loss = F.cross_entropy(y_hat, y) accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5)) - self.state.batch_metrics = { + self.batch_metrics = { "loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03, "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() runner = CustomRunner() runner.train( @@ -304,7 +304,7 @@ class CustomRunner(dl.Runner): loss_ae = F.mse_loss(x_, x) loss = loss_clf + loss_ae accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5)) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_ae": loss_ae, "loss": loss, @@ -313,10 +313,10 @@ class CustomRunner(dl.Runner): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() runner = CustomRunner() runner.train( @@ -402,7 +402,7 @@ class CustomRunner(dl.Runner): loss_logprob = torch.mean(z_logprob) * 0.01 loss = loss_clf + loss_ae + loss_kld + loss_logprob accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5)) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_ae": loss_ae, "loss_kld": loss_kld, @@ -413,10 +413,10 @@ class CustomRunner(dl.Runner): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() runner = CustomRunner() runner.train( @@ -479,7 +479,7 @@ class CustomRunner(dl.Runner): loss_iou = 1 - iou loss = loss_clf + loss_iou accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5)) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_iou": loss_iou, "loss": loss, @@ -489,10 +489,10 @@ class CustomRunner(dl.Runner): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() runner = CustomRunner() runner.train( @@ -590,7 +590,7 @@ class CustomRunner(dl.Runner): batch_metrics["loss_generator"] = \ F.binary_cross_entropy_with_logits(predictions, misleading_labels) - self.state.batch_metrics.update(**batch_metrics) + self.batch_metrics.update(**batch_metrics) runner = CustomRunner() runner.train( @@ -703,7 +703,7 @@ class CustomRunner(dl.Runner): loss_ae = F.mse_loss(x_, x) loss = loss_clf + loss_ae accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5)) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_ae": loss_ae, "loss": loss, @@ -712,10 +712,10 @@ class CustomRunner(dl.Runner): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() def datasets_fn(): dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) @@ -757,7 +757,7 @@ utils.distributed_cmd_run(train) ### Structure - **core** - framework core with main abstractions - - Experiment, Runner, Callback and State. + Experiment, Runner and Callback. - **data** - useful tools and scripts for data processing. - **dl** – runner for training and inference, all of the classic ML and CV/NLP/RecSys metrics diff --git a/bin/tests/check_dl_core.sh b/bin/tests/check_dl2_core.sh similarity index 100% rename from bin/tests/check_dl_core.sh rename to bin/tests/check_dl2_core.sh diff --git a/bin/tests/check_dl_core_callbacks.sh b/bin/tests/check_dl_core_callbacks.sh index 598cfb8069..750cddcb99 100644 --- a/bin/tests/check_dl_core_callbacks.sh +++ b/bin/tests/check_dl_core_callbacks.sh @@ -596,7 +596,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ python3 -c " import torch from torch.utils.data import DataLoader, TensorDataset -from catalyst.dl import SupervisedRunner, State, Callback, CallbackOrder, CheckpointCallback +from catalyst.dl import SupervisedRunner, CheckpointCallback # experiment_setup logdir = '${LOGDIR}' @@ -661,7 +661,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ python3 -c " import torch from torch.utils.data import DataLoader, TensorDataset -from catalyst.dl import SupervisedRunner, State, Callback, CallbackOrder, CheckpointCallback +from catalyst.dl import SupervisedRunner, CheckpointCallback # experiment_setup logdir = '${LOGDIR}' @@ -748,7 +748,7 @@ echo ${LOG_MSG} PYTHONPATH=./examples:./catalyst:${PYTHONPATH} python3 -c " import torch from torch.utils.data import DataLoader, TensorDataset -from catalyst.dl import SupervisedRunner, State, Callback, CallbackOrder, CheckpointCallback +from catalyst.dl import SupervisedRunner, CheckpointCallback # experiment_setup logdir = '${LOGDIR}' @@ -810,7 +810,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ python3 -c " import torch from torch.utils.data import DataLoader, TensorDataset -from catalyst.dl import SupervisedRunner, State, Callback, CallbackOrder, CheckpointCallback +from catalyst.dl import SupervisedRunner, CheckpointCallback # experiment_setup logdir = '${LOGDIR}' diff --git a/bin/tests/check_dl_core_periodic_loader_callback.sh b/bin/tests/check_dl_core_periodic_loader_callback.sh index fba69a8576..7fdcd00f34 100644 --- a/bin/tests/check_dl_core_periodic_loader_callback.sh +++ b/bin/tests/check_dl_core_periodic_loader_callback.sh @@ -73,8 +73,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( - SupervisedRunner, State, Callback, CallbackOrder, - PeriodicLoaderCallback, + SupervisedRunner, Callback, CallbackOrder, PeriodicLoaderCallback, ) # experiment_setup @@ -139,8 +138,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( - SupervisedRunner, State, Callback, CallbackOrder, - PeriodicLoaderCallback, + SupervisedRunner, Callback, CallbackOrder, PeriodicLoaderCallback, ) # experiment_setup @@ -205,8 +203,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( - SupervisedRunner, State, Callback, CallbackOrder, - PeriodicLoaderCallback, + SupervisedRunner, Callback, CallbackOrder, PeriodicLoaderCallback, ) # experiment_setup @@ -279,8 +276,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( - SupervisedRunner, State, Callback, CallbackOrder, - PeriodicLoaderCallback, + SupervisedRunner, Callback, CallbackOrder, PeriodicLoaderCallback, ) # experiment_setup @@ -485,8 +481,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( - SupervisedRunner, State, Callback, CallbackOrder, - PeriodicLoaderCallback, + SupervisedRunner, Callback, CallbackOrder, PeriodicLoaderCallback, ) # experiment_setup @@ -577,8 +572,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( - SupervisedRunner, State, Callback, CallbackOrder, - PeriodicLoaderCallback, + SupervisedRunner, Callback, CallbackOrder, PeriodicLoaderCallback, ) # experiment_setup @@ -687,8 +681,7 @@ PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( - SupervisedRunner, State, Callback, CallbackOrder, - PeriodicLoaderCallback, + SupervisedRunner, Callback, CallbackOrder, PeriodicLoaderCallback, ) # experiment_setup diff --git a/catalyst/__version__.py b/catalyst/__version__.py index db39fa13d8..250fd83ba0 100644 --- a/catalyst/__version__.py +++ b/catalyst/__version__.py @@ -1 +1 @@ -__version__ = "20.05.1" +__version__ = "20.06" diff --git a/catalyst/contrib/dl/callbacks/__init__.py b/catalyst/contrib/dl/callbacks/__init__.py index 13e3bd8fb6..3a52e5ed21 100644 --- a/catalyst/contrib/dl/callbacks/__init__.py +++ b/catalyst/contrib/dl/callbacks/__init__.py @@ -4,18 +4,18 @@ from catalyst.tools import settings from .cutmix_callback import CutmixCallback -from .knn import KNNMetricCallback -from .optimizer import SaveModelGradsCallback -from .periodic_loader import PeriodicLoaderCallback -from .perplexity import PerplexityMetricCallback +from .gradnorm_logger import GradNormLogger +from .knn_metric import KNNMetricCallback +from .periodic_loader_callback import PeriodicLoaderCallback +from .perplexity_metric import PerplexityMetricCallback from .telegram_logger import TelegramLogger -from .trace import TracerCallback +from .tracer_callback import TracerCallback logger = logging.getLogger(__name__) try: import imageio - from .inference import InferMaskCallback + from .mask_inference import InferMaskCallback except ImportError as ex: if settings.cv_required: logger.warning( @@ -26,7 +26,7 @@ try: import alchemy - from .alchemy import AlchemyLogger + from .alchemy_logger import AlchemyLogger except ImportError as ex: if settings.alchemy_logger_required: logger.warning( @@ -48,7 +48,7 @@ try: import neptune - from .neptune import NeptuneLogger + from .neptune_logger import NeptuneLogger except ImportError as ex: if settings.neptune_logger_required: logger.warning( @@ -59,7 +59,7 @@ try: import wandb - from .wandb import WandbLogger + from .wandb_logger import WandbLogger except ImportError as ex: if settings.wandb_logger_required: logger.warning( diff --git a/catalyst/contrib/dl/callbacks/alchemy.py b/catalyst/contrib/dl/callbacks/alchemy_logger.py similarity index 85% rename from catalyst/contrib/dl/callbacks/alchemy.py rename to catalyst/contrib/dl/callbacks/alchemy_logger.py index 3e4a9ae390..334afa5116 100644 --- a/catalyst/contrib/dl/callbacks/alchemy.py +++ b/catalyst/contrib/dl/callbacks/alchemy_logger.py @@ -3,17 +3,17 @@ from alchemy import Logger from catalyst import utils -from catalyst.core import ( +from catalyst.core.callback import ( Callback, CallbackNode, CallbackOrder, CallbackScope, - State, ) +from catalyst.core.runner import _Runner class AlchemyLogger(Callback): - """Logger callback, translates ``state.*_metrics`` to Alchemy. + """Logger callback, translates ``runner.*_metrics`` to Alchemy. Read about Alchemy here https://alchemy.host Example: @@ -101,43 +101,43 @@ def _log_metrics( name=metric_name, value=metric_value, step=step, ) - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Translate batch metrics to Alchemy.""" if self.log_on_batch_end: - mode = state.loader_name - metrics_ = state.batch_metrics + mode = runner.loader_name + metrics_ = runner.batch_metrics self._log_metrics( metrics=metrics_, - step=state.global_sample_step, + step=runner.global_sample_step, mode=mode, suffix=self.batch_log_suffix, ) - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """Translate loader metrics to Alchemy.""" if self.log_on_epoch_end: - mode = state.loader_name - metrics_ = state.loader_metrics + mode = runner.loader_name + metrics_ = runner.loader_metrics self._log_metrics( metrics=metrics_, - step=state.global_epoch, + step=runner.global_epoch, mode=mode, suffix=self.epoch_log_suffix, ) - def on_epoch_end(self, state: State): + def on_epoch_end(self, runner: _Runner): """Translate epoch metrics to Alchemy.""" extra_mode = "_base" splitted_epoch_metrics = utils.split_dict_to_subdicts( - dct=state.epoch_metrics, - prefixes=list(state.loaders.keys()), + dct=runner.epoch_metrics, + prefixes=list(runner.loaders.keys()), extra_key=extra_mode, ) if self.log_on_epoch_end: self._log_metrics( metrics=splitted_epoch_metrics[extra_mode], - step=state.global_epoch, + step=runner.global_epoch, mode=extra_mode, suffix=self.epoch_log_suffix, ) diff --git a/catalyst/contrib/dl/callbacks/cutmix_callback.py b/catalyst/contrib/dl/callbacks/cutmix_callback.py index e86ea0604b..c457910579 100644 --- a/catalyst/contrib/dl/callbacks/cutmix_callback.py +++ b/catalyst/contrib/dl/callbacks/cutmix_callback.py @@ -4,7 +4,8 @@ import torch -from catalyst.dl import CriterionCallback, State +from catalyst.core.callbacks import CriterionCallback +from catalyst.core.runner import _Runner class CutmixCallback(CriterionCallback): @@ -51,22 +52,22 @@ def __init__( self.index = None self.is_needed = True - def _compute_loss(self, state: State, criterion): + def _compute_loss(self, runner: _Runner, criterion): """Computes loss. If self.is_needed is ``False`` then calls ``_compute_loss`` from ``CriterionCallback``, otherwise computes loss value. Args: - state (State): current state + runner (_Runner): current runner criterion: that is used to compute loss """ if not self.is_needed: - return super()._compute_loss_value(state, criterion) + return super()._compute_loss_value(runner, criterion) - pred = state.output[self.output_key] - y_a = state.input[self.input_key] - y_b = state.input[self.input_key][self.index] + pred = runner.output[self.output_key] + y_a = runner.input[self.input_key] + y_b = runner.input[self.input_key][self.index] loss = self.lam * criterion(pred, y_a) + (1 - self.lam) * criterion( pred, y_b ) @@ -100,19 +101,19 @@ def _rand_bbox(self, size, lam): return bbx1, bby1, bbx2, bby2 - def on_loader_start(self, state: State) -> None: + def on_loader_start(self, runner: _Runner) -> None: """Checks if it is needed for the loader. Args: - state (State): current state + runner (_Runner): current runner """ - self.is_needed = not self.on_train_only or state.is_train_loader + self.is_needed = not self.on_train_only or runner.is_train_loader - def on_batch_start(self, state: State) -> None: + def on_batch_start(self, runner: _Runner) -> None: """Mixes data according to Cutmix algorithm. Args: - state (State): current state + runner (_Runner): current runner """ if not self.is_needed: return @@ -122,15 +123,15 @@ def on_batch_start(self, state: State) -> None: else: self.lam = 1 - self.index = torch.randperm(state.input[self.fields[0]].shape[0]) - self.index.to(state.device) + self.index = torch.randperm(runner.input[self.fields[0]].shape[0]) + self.index.to(runner.device) bbx1, bby1, bbx2, bby2 = self._rand_bbox( - state.input[self.fields[0]].shape, self.lam + runner.input[self.fields[0]].shape, self.lam ) for f in self.fields: - state.input[f][:, :, bbx1:bbx2, bby1:bby2] = state.input[f][ + runner.input[f][:, :, bbx1:bbx2, bby1:bby2] = runner.input[f][ self.index, :, bbx1:bbx2, bby1:bby2 ] @@ -138,8 +139,8 @@ def on_batch_start(self, state: State) -> None: (bbx2 - bbx1) * (bby2 - bby1) / ( - state.input[self.fields[0]].shape[-1] - * state.input[self.fields[0]].shape[-2] + runner.input[self.fields[0]].shape[-1] + * runner.input[self.fields[0]].shape[-2] ) ) diff --git a/catalyst/contrib/dl/callbacks/optimizer.py b/catalyst/contrib/dl/callbacks/gradnorm_logger.py similarity index 86% rename from catalyst/contrib/dl/callbacks/optimizer.py rename to catalyst/contrib/dl/callbacks/gradnorm_logger.py index b5a523f12f..d03a08f28e 100644 --- a/catalyst/contrib/dl/callbacks/optimizer.py +++ b/catalyst/contrib/dl/callbacks/gradnorm_logger.py @@ -3,11 +3,12 @@ from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel -from catalyst.core import Callback, CallbackNode, CallbackOrder, State +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner from catalyst.tools.typing import Model -class SaveModelGradsCallback(Callback): +class GradNormLogger(Callback): """Callback for logging model gradients.""" def __init__( @@ -64,13 +65,13 @@ def grad_norm(*, model: Model, prefix: str, norm_type: int,) -> Dict: return grad_norm - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """On batch end event Args: - state (State): current state + runner (_Runner): current runner """ - if not state.is_train_loader: + if not runner.is_train_loader: return self._accumulation_counter += 1 @@ -80,13 +81,13 @@ def on_batch_end(self, state: State) -> None: if need_gradient_step: grad_norm = self.grad_norm( - model=state.model, + model=runner.model, prefix=self.grad_norm_prefix, norm_type=self.norm_type, ) - state.batch_metrics.update(**grad_norm) + runner.batch_metrics.update(**grad_norm) self._accumulation_counter = 0 -__all__ = ["SaveModelGradsCallback"] +__all__ = ["GradNormLogger"] diff --git a/catalyst/contrib/dl/callbacks/knn.py b/catalyst/contrib/dl/callbacks/knn_metric.py similarity index 91% rename from catalyst/contrib/dl/callbacks/knn.py rename to catalyst/contrib/dl/callbacks/knn_metric.py index 21e8282103..9734cbc091 100644 --- a/catalyst/contrib/dl/callbacks/knn.py +++ b/catalyst/contrib/dl/callbacks/knn_metric.py @@ -13,11 +13,12 @@ import torch -from catalyst.dl import Callback, CallbackOrder, State +from catalyst.core.callback import Callback, CallbackOrder +from catalyst.core.runner import _Runner class KNNMetricCallback(Callback): - """A callback that returns single metric on ``state.on_loader_end``.""" + """A callback that returns single metric on ``runner.on_loader_end``.""" def __init__( self, @@ -164,27 +165,27 @@ def _knn(self, train_set, test_set=None): return result - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - features: torch.Tensor = state.output[ + features: torch.Tensor = runner.output[ self.features_key ].cpu().detach().numpy() - targets: torch.Tensor = state.input[ + targets: torch.Tensor = runner.input[ self.targets_key ].cpu().detach().numpy() self.features.extend(features) self.targets.extend(targets) - def on_loader_end(self, state: State) -> None: + def on_loader_end(self, runner: _Runner) -> None: """Loader end hook. Args: - state (State): current state + runner (_Runner): current runner """ self.features = np.stack(self.features) self.targets = np.stack(self.targets) @@ -197,11 +198,11 @@ def on_loader_end(self, state: State) -> None: "labels": self.targets, } - self.sets[state.loader_name] = s + self.sets[runner.loader_name] = s y_true, y_pred = self._knn(s) - loader_values = state.loader_metrics + loader_values = runner.loader_metrics if self.num_classes == 2: loader_values[self.prefix] = self.metric_fn( y_true, y_pred, average="binary" @@ -215,11 +216,11 @@ def on_loader_end(self, state: State) -> None: self._reset_cache() - def on_epoch_end(self, state: State) -> None: + def on_epoch_end(self, runner: _Runner) -> None: """Epoch end hook. Args: - state (State): current state + runner (_Runner): current runner """ if self.cv_loader_names is not None: for k, vs in self.cv_loader_names.items(): @@ -244,7 +245,7 @@ def on_epoch_end(self, state: State) -> None: y_true, y_pred = self._knn(self.sets[k], self.sets[v]) - loader_values = state.epoch_metrics[f"{k}_{v}_cv"] + loader_values = runner.epoch_metrics[f"{k}_{v}_cv"] if self.num_classes == 2: loader_values[f"{self.prefix}"] = self.metric_fn( diff --git a/catalyst/contrib/dl/callbacks/inference.py b/catalyst/contrib/dl/callbacks/mask_inference.py similarity index 79% rename from catalyst/contrib/dl/callbacks/inference.py rename to catalyst/contrib/dl/callbacks/mask_inference.py index 0176f5daac..9db612f1fc 100644 --- a/catalyst/contrib/dl/callbacks/inference.py +++ b/catalyst/contrib/dl/callbacks/mask_inference.py @@ -7,7 +7,9 @@ import torch import torch.nn.functional as F -from catalyst.dl import Callback, CallbackOrder, State, utils +from catalyst.core.callback import Callback, CallbackOrder +from catalyst.core.runner import _Runner +from catalyst.dl import utils class InferMaskCallback(Callback): @@ -44,16 +46,16 @@ def __init__( self.output_key = output_key self.name_key = name_key self.counter = 0 - self._keys_from_state = ["out_dir", "out_prefix"] + self._keys_from_runner = ["out_dir", "out_prefix"] - def on_stage_start(self, state: State): + def on_stage_start(self, runner: _Runner): """Stage start hook. Args: - state (State): current state + runner (_Runner): current runner """ - for key in self._keys_from_state: - value = getattr(state, key, None) + for key in self._keys_from_runner: + value = getattr(runner, key, None) if value is not None: setattr(self, key, value) # assert self.out_prefix is not None @@ -64,28 +66,28 @@ def on_stage_start(self, state: State): self.out_prefix = str(self.out_dir) + "/" + str(self.out_prefix) os.makedirs(os.path.dirname(self.out_prefix), exist_ok=True) - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ - lm = state.loader_name + lm = runner.loader_name os.makedirs(f"{self.out_prefix}/{lm}/", exist_ok=True) - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - lm = state.loader_name - names = state.input.get(self.name_key, []) + lm = runner.loader_name + names = runner.input.get(self.name_key, []) - features = state.input[self.input_key].detach().cpu() + features = runner.input[self.input_key].detach().cpu() images = utils.tensor_to_ndimage(features) - logits = state.output[self.output_key] + logits = runner.output[self.output_key] logits = ( torch.unsqueeze_(logits, dim=1) if len(logits.shape) < 4 diff --git a/catalyst/contrib/dl/callbacks/neptune.py b/catalyst/contrib/dl/callbacks/neptune_logger.py similarity index 90% rename from catalyst/contrib/dl/callbacks/neptune.py rename to catalyst/contrib/dl/callbacks/neptune_logger.py index 36c3e1addb..04a1b417b3 100644 --- a/catalyst/contrib/dl/callbacks/neptune.py +++ b/catalyst/contrib/dl/callbacks/neptune_logger.py @@ -2,17 +2,17 @@ import neptune -from catalyst.core import ( +from catalyst.core.callback import ( Callback, CallbackNode, CallbackOrder, CallbackScope, - State, ) +from catalyst.core.runner import _Runner class NeptuneLogger(Callback): - """Logger callback, translates ``state.*_metrics`` to Neptune. + """Logger callback, translates ``runner.*_metrics`` to Neptune. Read about Neptune here https://neptune.ai Example: @@ -135,26 +135,26 @@ def _log_metrics( metric_value = metrics[name] self.experiment.log_metric(metric_name, y=metric_value, x=step) - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Log batch metrics to Neptune.""" if self.log_on_batch_end: - mode = state.loader_name - metrics_ = state.batch_metrics + mode = runner.loader_name + metrics_ = runner.batch_metrics self._log_metrics( metrics=metrics_, - step=state.global_sample_step, + step=runner.global_sample_step, mode=mode, suffix=self.batch_log_suffix, ) - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """Translate epoch metrics to Neptune.""" if self.log_on_epoch_end: - mode = state.loader_name - metrics_ = state.loader_metrics + mode = runner.loader_name + metrics_ = runner.loader_metrics self._log_metrics( metrics=metrics_, - step=state.global_epoch, + step=runner.global_epoch, mode=mode, suffix=self.epoch_log_suffix, ) diff --git a/catalyst/contrib/dl/callbacks/periodic_loader.py b/catalyst/contrib/dl/callbacks/periodic_loader_callback.py similarity index 77% rename from catalyst/contrib/dl/callbacks/periodic_loader.py rename to catalyst/contrib/dl/callbacks/periodic_loader_callback.py index 7495679781..d7ad1288fe 100644 --- a/catalyst/contrib/dl/callbacks/periodic_loader.py +++ b/catalyst/contrib/dl/callbacks/periodic_loader_callback.py @@ -4,7 +4,8 @@ from torch.utils.data import DataLoader -from catalyst.core import Callback, CallbackOrder, State +from catalyst.core.callback import Callback, CallbackOrder +from catalyst.core.runner import _Runner class PeriodicLoaderCallback(Callback): @@ -40,21 +41,23 @@ def __init__(self, **kwargs): ) self.loader_periods[loader] = int(period) - def on_stage_start(self, state: State) -> None: + def on_stage_start(self, runner: _Runner) -> None: """Collect information about loaders. Arguments: - state (State): training state + runner (_Runner): current runner """ # store pointers to data loader objects - for name, loader in state.loaders.items(): + for name, loader in runner.loaders.items(): self.loaders[name] = loader # stage validation loader - self.valid_loader = copy.copy(state.valid_loader) + self.valid_loader = copy.copy(runner.valid_loader) is_loaders_match = all( - loader in state.loaders for loader in self.loader_periods.keys() + loader in runner.loaders for loader in self.loader_periods.keys() + ) + is_same_loaders_number = len(self.loader_periods) == len( + runner.loaders ) - is_same_loaders_number = len(self.loader_periods) == len(state.loaders) if is_same_loaders_number and is_loaders_match: # find potential epoch with zero loaders zero_loaders_epochs = list( @@ -63,7 +66,7 @@ def on_stage_start(self, state: State) -> None: (p == 0 or n % p != 0) for p in self.loader_periods.values() ), - range(1, state.num_epochs + 1), + range(1, runner.num_epochs + 1), ) ) if len(zero_loaders_epochs) > 0: @@ -72,7 +75,7 @@ def on_stage_start(self, state: State) -> None: f"There will be no loaders in epoch {epoch_with_err}!" ) - def on_epoch_start(self, state: State) -> None: + def on_epoch_start(self, runner: _Runner) -> None: """Set loaders for current epoch. If validation is not required then the first loader from loaders used in current epoch will be used @@ -82,9 +85,9 @@ def on_epoch_start(self, state: State) -> None: in the epochs where this loader is missing. Arguments: - state (State): training state + runner (_Runner): current runner """ - epoch_num = state.epoch + epoch_num = runner.epoch # loaders to use in current epoch epoch_loaders = OrderedDict() for name, loader in self.loaders.items(): @@ -95,27 +98,27 @@ def on_epoch_start(self, state: State) -> None: if len(epoch_loaders) == 0: raise ValueError(f"There is no loaders in epoch {epoch_num}!") first_loader = next(iter(epoch_loaders.keys())) - state.valid_loader = ( + runner.valid_loader = ( self.valid_loader if self.valid_loader in epoch_loaders else first_loader ) - state.loaders = epoch_loaders + runner.loaders = epoch_loaders - def on_epoch_end(self, state: State) -> None: + def on_epoch_end(self, runner: _Runner) -> None: """Store validation metrics and use latest validation score when validation loader is not required. Arguments: - state (State): training state + runner (_Runner): current runner """ - if self.valid_loader in state.loaders: + if self.valid_loader in runner.loaders: self.valid_metrics = { - state.main_metric: state.valid_metrics[state.main_metric] + runner.main_metric: runner.valid_metrics[runner.main_metric] } elif self.valid_metrics is not None: # use previous score on validation - state.valid_metrics = self.valid_metrics + runner.valid_metrics = self.valid_metrics __all__ = ["PeriodicLoaderCallback"] diff --git a/catalyst/contrib/dl/callbacks/perplexity.py b/catalyst/contrib/dl/callbacks/perplexity_metric.py similarity index 100% rename from catalyst/contrib/dl/callbacks/perplexity.py rename to catalyst/contrib/dl/callbacks/perplexity_metric.py diff --git a/catalyst/contrib/dl/callbacks/telegram_logger.py b/catalyst/contrib/dl/callbacks/telegram_logger.py index 3e84e50b3e..f74763c0c8 100644 --- a/catalyst/contrib/dl/callbacks/telegram_logger.py +++ b/catalyst/contrib/dl/callbacks/telegram_logger.py @@ -4,13 +4,14 @@ from urllib.request import Request, urlopen from catalyst import utils -from catalyst.core import Callback, CallbackNode, CallbackOrder, State +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner from catalyst.tools import settings class TelegramLogger(Callback): """ - Logger callback, translates ``state.metric_manager`` to telegram channel. + Logger callback, translates ``runner.metric_manager`` to telegram channel. """ def __init__( @@ -68,26 +69,26 @@ def _send_text(self, text: str): except Exception as e: logging.getLogger(__name__).warning(f"telegram.send.error:{e}") - def on_stage_start(self, state: State): + def on_stage_start(self, runner: _Runner): """Notify about starting a new stage.""" if self.log_on_stage_start: - text = f"{state.stage_name} stage was started" + text = f"{runner.stage_name} stage was started" self._send_text(text) - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """Notify about starting running the new loader.""" if self.log_on_loader_start: text = ( - f"{state.loader_name} {state.global_epoch} epoch has started" + f"{runner.loader_name} {runner.global_epoch} epoch has started" ) self._send_text(text) - def on_loader_end(self, state: State): - """Translate ``state.metric_manager`` to telegram channel.""" + def on_loader_end(self, runner: _Runner): + """Translate ``runner.metric_manager`` to telegram channel.""" if self.log_on_loader_end: - metrics = state.loader_metrics + metrics = runner.loader_metrics if self.metrics_to_log is None: metrics_to_log = sorted(metrics.keys()) @@ -95,7 +96,8 @@ def on_loader_end(self, state: State): metrics_to_log = self.metrics_to_log rows: List[str] = [ - f"{state.loader_name} {state.global_epoch} epoch was finished:" + f"{runner.loader_name} {runner.global_epoch}" + f" epoch was finished:" ] for name in metrics_to_log: @@ -106,17 +108,17 @@ def on_loader_end(self, state: State): self._send_text(text) - def on_stage_end(self, state: State): + def on_stage_end(self, runner: _Runner): """Notify about finishing a stage.""" if self.log_on_stage_end: - text = f"{state.stage_name} stage was finished" + text = f"{runner.stage_name} stage was finished" self._send_text(text) - def on_exception(self, state: State): + def on_exception(self, runner: _Runner): """Notify about raised ``Exception``.""" if self.log_on_exception: - exception = state.exception + exception = runner.exception if utils.is_exception(exception) and not isinstance( exception, KeyboardInterrupt ): diff --git a/catalyst/contrib/dl/callbacks/tests/test_optimizer_callback.py b/catalyst/contrib/dl/callbacks/tests/test_gradnorm_logger.py similarity index 88% rename from catalyst/contrib/dl/callbacks/tests/test_optimizer_callback.py rename to catalyst/contrib/dl/callbacks/tests/test_gradnorm_logger.py index 4b063c20b4..c649269daa 100644 --- a/catalyst/contrib/dl/callbacks/tests/test_optimizer_callback.py +++ b/catalyst/contrib/dl/callbacks/tests/test_gradnorm_logger.py @@ -11,14 +11,10 @@ from catalyst.contrib import registry from catalyst.contrib.data.transforms import ToTensor from catalyst.contrib.datasets import MNIST -from catalyst.contrib.dl.callbacks.optimizer import SaveModelGradsCallback -from catalyst.core import ( - Callback, - CallbackOrder, - CriterionCallback, - OptimizerCallback, - State, -) +from catalyst.contrib.dl.callbacks.gradnorm_logger import GradNormLogger +from catalyst.core.callback import Callback, CallbackOrder +from catalyst.core.callbacks import CriterionCallback, OptimizerCallback +from catalyst.core.runner import _Runner from catalyst.dl import SupervisedRunner @@ -107,19 +103,19 @@ def __init__(self, prefix: str): super().__init__(CallbackOrder.External) self.prefix = prefix - def on_batch_end(self, state: State): - if not state.is_train_loader: + def on_batch_end(self, runner: _Runner): + if not runner.is_train_loader: return for layer in ["conv1", "conv2", "fc1"]: for weights in ["weight", "bias"]: tag = f"{self.prefix}/{layer}/{weights}" - assert tag in state.batch_metrics - assert isinstance(state.batch_metrics[tag], Number) + assert tag in runner.batch_metrics + assert isinstance(runner.batch_metrics[tag], Number) tag = f"{self.prefix}/total" - assert tag in state.batch_metrics - assert isinstance(state.batch_metrics[tag], Number) + assert tag in runner.batch_metrics + assert isinstance(runner.batch_metrics[tag], Number) def test_save_model_grads(): @@ -139,7 +135,7 @@ def test_save_model_grads(): criterion_callback = CriterionCallback() optimizer_callback = OptimizerCallback() - save_model_grads_callback = SaveModelGradsCallback() + save_model_grads_callback = GradNormLogger() prefix = save_model_grads_callback.grad_norm_prefix test_callback = _OnBatchEndCheckGradsCallback(prefix) diff --git a/catalyst/contrib/dl/callbacks/tests/test_perplexity_callback.py b/catalyst/contrib/dl/callbacks/tests/test_perplexity_callback.py index e8357d969b..a0b123981a 100644 --- a/catalyst/contrib/dl/callbacks/tests/test_perplexity_callback.py +++ b/catalyst/contrib/dl/callbacks/tests/test_perplexity_callback.py @@ -21,13 +21,13 @@ def _handle_batch(self, batch): loss = output[0] logits = output[1].view(-1, vocab_size) - self.state.batch_metrics = {"loss": loss} + self.batch_metrics = {"loss": loss} if masked_lm_labels is not None: - self.state.input["targets"] = masked_lm_labels.view(-1) - self.state.output = {"loss": loss, "logits": logits} + self.input["targets"] = masked_lm_labels.view(-1) + self.output = {"loss": loss, "logits": logits} else: - self.state.input["targets"] = lm_labels.view(-1) - self.state.output = {"loss": loss, "logits": logits} + self.input["targets"] = lm_labels.view(-1) + self.output = {"loss": loss, "logits": logits} texts = [ diff --git a/catalyst/contrib/dl/callbacks/tests/test_trace_callback.py b/catalyst/contrib/dl/callbacks/tests/test_tracer_callback.py similarity index 95% rename from catalyst/contrib/dl/callbacks/tests/test_trace_callback.py rename to catalyst/contrib/dl/callbacks/tests/test_tracer_callback.py index d6617840dd..2a1c2edb97 100644 --- a/catalyst/contrib/dl/callbacks/tests/test_trace_callback.py +++ b/catalyst/contrib/dl/callbacks/tests/test_tracer_callback.py @@ -11,14 +11,10 @@ from catalyst.contrib import registry from catalyst.contrib.data.transforms import ToTensor from catalyst.contrib.datasets import MNIST -from catalyst.contrib.dl.callbacks.trace import TracerCallback -from catalyst.core import ( - Callback, - CallbackOrder, - CriterionCallback, - OptimizerCallback, - State, -) +from catalyst.contrib.dl.callbacks.tracer_callback import TracerCallback +from catalyst.core.callback import Callback, CallbackOrder +from catalyst.core.callbacks import CriterionCallback, OptimizerCallback +from catalyst.core.runner import _Runner from catalyst.dl import SupervisedRunner from catalyst.dl.utils import get_device, get_trace_name @@ -148,10 +144,10 @@ def __init__(self, path: Union[str, Path], inputs: torch.Tensor): self.inputs: torch.Tensor = inputs self.device = get_device() - def on_stage_end(self, state: State): + def on_stage_end(self, runner: _Runner): """ Args: - state (State): Current state. + runner (_Runner): current runner """ assert self.path.exists(), "Traced model was not found" diff --git a/catalyst/contrib/dl/callbacks/trace.py b/catalyst/contrib/dl/callbacks/tracer_callback.py similarity index 87% rename from catalyst/contrib/dl/callbacks/trace.py rename to catalyst/contrib/dl/callbacks/tracer_callback.py index 13e29336cc..d47596b099 100644 --- a/catalyst/contrib/dl/callbacks/trace.py +++ b/catalyst/contrib/dl/callbacks/tracer_callback.py @@ -3,8 +3,9 @@ from pathlib import Path import warnings -from catalyst.dl import Callback, CallbackNode, CallbackOrder, State -from catalyst.dl.utils import save_traced_model, trace_model_from_state +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner +from catalyst.dl.utils import save_traced_model, trace_model_from_runner class TracerCallback(Callback): @@ -88,12 +89,12 @@ def __init__( out_dir = Path(out_dir) self.out_dir = out_dir - def _trace(self, state: State): + def _trace(self, runner: _Runner): """ Performing model tracing on epoch end if condition metric is improved. Args: - state (State): Current state + runner (_Runner): Current runner """ if self.opt_level is not None: device = "cuda" @@ -106,8 +107,8 @@ def _trace(self, state: State): if self.do_once and self.mode == "best": checkpoint_name_to_restore = "best" - traced_model = trace_model_from_state( - state=state, + traced_model = trace_model_from_runner( + runner=runner, checkpoint_name=checkpoint_name_to_restore, method_name=self.method_name, mode=self.trace_mode, @@ -118,7 +119,7 @@ def _trace(self, state: State): save_traced_model( model=traced_model, - logdir=state.logdir, + logdir=runner.logdir, checkpoint_name=self.mode, method_name=self.method_name, mode=self.trace_mode, @@ -128,16 +129,16 @@ def _trace(self, state: State): out_dir=self.out_dir, ) - def on_epoch_end(self, state: State): + def on_epoch_end(self, runner: _Runner): """ Performing model tracing on epoch end if condition metric is improved Args: - state (State): Current state + runner (_Runner): Current runner """ if not self.do_once: if self.mode == "best": - score = state.valid_metrics[self.metric] + score = runner.valid_metrics[self.metric] if self.best_score is None: self.best_score = score @@ -146,19 +147,19 @@ def on_epoch_end(self, state: State): # will never work very first epoch if self.is_better(score, self.best_score): self.best_score = score - self._trace(state) + self._trace(runner) else: - self._trace(state) + self._trace(runner) - def on_stage_end(self, state: State): + def on_stage_end(self, runner: _Runner): """ Performing model tracing on stage end if `do_once` is True. Args: - state (State): Current state + runner (_Runner): Current runner """ if self.do_once: - self._trace(state) + self._trace(runner) __all__ = ["TracerCallback"] diff --git a/catalyst/contrib/dl/callbacks/visdom_logger.py b/catalyst/contrib/dl/callbacks/visdom_logger.py index 0d736477a5..75178d1800 100644 --- a/catalyst/contrib/dl/callbacks/visdom_logger.py +++ b/catalyst/contrib/dl/callbacks/visdom_logger.py @@ -8,17 +8,17 @@ from alchemy.logger import Logger import visdom -from catalyst.core import ( +from catalyst.core.callback import ( Callback, CallbackNode, CallbackOrder, CallbackScope, - State, ) +from catalyst.core.runner import _Runner class Visdom(Logger): - """Logger, translates ``state.*_metrics`` to Visdom. + """Logger, translates ``runner.*_metrics`` to Visdom. Read about Visdom here https://github.com/facebookresearch/visdom Example: @@ -149,7 +149,7 @@ def log_scalar( class VisdomLogger(Callback): - """Logger callback, translates ``state.*_metrics`` to Visdom. + """Logger callback, translates ``runner.*_metrics`` to Visdom. Read about Visdom here https://github.com/facebookresearch/visdom Example: @@ -246,25 +246,25 @@ def __del__(self): """@TODO: Docs. Contribution is welcome.""" self.logger.close() - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Translate batch metrics to Visdom.""" if self.log_on_batch_end: - mode = state.loader_name - metrics_ = state.batch_metrics + mode = runner.loader_name + metrics_ = runner.batch_metrics self._log_metrics( metrics=metrics_, - step=state.global_sample_step, + step=runner.global_sample_step, mode=mode, suffix=self.batch_log_suffix, ) - def on_epoch_end(self, state: State): + def on_epoch_end(self, runner: _Runner): """Translate epoch metrics to Visdom.""" if self.log_on_epoch_end: self._log_metrics( - metrics=state.epoch_metrics, - step=state.global_epoch, - mode=state.loader_name, + metrics=runner.epoch_metrics, + step=runner.global_epoch, + mode=runner.loader_name, suffix=self.epoch_log_suffix, ) diff --git a/catalyst/contrib/dl/callbacks/wandb.py b/catalyst/contrib/dl/callbacks/wandb_logger.py similarity index 85% rename from catalyst/contrib/dl/callbacks/wandb.py rename to catalyst/contrib/dl/callbacks/wandb_logger.py index 089965983d..24c5e5c30e 100644 --- a/catalyst/contrib/dl/callbacks/wandb.py +++ b/catalyst/contrib/dl/callbacks/wandb_logger.py @@ -3,17 +3,17 @@ import wandb from catalyst import utils -from catalyst.core import ( +from catalyst.core.callback import ( Callback, CallbackNode, CallbackOrder, CallbackScope, - State, ) +from catalyst.core.runner import _Runner class WandbLogger(Callback): - """Logger callback, translates ``state.*_metrics`` to Weights & Biases. + """Logger callback, translates ``runner.*_metrics`` to Weights & Biases. Read about Weights & Biases here https://docs.wandb.com/ Example: @@ -133,53 +133,53 @@ def key_locate(key: str): } wandb.log(metrics, step=step, commit=commit) - def on_stage_start(self, state: State): + def on_stage_start(self, runner: _Runner): """Initialize Weights & Biases.""" - wandb.init(**self.logging_params, reinit=True, dir=str(state.logdir)) + wandb.init(**self.logging_params, reinit=True, dir=str(runner.logdir)) - def on_stage_end(self, state: State): + def on_stage_end(self, runner: _Runner): """Finish logging to Weights & Biases.""" wandb.join() - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Translate batch metrics to Weights & Biases.""" if self.log_on_batch_end: - mode = state.loader_name - metrics_ = state.batch_metrics + mode = runner.loader_name + metrics_ = runner.batch_metrics self._log_metrics( metrics=metrics_, - step=state.global_sample_step, + step=runner.global_sample_step, mode=mode, suffix=self.batch_log_suffix, commit=True, ) - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """Translate loader metrics to Weights & Biases.""" if self.log_on_epoch_end: - mode = state.loader_name - metrics_ = state.loader_metrics + mode = runner.loader_name + metrics_ = runner.loader_metrics self._log_metrics( metrics=metrics_, - step=state.global_epoch, + step=runner.global_epoch, mode=mode, suffix=self.epoch_log_suffix, commit=False, ) - def on_epoch_end(self, state: State): + def on_epoch_end(self, runner: _Runner): """Translate epoch metrics to Weights & Biases.""" extra_mode = "_base" splitted_epoch_metrics = utils.split_dict_to_subdicts( - dct=state.epoch_metrics, - prefixes=list(state.loaders.keys()), + dct=runner.epoch_metrics, + prefixes=list(runner.loaders.keys()), extra_key=extra_mode, ) if self.log_on_epoch_end: self._log_metrics( metrics=splitted_epoch_metrics[extra_mode], - step=state.global_epoch, + step=runner.global_epoch, mode=extra_mode, suffix=self.epoch_log_suffix, commit=True, diff --git a/catalyst/contrib/dl/experiment/__init__.py b/catalyst/contrib/dl/experiment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/catalyst/contrib/dl/runner/__init__.py b/catalyst/contrib/dl/runner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/catalyst/core/__init__.py b/catalyst/core/__init__.py index 3432648a38..de20ef250a 100644 --- a/catalyst/core/__init__.py +++ b/catalyst/core/__init__.py @@ -1,14 +1,13 @@ # flake8: noqa # isort:skip_file # import order: -# state # callback # callbacks # experiment # runner -from .state import State -from .callback import Callback, CallbackOrder, CallbackNode, CallbackScope -from .callbacks import * from .experiment import _Experiment from .runner import _Runner, _StageBasedRunner +from .callback import Callback, CallbackOrder, CallbackNode, CallbackScope +from .callbacks import * +from .state import State diff --git a/catalyst/core/callback.py b/catalyst/core/callback.py index d4f08c4099..d7a74d6c31 100644 --- a/catalyst/core/callback.py +++ b/catalyst/core/callback.py @@ -2,7 +2,7 @@ from enum import IntFlag if TYPE_CHECKING: - from .state import State + from catalyst.core.runner import _Runner class CallbackNode(IntFlag): @@ -108,7 +108,6 @@ class Callback: - :py:mod:`catalyst.core.experiment._Experiment` - :py:mod:`catalyst.core.runner._Runner` - - :py:mod:`catalyst.core.state.State` - :py:mod:`catalyst.core.callback.Callback` Abstraction, please check out the implementations: @@ -137,75 +136,75 @@ def __init__( self.order = order self.scope = scope - def on_stage_start(self, state: "State"): + def on_stage_start(self, runner: "_Runner"): """Event handler for stage start. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_stage_end(self, state: "State"): + def on_stage_end(self, runner: "_Runner"): """Event handler for stage end. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_epoch_start(self, state: "State"): + def on_epoch_start(self, runner: "_Runner"): """Event handler for epoch start. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_epoch_end(self, state: "State"): + def on_epoch_end(self, runner: "_Runner"): """Event handler for epoch end. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_loader_start(self, state: "State"): + def on_loader_start(self, runner: "_Runner"): """Event handler for loader start. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_loader_end(self, state: "State"): + def on_loader_end(self, runner: "_Runner"): """Event handler for loader end. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_batch_start(self, state: "State"): + def on_batch_start(self, runner: "_Runner"): """Event handler for batch start. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_batch_end(self, state: "State"): + def on_batch_end(self, runner: "_Runner"): """Event handler for batch end. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass - def on_exception(self, state: "State"): + def on_exception(self, runner: "_Runner"): """Event handler for exception case. Args: - state ("State"): State instance. + runner ("_Runner"): _Runner instance. """ pass diff --git a/catalyst/core/callbacks/checkpoint.py b/catalyst/core/callbacks/checkpoint.py index 0458a67e64..57980578dd 100644 --- a/catalyst/core/callbacks/checkpoint.py +++ b/catalyst/core/callbacks/checkpoint.py @@ -3,39 +3,41 @@ import os from pathlib import Path -from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils +from catalyst.core import utils +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner -def _pack_state(state: State): +def _pack_runner(runner: _Runner): checkpoint = utils.pack_checkpoint( - model=state.model, - criterion=state.criterion, - optimizer=state.optimizer, - scheduler=state.scheduler, - epoch_metrics=dict(state.epoch_metrics), - valid_metrics=dict(state.valid_metrics), - stage_name=state.stage_name, - epoch=state.epoch, - loader_name=state.loader_name, - loader_step=state.loader_batch_step, - global_epoch=state.global_epoch, - checkpoint_data=state.checkpoint_data, - main_metric=state.main_metric, - minimize_metric=state.minimize_metric, - valid_loader=state.valid_loader, + model=runner.model, + criterion=runner.criterion, + optimizer=runner.optimizer, + scheduler=runner.scheduler, + epoch_metrics=dict(runner.epoch_metrics), + valid_metrics=dict(runner.valid_metrics), + stage_name=runner.stage_name, + epoch=runner.epoch, + loader_name=runner.loader_name, + loader_step=runner.loader_batch_step, + global_epoch=runner.global_epoch, + checkpoint_data=runner.checkpoint_data, + main_metric=runner.main_metric, + minimize_metric=runner.minimize_metric, + valid_loader=runner.valid_loader, ) return checkpoint def _load_checkpoint( - *, filename, state: State, load_full: bool = True + *, filename, runner: _Runner, load_full: bool = True ) -> None: """ Load checkpoint from a file. Arguments: filename (str): path to checkpoint - state (State): training state + runner (_Runner): current runner load_full (bool): if true (default) then will be performed loading states for criterion, optimizer and scheduler. File should contain keys required for @@ -54,10 +56,10 @@ def _load_checkpoint( print(f"=> Loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) - if not state.stage_name.startswith("infer") and load_full: - state.stage_name = checkpoint["stage_name"] - state.epoch = checkpoint["epoch"] - state.global_epoch = checkpoint["global_epoch"] + if not runner.stage_name.startswith("infer") and load_full: + runner.stage_name = checkpoint["stage_name"] + runner.epoch = checkpoint["epoch"] + runner.global_epoch = checkpoint["global_epoch"] # @TODO: should we also load, # checkpoint_data, main_metric, minimize_metric, valid_loader ? # epoch_metrics, valid_metrics ? @@ -65,10 +67,10 @@ def _load_checkpoint( if load_full: utils.unpack_checkpoint( checkpoint, - model=state.model, - criterion=state.criterion, - optimizer=state.optimizer, - scheduler=state.scheduler, + model=runner.model, + criterion=runner.criterion, + optimizer=runner.optimizer, + scheduler=runner.scheduler, ) print( @@ -79,7 +81,7 @@ def _load_checkpoint( ) else: utils.unpack_checkpoint( - checkpoint, model=state.model, + checkpoint, model=runner.model, ) print(f"loaded model checkpoint {filename}") @@ -134,14 +136,14 @@ def _required_files(logdir: str, load_map: Dict[str, str]) -> Dict[str, str]: def _load_states_from_file_map( - *, state: State, load_map: Dict[str, str] + *, runner: _Runner, load_map: Dict[str, str] ) -> None: """ Load state of a model, criterion, optimizer, scheduler from files specified in ``load_map``. Arguments: - state (State): training state + runner (_Runner): current runner load_map (Dict[str, str]): dict with mappings to load. Expected keys - ``'model'``, ``'criterion'`` ``'optimizer'``, ``'scheduler'``, other keys will be @@ -156,7 +158,7 @@ def _load_states_from_file_map( FileNotFoundError: when file/state specified in ``load_map`` is not exist. """ - required_files = _required_files(state.logdir, load_map) + required_files = _required_files(runner.logdir, load_map) for filename in required_files.keys(): if not os.path.isfile(filename): @@ -166,7 +168,7 @@ def _load_states_from_file_map( for filename, parts_to_load in required_files.items(): print(f"=> Loading {', '.join(parts_to_load)} from {filename}") checkpoint = utils.load_checkpoint(filename) - to_unpack = {part: getattr(state, part) for part in parts_to_load} + to_unpack = {part: getattr(runner, part) for part in parts_to_load} utils.unpack_checkpoint(checkpoint, **to_unpack) print(f" loaded: {', '.join(parts_to_load)}") @@ -192,25 +194,25 @@ def save_metric(self, logdir: Union[str, Path], metrics: Dict) -> None: metrics, f"{logdir}/checkpoints/{self.metrics_filename}" ) - def on_exception(self, state: State): - exception = state.exception + def on_exception(self, runner: _Runner): + exception = runner.exception if not utils.is_exception(exception): return try: - checkpoint = _pack_state(state) + checkpoint = _pack_runner(runner) suffix = self.get_checkpoint_suffix(checkpoint) suffix = f"{suffix}.exception_{exception.__class__.__name__}" utils.save_checkpoint( - logdir=Path(f"{state.logdir}/checkpoints/"), + logdir=Path(f"{runner.logdir}/checkpoints/"), checkpoint=checkpoint, suffix=suffix, is_best=False, is_last=False, ) metrics = self.metrics - metrics[suffix] = state.valid_metrics - self.save_metric(state.logdir, metrics) + metrics[suffix] = runner.valid_metrics + self.save_metric(runner.logdir, metrics) except Exception: pass @@ -301,7 +303,7 @@ def __init__( Logic for dict is the same as for ``load_on_stage_start``. If ``None`` then no action is required at stage end - and will be used the last state. + and will be used the last runner. **NOTE:** Loading will be performed always at stage end. """ @@ -481,8 +483,8 @@ def process_checkpoint( self.save_metric(logdir, metrics) @staticmethod - def _load_state( - state: State, + def _load_runner( + runner: _Runner, mapping: Union[str, Dict[str, str]], load_full: bool = False, ) -> None: @@ -490,7 +492,7 @@ def _load_state( Selects a loading method based on type of mapping. Args: - state (State): training state + runner (_Runner): current runner mapping (str or dict): mapping to use for loading load_full (bool): load a full model, used only when mapping type is string @@ -498,18 +500,18 @@ def _load_state( """ if isinstance(mapping, str): if mapping in {"best", "best_full", "last", "last_full"}: - checkpoint = f"{state.logdir}/checkpoints/{mapping}.pth" + checkpoint = f"{runner.logdir}/checkpoints/{mapping}.pth" else: checkpoint = mapping _load_checkpoint( - filename=checkpoint, state=state, load_full=load_full, + filename=checkpoint, runner=runner, load_full=load_full, ) elif isinstance(mapping, dict): _load_states_from_file_map( - state=state, load_map=mapping, + runner=runner, load_map=mapping, ) - def on_stage_start(self, state: State) -> None: + def on_stage_start(self, runner: _Runner) -> None: """ Setup model for stage. @@ -519,10 +521,10 @@ def on_stage_start(self, state: State) -> None: then will be performed loading checkpoint. Args: - state (State): training state + runner (_Runner): current runner """ for key in self._keys_from_state: - value = getattr(state, key, None) + value = getattr(runner, key, None) if value is not None: setattr(self, key, value) @@ -530,7 +532,7 @@ def on_stage_start(self, state: State) -> None: self.resume = str(self.resume_dir) + "/" + str(self.resume) if self.resume is not None: - self._load_state(state, mapping=self.resume, load_full=True) + self._load_runner(runner, mapping=self.resume, load_full=True) self.resume = None else: _exists_checkpoint = False @@ -538,69 +540,75 @@ def on_stage_start(self, state: State) -> None: if isinstance(self.load_on_stage_start, str): _exists_checkpoint = os.path.isfile( "{}/checkpoints/{}.pth".format( - state.logdir, self.load_on_stage_start + runner.logdir, self.load_on_stage_start ) ) _load_full = self.load_on_stage_start.endswith("full") elif isinstance(self.load_on_stage_start, dict): required_files = _required_files( - state.logdir, self.load_on_stage_start + runner.logdir, self.load_on_stage_start ).keys() _exists_checkpoint = all( os.path.isfile(file) for file in required_files ) if self.load_on_stage_start is not None and _exists_checkpoint: - self._load_state( - state, + self._load_runner( + runner, mapping=self.load_on_stage_start, load_full=_load_full, ) - def on_epoch_end(self, state: State) -> None: + def on_epoch_end(self, runner: _Runner) -> None: """ Collect and save checkpoint after epoch. Args: - state (State): training state + runner (_Runner): current runner """ - if state.stage_name.startswith("infer") or state.is_distributed_worker: + if ( + runner.stage_name.startswith("infer") + or runner.is_distributed_worker + ): return if self.save_n_best > 0: - checkpoint = _pack_state(state) + checkpoint = _pack_runner(runner) self.process_checkpoint( - logdir=state.logdir, + logdir=runner.logdir, checkpoint=checkpoint, - is_best=state.is_best_valid, - main_metric=state.main_metric, - minimize_metric=state.minimize_metric, + is_best=runner.is_best_valid, + main_metric=runner.main_metric, + minimize_metric=runner.minimize_metric, ) - def on_stage_end(self, state: State) -> None: + def on_stage_end(self, runner: _Runner) -> None: """ Show information about best checkpoints during the stage and load model specified in ``load_on_stage_end``. Args: - state (State): training state + runner (_Runner): current runner """ - if state.stage_name.startswith("infer") or state.is_distributed_worker: + if ( + runner.stage_name.startswith("infer") + or runner.is_distributed_worker + ): return log_message = "Top best models:\n" # store latest state if self.save_n_best == 0: - checkpoint = _pack_state(state) + checkpoint = _pack_runner(runner) _, filepath = self._save_checkpoint( - logdir=state.logdir, + logdir=runner.logdir, checkpoint=checkpoint, suffix="last", is_best=True, # will duplicate current (last) as best is_last=False, # don't need that because current state is last ) metrics = self.process_metrics(checkpoint["valid_metrics"]) - self.save_metric(state.logdir, metrics) - main_metric_value = metrics["last"][state.main_metric] + self.save_metric(runner.logdir, metrics) + main_metric_value = metrics["last"][runner.main_metric] log_message += "{filepath}\t{metric:3.4f}".format( filepath=filepath, metric=main_metric_value ) @@ -625,8 +633,8 @@ def on_stage_end(self, state: State) -> None: if isinstance(self.load_on_stage_end, str) else False ) - self._load_state( - state, mapping=self.load_on_stage_end, load_full=_load_full, + self._load_runner( + runner, mapping=self.load_on_stage_end, load_full=_load_full, ) elif isinstance(self.load_on_stage_end, dict) and self.save_n_best > 0: to_load = { @@ -634,7 +642,7 @@ def on_stage_end(self, state: State) -> None: for k, v in self.load_on_stage_end.items() if v not in _not_required_load_states } - self._load_state(state, mapping=to_load) + self._load_runner(runner, mapping=to_load) class IterationCheckpointCallback(BaseCheckpointCallback): @@ -743,45 +751,47 @@ def process_checkpoint( self.save_metric(logdir, metrics) print(f"\nSaved checkpoint at {filepath}") - def on_stage_start(self, state: State): + def on_stage_start(self, runner: _Runner): """ Reset iterations counter. Args: - state (State): training state + runner (_Runner): current runner """ if self.stage_restart: self._iteration_counter = 0 - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """ Save checkpoint based on batches count. Args: - state (State): training state + runner (_Runner): current runner """ self._iteration_counter += 1 if self._iteration_counter % self.period == 0: - checkpoint = _pack_state(state) + checkpoint = _pack_runner(runner) self.process_checkpoint( - logdir=state.logdir, + logdir=runner.logdir, checkpoint=checkpoint, - batch_metrics=state.batch_metrics, + batch_metrics=runner.batch_metrics, ) - def on_stage_end(self, state: State): + def on_stage_end(self, runner: _Runner): """ Load model specified in ``load_on_stage_end``. Args: - state (State): training state + runner (_Runner): current runner """ if self.load_on_stage_end in ["best", "best_full"]: - resume = f"{state.logdir}/checkpoints/{self.load_on_stage_end}.pth" + resume = ( + f"{runner.logdir}/checkpoints/{self.load_on_stage_end}.pth" + ) print(f"Loading {self.load_on_stage_end} model from {resume}") _load_checkpoint( filename=resume, - state=state, + runner=runner, load_full=self.load_on_stage_end.endswith("full"), ) diff --git a/catalyst/core/callbacks/criterion.py b/catalyst/core/callbacks/criterion.py index d1044b2c69..3eef0ce9bb 100644 --- a/catalyst/core/callbacks/criterion.py +++ b/catalyst/core/callbacks/criterion.py @@ -1,6 +1,6 @@ from typing import Dict, List, Union -from catalyst.core import State +from catalyst.core.runner import _Runner from .metrics import _MetricCallback @@ -28,7 +28,7 @@ def __init__( If '__all__', the whole output will be passed to the criterion If None, empty dict will be passed to the criterion. prefix (str): prefix for metrics and output key for loss - in ``state.batch_metrics`` dictionary + in ``runner.batch_metrics`` dictionary criterion_key (str): A key to take a criterion in case there are several of them and they are in a dictionary format. multiplier (float): scale factor for the output loss. @@ -48,9 +48,9 @@ def metric_fn(self): """@TODO: Docs. Contribution is welcome.""" return self._criterion - def on_stage_start(self, state: State): + def on_stage_start(self, runner: _Runner): """Checks that the current stage has correct criterion.""" - criterion = state.get_attr( + criterion = runner.get_attr( key="criterion", inner_key=self.criterion_key ) assert criterion is not None diff --git a/catalyst/core/callbacks/early_stop.py b/catalyst/core/callbacks/early_stop.py index 47a1c1192b..945abca2b1 100644 --- a/catalyst/core/callbacks/early_stop.py +++ b/catalyst/core/callbacks/early_stop.py @@ -1,4 +1,5 @@ -from catalyst.core import Callback, CallbackNode, CallbackOrder, State +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner class CheckRunCallback(Callback): @@ -10,15 +11,15 @@ def __init__(self, num_batch_steps: int = 3, num_epoch_steps: int = 2): self.num_batch_steps = num_batch_steps self.num_epoch_steps = num_epoch_steps - def on_epoch_end(self, state: State): + def on_epoch_end(self, runner: _Runner): """@TODO: Docs. Contribution is welcome.""" - if state.epoch >= self.num_epoch_steps: - state.need_early_stop = True + if runner.epoch >= self.num_epoch_steps: + runner.need_early_stop = True - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """@TODO: Docs. Contribution is welcome.""" - if state.loader_batch_step >= self.num_batch_steps: - state.need_early_stop = True + if runner.loader_batch_step >= self.num_batch_steps: + runner.need_early_stop = True class EarlyStoppingCallback(Callback): @@ -44,12 +45,12 @@ def __init__( else: self.is_better = lambda score, best: score >= (best + min_delta) - def on_epoch_end(self, state: State) -> None: + def on_epoch_end(self, runner: _Runner) -> None: """@TODO: Docs. Contribution is welcome.""" - if state.stage_name.startswith("infer"): + if runner.stage_name.startswith("infer"): return - score = state.valid_metrics[self.metric] + score = runner.valid_metrics[self.metric] if self.best_score is None: self.best_score = score if self.is_better(score, self.best_score): @@ -59,5 +60,5 @@ def on_epoch_end(self, state: State) -> None: self.num_bad_epochs += 1 if self.num_bad_epochs >= self.patience: - print(f"Early stop at {state.epoch} epoch") - state.need_early_stop = True + print(f"Early stop at {runner.epoch} epoch") + runner.need_early_stop = True diff --git a/catalyst/core/callbacks/exception.py b/catalyst/core/callbacks/exception.py index a9fa492424..54984dbeca 100644 --- a/catalyst/core/callbacks/exception.py +++ b/catalyst/core/callbacks/exception.py @@ -1,4 +1,6 @@ -from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils +from catalyst.core import utils +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner class ExceptionCallback(Callback): @@ -10,11 +12,11 @@ def __init__(self): order=CallbackOrder.External + 1, node=CallbackNode.All ) - def on_exception(self, state: State): + def on_exception(self, runner: _Runner): """@TODO: Docs. Contribution is welcome.""" - exception = state.exception + exception = runner.exception if not utils.is_exception(exception): return - if state.need_exception_reraise: + if runner.need_exception_reraise: raise exception diff --git a/catalyst/core/callbacks/formatters.py b/catalyst/core/callbacks/formatters.py index 1d159220f1..ccd3363346 100644 --- a/catalyst/core/callbacks/formatters.py +++ b/catalyst/core/callbacks/formatters.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod import logging -from catalyst.core import State, utils +from catalyst.core import utils +from catalyst.core.runner import _Runner class MetricsFormatter(ABC, logging.Formatter): @@ -17,15 +18,15 @@ def __init__(self, message_prefix): super().__init__(f"{message_prefix}{{message}}", style="{") @abstractmethod - def _format_message(self, state: State): + def _format_message(self, runner: _Runner): pass def format(self, record: logging.LogRecord): """Format message string.""" # noinspection PyUnresolvedReferences - state = record.state + runner = record.runner - record.msg = self._format_message(state) + record.msg = self._format_message(runner) return super().format(record) @@ -59,18 +60,18 @@ def _format_metrics(self, metrics: Dict[str, Dict[str, float]]): return metrics_formatted - def _format_message(self, state: State): + def _format_message(self, runner: _Runner): message = [""] mode_metrics = utils.split_dict_to_subdicts( - dct=state.epoch_metrics, - prefixes=list(state.loaders.keys()), + dct=runner.epoch_metrics, + prefixes=list(runner.loaders.keys()), extra_key="_base", ) metrics = self._format_metrics(mode_metrics) for key, value in metrics.items(): message.append( - f"{state.epoch}/{state.num_epochs} " - f"* Epoch {state.global_epoch} ({key}): {value}" + f"{runner.epoch}/{runner.num_epochs} " + f"* Epoch {runner.global_epoch} ({key}): {value}" ) message = "\n".join(message) return message diff --git a/catalyst/core/callbacks/logging.py b/catalyst/core/callbacks/logging.py index 066a6f2da1..a2c8b8bffd 100644 --- a/catalyst/core/callbacks/logging.py +++ b/catalyst/core/callbacks/logging.py @@ -6,7 +6,9 @@ from tqdm import tqdm from catalyst.contrib.tools.tensorboard import SummaryWriter -from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils +from catalyst.core import utils +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner from . import formatters @@ -50,19 +52,19 @@ def _need_show(self, key: str): return result - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """Init tqdm progress bar.""" self.step = 0 self.tqdm = tqdm( - total=state.loader_len, - desc=f"{state.epoch}/{state.num_epochs}" - f" * Epoch ({state.loader_name})", + total=runner.loader_len, + desc=f"{runner.epoch}/{runner.num_epochs}" + f" * Epoch ({runner.loader_name})", leave=True, ncols=0, file=sys.stdout, ) - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """Cleanup and close tqdm progress bar.""" # self.tqdm.visible = False # self.tqdm.leave = True @@ -72,31 +74,31 @@ def on_loader_end(self, state: State): self.tqdm = None self.step = 0 - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Update tqdm progress bar at the end of each batch.""" self.tqdm.set_postfix( **{ k: "{:3.3f}".format(v) if v > 1e-3 else "{:1.3e}".format(v) - for k, v in sorted(state.batch_metrics.items()) + for k, v in sorted(runner.batch_metrics.items()) if self._need_show(k) } ) self.tqdm.update() - def on_exception(self, state: State): + def on_exception(self, runner: _Runner): """Called if an Exception was raised.""" - exception = state.exception + exception = runner.exception if not utils.is_exception(exception): return if isinstance(exception, KeyboardInterrupt): self.tqdm.write("Early exiting") - state.need_exception_reraise = False + runner.need_exception_reraise = False class ConsoleLogger(Callback): """Logger callback, - translates ``state.*_metrics`` to console and text file. + translates ``runner.*_metrics`` to console and text file. """ def __init__(self): @@ -127,28 +129,28 @@ def _get_logger(logdir): # logger.addHandler(jh) return logger - def on_stage_start(self, state: State): - """Prepare ``state.logdir`` for the current stage.""" - if state.logdir: - state.logdir.mkdir(parents=True, exist_ok=True) - self.logger = self._get_logger(state.logdir) + def on_stage_start(self, runner: _Runner): + """Prepare ``runner.logdir`` for the current stage.""" + if runner.logdir: + runner.logdir.mkdir(parents=True, exist_ok=True) + self.logger = self._get_logger(runner.logdir) - def on_stage_end(self, state: State): + def on_stage_end(self, runner: _Runner): """Called at the end of each stage.""" for handler in self.logger.handlers: handler.close() self.logger.handlers = [] - def on_epoch_end(self, state: State): + def on_epoch_end(self, runner: _Runner): """ - Translate ``state.metric_manager`` to console and text file + Translate ``runner.metric_manager`` to console and text file at the end of an epoch. """ - self.logger.info("", extra={"state": state}) + self.logger.info("", extra={"runner": runner}) class TensorboardLogger(Callback): - """Logger callback, translates ``state.metric_manager`` to tensorboard.""" + """Logger callback, translates ``runner.metric_manager`` to tensorboard.""" def __init__( self, @@ -187,44 +189,44 @@ def _log_metrics( f"{name}{suffix}", metrics[name], step ) - def on_stage_start(self, state: State): + def on_stage_start(self, runner: _Runner): """@TODO: Docs. Contribution is welcome.""" - assert state.logdir is not None + assert runner.logdir is not None extra_mode = "_base" - log_dir = os.path.join(state.logdir, f"{extra_mode}_log") + log_dir = os.path.join(runner.logdir, f"{extra_mode}_log") self.loggers[extra_mode] = SummaryWriter(log_dir) - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """Prepare tensorboard writers for the current stage.""" - if state.loader_name not in self.loggers: - log_dir = os.path.join(state.logdir, f"{state.loader_name}_log") - self.loggers[state.loader_name] = SummaryWriter(log_dir) + if runner.loader_name not in self.loggers: + log_dir = os.path.join(runner.logdir, f"{runner.loader_name}_log") + self.loggers[runner.loader_name] = SummaryWriter(log_dir) - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Translate batch metrics to tensorboard.""" - if state.logdir is None: + if runner.logdir is None: return if self.log_on_batch_end: - mode = state.loader_name - metrics_ = state.batch_metrics + mode = runner.loader_name + metrics_ = runner.batch_metrics self._log_metrics( metrics=metrics_, - step=state.global_sample_step, + step=runner.global_sample_step, mode=mode, suffix="/batch", ) - def on_epoch_end(self, state: "State"): + def on_epoch_end(self, runner: _Runner): """Translate epoch metrics to tensorboard.""" - if state.logdir is None: + if runner.logdir is None: return if self.log_on_epoch_end: per_mode_metrics = utils.split_dict_to_subdicts( - dct=state.epoch_metrics, - prefixes=list(state.loaders.keys()), + dct=runner.epoch_metrics, + prefixes=list(runner.loaders.keys()), extra_key="_base", ) @@ -232,7 +234,7 @@ def on_epoch_end(self, state: "State"): # suffix = "" if mode == "_base" else "/epoch" self._log_metrics( metrics=metrics, - step=state.global_epoch, + step=runner.global_epoch, mode=mode, suffix="/epoch", ) @@ -240,9 +242,9 @@ def on_epoch_end(self, state: "State"): for logger in self.loggers.values(): logger.flush() - def on_stage_end(self, state: State): + def on_stage_end(self, runner: _Runner): """Close opened tensorboard writers.""" - if state.logdir is None: + if runner.logdir is None: return for logger in self.loggers.values(): diff --git a/catalyst/core/callbacks/metrics.py b/catalyst/core/callbacks/metrics.py index 6f1720ad82..46e6248e36 100644 --- a/catalyst/core/callbacks/metrics.py +++ b/catalyst/core/callbacks/metrics.py @@ -5,7 +5,9 @@ import torch -from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils +from catalyst.core import utils +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner from catalyst.tools import meters logger = logging.getLogger(__name__) @@ -65,28 +67,28 @@ def metric_fn(self): """@TODO: Docs. Contribution is welcome.""" pass - def _compute_metric_value(self, state: State): - output = self._get_output(state.output, self.output_key) - input = self._get_input(state.input, self.input_key) + def _compute_metric_value(self, runner: _Runner): + output = self._get_output(runner.output, self.output_key) + input = self._get_input(runner.input, self.input_key) metric = self.metric_fn(output, input, **self.metrics_kwargs) return metric - def _compute_metric_key_value(self, state: State): - output = self._get_output(state.output, self.output_key) - input = self._get_input(state.input, self.input_key) + def _compute_metric_key_value(self, runner: _Runner): + output = self._get_output(runner.output, self.output_key) + input = self._get_input(runner.input, self.input_key) metric = self.metric_fn(**output, **input, **self.metrics_kwargs) return metric - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Computes the metric and add it to batch metrics.""" - metric = self._compute_metric(state) * self.multiplier - state.batch_metrics[self.prefix] = metric + metric = self._compute_metric(runner) * self.multiplier + runner.batch_metrics[self.prefix] = metric class MetricCallback(_MetricCallback): - """A callback that returns single metric on `state.on_batch_end`.""" + """A callback that returns single metric on `runner.on_batch_end`.""" def __init__( self, @@ -114,7 +116,7 @@ def metric_fn(self): class MultiMetricCallback(MetricCallback): - """A callback that returns multiple metrics on `state.on_batch_end`.""" + """A callback that returns multiple metrics on `runner.on_batch_end`.""" def __init__( self, @@ -137,20 +139,20 @@ def __init__( ) self.list_args = list_args - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - metrics_ = self._compute_metric(state) + metrics_ = self._compute_metric(runner) for arg, metric in zip(self.list_args, metrics_): if isinstance(arg, int): key = f"{self.prefix}{arg:02}" else: key = f"{self.prefix}_{arg}" - state.batch_metrics[key] = metric * self.multiplier + runner.batch_metrics[key] = metric * self.multiplier class MetricAggregationCallback(Callback): @@ -234,15 +236,15 @@ def _preprocess(self, metrics: Any) -> List[float]: result = list(metrics.values()) return result - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Computes the metric and add it to the metrics. Args: - state (State): current state + runner (_Runner): current runner """ - metrics = self._preprocess(state.batch_metrics) + metrics = self._preprocess(runner.batch_metrics) metric = self.aggregation_fn(metrics) - state.batch_metrics[self.prefix] = metric + runner.batch_metrics[self.prefix] = metric class MetricManagerCallback(Callback): @@ -274,52 +276,52 @@ def _process_metrics(metrics: Dict[str, Any]): output[key] = value return output - def on_epoch_start(self, state: State) -> None: + def on_epoch_start(self, runner: _Runner) -> None: """Epoch start hook. Args: - state (State): current state + runner (_Runner): current runner """ - state.epoch_metrics = defaultdict(None) + runner.epoch_metrics = defaultdict(None) - def on_loader_start(self, state: State) -> None: + def on_loader_start(self, runner: _Runner) -> None: """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ - state.loader_metrics = defaultdict(None) + runner.loader_metrics = defaultdict(None) self.meters = defaultdict(meters.AverageValueMeter) - def on_loader_end(self, state: State) -> None: + def on_loader_end(self, runner: _Runner) -> None: """Loader end hook. Args: - state (State): current state + runner (_Runner): current runner """ for key, value in self.meters.items(): value = value.mean - state.loader_metrics[key] = value - for key, value in state.loader_metrics.items(): - state.epoch_metrics[f"{state.loader_name}_{key}"] = value + runner.loader_metrics[key] = value + for key, value in runner.loader_metrics.items(): + runner.epoch_metrics[f"{runner.loader_name}_{key}"] = value - def on_batch_start(self, state: State) -> None: + def on_batch_start(self, runner: _Runner) -> None: """Batch start hook. Args: - state (State): current state + runner (_Runner): current runner """ - state.batch_metrics = defaultdict(None) + runner.batch_metrics = defaultdict(None) - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - state.batch_metrics = self._process_metrics(state.batch_metrics) - for key, value in state.batch_metrics.items(): - self.meters[key].add(value, state.batch_size) + runner.batch_metrics = self._process_metrics(runner.batch_metrics) + for key, value in runner.batch_metrics.items(): + self.meters[key].add(value, runner.batch_size) __all__ = [ diff --git a/catalyst/core/callbacks/optimizer.py b/catalyst/core/callbacks/optimizer.py index 1c187a1bbf..765eb80e61 100644 --- a/catalyst/core/callbacks/optimizer.py +++ b/catalyst/core/callbacks/optimizer.py @@ -2,14 +2,9 @@ import logging import warnings -from catalyst.core import ( - Callback, - CallbackNode, - CallbackOrder, - registry, - State, - utils, -) +from catalyst.core import registry, utils +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner from catalyst.tools.typing import Optimizer logger = logging.getLogger(__name__) @@ -29,7 +24,7 @@ def __init__( ): """ Args: - loss_key (str): key to get loss from ``state.batch_metrics`` + loss_key (str): key to get loss from ``runner.batch_metrics`` optimizer_key (str): A key to take a optimizer in case there are several of them and they are in a dictionary format. accumulation_steps (int): number of steps before @@ -83,18 +78,18 @@ def grad_step( grad_clip_fn(group["params"]) optimizer.step() - def on_stage_start(self, state: State) -> None: + def on_stage_start(self, runner: _Runner) -> None: """Checks that the current stage has correct optimizer.""" - self._optimizer = state.get_attr( + self._optimizer = runner.get_attr( key="optimizer", inner_key=self.optimizer_key ) assert self._optimizer is not None - def on_epoch_start(self, state: State) -> None: + def on_epoch_start(self, runner: _Runner) -> None: """On epoch start event. Args: - state (State): current state + runner (_Runner): current runner """ if self.decouple_weight_decay: self._optimizer_wd = [ @@ -106,11 +101,11 @@ def on_epoch_start(self, state: State) -> None: else: self._optimizer_wd = [0.0] * len(self._optimizer.param_groups) - def on_epoch_end(self, state: State) -> None: + def on_epoch_end(self, runner: _Runner) -> None: """On epoch end event. Args: - state (State): current state + runner (_Runner): current runner """ if self.decouple_weight_decay: for i, wd in enumerate(self._optimizer_wd): @@ -122,7 +117,7 @@ def on_epoch_end(self, state: State) -> None: if self.optimizer_key is not None else "lr" ) - state.epoch_metrics[lr_name] = lr + runner.epoch_metrics[lr_name] = lr momentum = utils.get_optimizer_momentum(self._optimizer) if momentum is not None: @@ -131,18 +126,18 @@ def on_epoch_end(self, state: State) -> None: if self.optimizer_key is not None else "momentum" ) - state.epoch_metrics[momentum_name] = momentum + runner.epoch_metrics[momentum_name] = momentum - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """On batch end event Args: - state (State): current state + runner (_Runner): current runner """ - if not state.is_train_loader: + if not runner.is_train_loader: return - loss = state.batch_metrics[self.metric_key] + loss = runner.batch_metrics[self.metric_key] self._accumulation_counter += 1 need_gradient_step = ( diff --git a/catalyst/core/callbacks/scheduler.py b/catalyst/core/callbacks/scheduler.py index c505b97f69..10ca66b652 100644 --- a/catalyst/core/callbacks/scheduler.py +++ b/catalyst/core/callbacks/scheduler.py @@ -1,7 +1,11 @@ +from typing import Tuple + import torch from catalyst.contrib.nn.schedulers import BatchScheduler, OneCycleLRWithWarmup -from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils +from catalyst.core import utils +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner class SchedulerCallback(Callback): @@ -34,56 +38,56 @@ def _scheduler_step( return lr, momentum - def step_batch(self, state: State) -> None: + def step_batch(self, runner: _Runner) -> None: """@TODO: Docs. Contribution is welcome. Args: - state (State): current state + runner (_Runner): current runner """ lr, momentum = self._scheduler_step(scheduler=self._scheduler) if self.scheduler_key is not None: - state.batch_metrics[f"lr/{self.scheduler_key}"] = lr + runner.batch_metrics[f"lr/{self.scheduler_key}"] = lr if momentum is not None: - state.batch_metrics[ + runner.batch_metrics[ f"momentum/{self.scheduler_key}" ] = momentum else: - state.batch_metrics["lr"] = lr + runner.batch_metrics["lr"] = lr if momentum is not None: - state.batch_metrics["momentum"] = momentum + runner.batch_metrics["momentum"] = momentum - def step_epoch(self, state: State) -> None: + def step_epoch(self, runner: _Runner) -> None: """@TODO: Docs. Contribution is welcome. Args: - state (State): current state + runner (_Runner): current runner """ - reduced_metric = state.valid_metrics[self.reduced_metric] + reduced_metric = runner.valid_metrics[self.reduced_metric] lr, momentum = self._scheduler_step( scheduler=self._scheduler, reduced_metric=reduced_metric ) if self.scheduler_key is not None: - state.epoch_metrics[f"lr/{self.scheduler_key}"] = lr + runner.epoch_metrics[f"lr/{self.scheduler_key}"] = lr if momentum is not None: - state.epoch_metrics[ + runner.epoch_metrics[ f"momentum/{self.scheduler_key}" ] = momentum else: - state.epoch_metrics["lr"] = lr + runner.epoch_metrics["lr"] = lr if momentum is not None: - state.epoch_metrics["momentum"] = momentum + runner.epoch_metrics["momentum"] = momentum - def on_stage_start(self, state: State) -> None: + def on_stage_start(self, runner: _Runner) -> None: """Stage start hook. Args: - state (State): current state + runner (_Runner): current runner """ - self.reduced_metric = self.reduced_metric or state.main_metric + self.reduced_metric = self.reduced_metric or runner.main_metric - scheduler = state.get_attr( + scheduler = runner.get_attr( key="scheduler", inner_key=self.scheduler_key ) assert scheduler is not None @@ -102,38 +106,38 @@ def on_stage_start(self, state: State) -> None: scheduler.reset() assert self.mode is not None - def on_loader_start(self, state: State) -> None: + def on_loader_start(self, runner: _Runner) -> None: """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ if ( - state.is_train_loader + runner.is_train_loader and isinstance(self._scheduler, OneCycleLRWithWarmup) and self.mode == "batch" ): self._scheduler.recalculate( - loader_len=state.loader_len, current_step=state.epoch + loader_len=runner.loader_len, current_step=runner.epoch ) - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - if state.is_train_loader and self.mode == "batch": - self.step_batch(state=state) + if runner.is_train_loader and self.mode == "batch": + self.step_batch(runner=runner) - def on_epoch_end(self, state: State) -> None: + def on_epoch_end(self, runner: _Runner) -> None: """Epoch end hook. Args: - state (State): current state + runner (_Runner): current runner """ if self.mode == "epoch": - self.step_epoch(state=state) + self.step_epoch(runner=runner) class LRUpdater(Callback): @@ -171,7 +175,7 @@ def _update_momentum(optimizer, new_momentum) -> None: for pg in optimizer.param_groups: pg["momentum"] = new_momentum - def _update_optimizer(self, optimizer) -> None: + def _update_optimizer(self, optimizer) -> Tuple[float, float]: new_lr = self.calc_lr() if new_lr is not None: self._update_lr(optimizer, new_lr) @@ -184,51 +188,51 @@ def _update_optimizer(self, optimizer) -> None: return new_lr, new_momentum - def update_optimizer(self, state: State) -> None: + def update_optimizer(self, runner: _Runner) -> None: """@TODO: Docs. Contribution is welcome. Args: - state (State): current state + runner (_Runner): current runner """ lr, momentum = self._update_optimizer(optimizer=self._optimizer) if self.optimizer_key is not None: - state.batch_metrics[f"lr_{self.optimizer_key}"] = lr - state.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum + runner.batch_metrics[f"lr_{self.optimizer_key}"] = lr + runner.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum else: - state.batch_metrics["lr"] = lr - state.batch_metrics["momentum"] = momentum + runner.batch_metrics["lr"] = lr + runner.batch_metrics["momentum"] = momentum - def on_stage_start(self, state: State) -> None: + def on_stage_start(self, runner: _Runner) -> None: """Stage start hook. Args: - state (State): current state + runner (_Runner): current runner """ - optimizer = state.get_attr( + optimizer = runner.get_attr( key="optimizer", inner_key=self.optimizer_key ) assert optimizer is not None self._optimizer = optimizer self.init_lr = optimizer.defaults["lr"] - def on_loader_start(self, state: State) -> None: + def on_loader_start(self, runner: _Runner) -> None: """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ - if state.is_train_loader: - self.update_optimizer(state=state) + if runner.is_train_loader: + self.update_optimizer(runner=runner) - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - if state.is_train_loader: - self.update_optimizer(state=state) + if runner.is_train_loader: + self.update_optimizer(runner=runner) __all__ = ["SchedulerCallback", "LRUpdater"] diff --git a/catalyst/core/callbacks/timer.py b/catalyst/core/callbacks/timer.py index 0c18c7ced8..e067fe9176 100644 --- a/catalyst/core/callbacks/timer.py +++ b/catalyst/core/callbacks/timer.py @@ -1,4 +1,5 @@ -from catalyst.core import Callback, CallbackNode, CallbackOrder, State +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner from catalyst.tools.time_manager import TimeManager EPS = 1e-8 @@ -12,48 +13,48 @@ def __init__(self): super().__init__(order=CallbackOrder.Metric + 1, node=CallbackNode.All) self.timer = TimeManager() - def on_loader_start(self, state: State) -> None: + def on_loader_start(self, runner: _Runner) -> None: """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ self.timer.reset() self.timer.start("_timer/batch_time") self.timer.start("_timer/data_time") - def on_loader_end(self, state: State) -> None: + def on_loader_end(self, runner: _Runner) -> None: """Loader end hook. Args: - state (State): current state + runner (_Runner): current runner """ self.timer.reset() - def on_batch_start(self, state: State) -> None: + def on_batch_start(self, runner: _Runner) -> None: """Batch start hook. Args: - state (State): current state + runner (_Runner): current runner """ self.timer.stop("_timer/data_time") self.timer.start("_timer/model_time") - def on_batch_end(self, state: State) -> None: + def on_batch_end(self, runner: _Runner) -> None: """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ self.timer.stop("_timer/model_time") self.timer.stop("_timer/batch_time") # @TODO: just a trick - self.timer.elapsed["_timer/_fps"] = state.batch_size / ( + self.timer.elapsed["_timer/_fps"] = runner.batch_size / ( self.timer.elapsed["_timer/batch_time"] + EPS ) for key, value in self.timer.elapsed.items(): - state.batch_metrics[key] = value + runner.batch_metrics[key] = value self.timer.reset() self.timer.start("_timer/batch_time") diff --git a/catalyst/core/callbacks/validation.py b/catalyst/core/callbacks/validation.py index 0cc6c6738c..4b0daac26a 100644 --- a/catalyst/core/callbacks/validation.py +++ b/catalyst/core/callbacks/validation.py @@ -1,10 +1,13 @@ from collections import defaultdict -from catalyst.core import Callback, CallbackNode, CallbackOrder, State +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import _Runner class ValidationManagerCallback(Callback): - """A callback to aggregate state.valid_metrics from state.epoch_metrics.""" + """ + A callback to aggregate runner.valid_metrics from runner.epoch_metrics. + """ def __init__(self): """@TODO: Docs. Contribution is welcome.""" @@ -12,48 +15,48 @@ def __init__(self): order=CallbackOrder.Validation, node=CallbackNode.All, ) - def on_epoch_start(self, state: State) -> None: + def on_epoch_start(self, runner: _Runner) -> None: """Epoch start hook. Args: - state (State): current state + runner (_Runner): current runner """ - state.valid_metrics = defaultdict(None) - state.is_best_valid = False + runner.valid_metrics = defaultdict(None) + runner.is_best_valid = False - def on_epoch_end(self, state: State) -> None: + def on_epoch_end(self, runner: _Runner) -> None: """Epoch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - if state.stage_name.startswith("infer"): + if runner.stage_name.startswith("infer"): return - state.valid_metrics = { - k.replace(f"{state.valid_loader}_", ""): v - for k, v in state.epoch_metrics.items() - if k.startswith(state.valid_loader) + runner.valid_metrics = { + k.replace(f"{runner.valid_loader}_", ""): v + for k, v in runner.epoch_metrics.items() + if k.startswith(runner.valid_loader) } assert ( - state.main_metric in state.valid_metrics - ), f"{state.main_metric} value is not available by the epoch end" + runner.main_metric in runner.valid_metrics + ), f"{runner.main_metric} value is not available by the epoch end" - current_valid_metric = state.valid_metrics[state.main_metric] - if state.minimize_metric: - best_valid_metric = state.best_valid_metrics.get( - state.main_metric, float("+inf") + current_valid_metric = runner.valid_metrics[runner.main_metric] + if runner.minimize_metric: + best_valid_metric = runner.best_valid_metrics.get( + runner.main_metric, float("+inf") ) is_best = current_valid_metric < best_valid_metric else: - best_valid_metric = state.best_valid_metrics.get( - state.main_metric, float("-inf") + best_valid_metric = runner.best_valid_metrics.get( + runner.main_metric, float("-inf") ) is_best = current_valid_metric > best_valid_metric if is_best: - state.is_best_valid = True - state.best_valid_metrics = state.valid_metrics.copy() + runner.is_best_valid = True + runner.best_valid_metrics = runner.valid_metrics.copy() __all__ = ["ValidationManagerCallback"] diff --git a/catalyst/core/experiment.py b/catalyst/core/experiment.py index 605adea36c..538b1feef6 100644 --- a/catalyst/core/experiment.py +++ b/catalyst/core/experiment.py @@ -5,10 +5,9 @@ from torch import nn from torch.utils.data import DataLoader, Dataset +from catalyst.core.callback import Callback from catalyst.tools.typing import Criterion, Model, Optimizer, Scheduler -from .callback import Callback - class _Experiment(ABC): """ @@ -22,7 +21,6 @@ class _Experiment(ABC): - :py:mod:`catalyst.core.experiment._Experiment` - :py:mod:`catalyst.core.runner._Runner` - - :py:mod:`catalyst.core.state.State` - :py:mod:`catalyst.core.callback.Callback` Abstraction, please check out the implementations: @@ -39,7 +37,7 @@ def initial_seed(self) -> int: Experiment's initial seed, used to setup `global seed` at the beginning of each stage. Additionally, Catalyst Runner setups - `experiment.initial_seed + state.global_epoch + 1` + `experiment.initial_seed + runner.global_epoch + 1` as `global seed` each epoch. Used for experiment reproducibility. @@ -100,23 +98,21 @@ def distributed_params(self) -> Dict: pass @abstractmethod - def get_state_params(self, stage: str) -> Mapping[str, Any]: - """Returns State parameters for a given stage. - - To learn more about State, please follow - :py:mod:`catalyst.core.state.State` - documentation. + def get_stage_params(self, stage: str) -> Mapping[str, Any]: + """Returns extra stage parameters for a given stage. Example:: - >>> experiment.get_state_params(stage="training") + >>> experiment.get_stage_params(stage="training") { "logdir": "./logs/training", "num_epochs": 42, "valid_loader": "valid", "main_metric": "loss", "minimize_metric": True, - "checkpoint_data": {"comment": "we are going to make it!"} + "checkpoint_data": { + "comment": "break the cycle - use the Catalyst" + } } Args: @@ -124,7 +120,7 @@ def get_state_params(self, stage: str) -> Mapping[str, Any]: like "pretrain" / "train" / "finetune" / etc Returns: - dict: State parameters for a given stage. + dict: parameters for a given stage. """ pass @@ -208,30 +204,34 @@ def get_scheduler(self, stage: str, optimizer: Optimizer) -> Scheduler: pass def get_experiment_components( - self, model: nn.Module, stage: str - ) -> Tuple[Criterion, Optimizer, Scheduler]: + self, stage: str, model: nn.Module = None, + ) -> Tuple[Model, Criterion, Optimizer, Scheduler]: """ Returns the tuple containing criterion, optimizer and scheduler by giving model and stage. Aggregation method, based on, + - :py:mod:`catalyst.core.experiment._Experiment.get_model` - :py:mod:`catalyst.core.experiment._Experiment.get_criterion` - :py:mod:`catalyst.core.experiment._Experiment.get_optimizer` - :py:mod:`catalyst.core.experiment._Experiment.get_scheduler` Args: - model (Model): model to optimize with stage optimizer stage (str): stage name of interest, like "pretrain" / "train" / "finetune" / etc + model (Model): model to optimize with stage optimizer Returns: - tuple: criterion, optimizer, scheduler for a given stage and model + tuple: model, criterion, optimizer, scheduler + for a given stage and model """ + if model is None: + model = self.get_model(stage) criterion = self.get_criterion(stage) optimizer = self.get_optimizer(stage, model) scheduler = self.get_scheduler(stage, optimizer) - return criterion, optimizer, scheduler + return model, criterion, optimizer, scheduler def get_transforms(self, stage: str = None, dataset: str = None): """Returns the data transforms for a given stage and dataset. @@ -244,7 +244,7 @@ def get_transforms(self, stage: str = None, dataset: str = None): .. note:: For datasets/loaders nameing please follow - :py:mod:`catalyst.core.state.State` documentation. + :py:mod:`catalyst.core.runner` documentation. Returns: Data transformations to use for specified dataset. @@ -349,7 +349,6 @@ def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": - :py:mod:`catalyst.core.experiment._Experiment` - :py:mod:`catalyst.core.runner._Runner` - - :py:mod:`catalyst.core.state.State` - :py:mod:`catalyst.core.callback.Callback` """ pass diff --git a/catalyst/core/legacy.py b/catalyst/core/legacy.py new file mode 100644 index 0000000000..54d997dfa8 --- /dev/null +++ b/catalyst/core/legacy.py @@ -0,0 +1,86 @@ +import warnings + + +class _RunnerLegacy: + """ + Special class to encapsulate all `catalyst.core.runner._Runner` + and `catalyst.core.runner.State` legacy into one place. + Used to make `catalyst.core.runner._Runner` cleaner + and easier to understand. + + Saved for backward compatibility. Should be removed someday. + """ + + @property + def batch_in(self): + """Alias for `runner.input`. + + .. warning:: + Deprecated, saved for backward compatibility. + Please use `runner.input` instead. + """ + warnings.warn( + "`runner.batch_in` was deprecated, " + "please use `runner.input` instead", + DeprecationWarning, + ) + return self.input + + @property + def batch_out(self): + """Alias for `runner.output`. + + .. warning:: + Deprecated, saved for backward compatibility. + Please use `runner.output` instead. + """ + warnings.warn( + "`runner.batch_out` was deprecated, " + "please use `runner.output` instead", + DeprecationWarning, + ) + return self.output + + @property + def need_backward_pass(self): + """Alias for `runner.is_train_loader`. + + .. warning:: + Deprecated, saved for backward compatibility. + Please use `runner.is_train_loader` instead. + """ + warnings.warn( + "`need_backward_pass` was deprecated, " + "please use `is_train_loader` instead", + DeprecationWarning, + ) + return self.is_train_loader + + @property + def loader_step(self): + """Alias for `runner.loader_batch_step`. + + .. warning:: + Deprecated, saved for backward compatibility. + Please use `runner.loader_batch_step` instead. + """ + warnings.warn( + "`loader_step` was deprecated, " + "please use `loader_batch_step` instead", + DeprecationWarning, + ) + return self.loader_batch_step + + @property + def state(self): + """Alias for `runner`. + + .. warning:: + Deprecated, saved for backward compatibility. + Please use `runner` instead. + """ + warnings.warn( + "`runner.state` was deprecated, " "please use `runner` instead", + DeprecationWarning, + ) + return self diff --git a/catalyst/core/runner.py b/catalyst/core/runner.py index 06845a41ec..d1f6452e55 100644 --- a/catalyst/core/runner.py +++ b/catalyst/core/runner.py @@ -1,28 +1,33 @@ -from typing import Any, Callable, Dict, Mapping, Tuple, Union +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union from abc import ABC, abstractmethod -from collections import OrderedDict +from collections import defaultdict, OrderedDict +from pathlib import Path import torch from torch import nn from torch.utils.data import DataLoader, DistributedSampler from catalyst.core import utils +from catalyst.core.callback import Callback, CallbackScope +from catalyst.core.experiment import _Experiment from catalyst.tools import settings +from catalyst.tools.frozen_class import FrozenClass from catalyst.tools.typing import ( Criterion, Device, Model, Optimizer, + RunnerCriterion, + RunnerModel, + RunnerOptimizer, + RunnerScheduler, Scheduler, ) -from .callback import Callback, CallbackScope -from .callbacks import ExceptionCallback -from .experiment import _Experiment -from .state import State +from .legacy import _RunnerLegacy -class _Runner(ABC): +class _Runner(ABC, _RunnerLegacy, FrozenClass): """ An abstraction that knows how to run an experiment. It contains all the logic of **how** to run the experiment, @@ -33,29 +38,445 @@ class _Runner(ABC): - :py:mod:`catalyst.core.experiment._Experiment` - :py:mod:`catalyst.core.runner._Runner` - - :py:mod:`catalyst.core.state.State` - :py:mod:`catalyst.core.callback.Callback` Abstraction, please check out the implementations: + - :py:mod:`catalyst.dl.runner.runner.Runner` - :py:mod:`catalyst.dl.runner.supervised.SupervisedRunner` + Runner also contains full information about experiment runner. + + + Runner section + + + **runner.model** - an instance of torch.nn.Module class, \ + (should implement ``forward`` method); \ + for example, + :: + + runner.model = torch.nn.Linear(10, 10) + + **runner.device** - an instance of torch.device (CPU, GPU, TPU); \ + for example, + :: + + runner.device = torch.device("cpu") + + + Experiment section + + + **runner.criterion** - an instance of torch.nn.Module class\ + or torch.nn.modules.loss._Loss (should implement ``forward`` method); \ + for example, + :: + + runner.criterion = torch.nn.CrossEntropyLoss() + + **runner.optimizer** - an instance of torch.optim.optimizer.Optimizer\ + (should implement ``step`` method); \ + for example, + :: + + runner.optimizer = torch.optim.Adam() + + **runner.scheduler** - + an instance of torch.optim.lr_scheduler._LRScheduler\ + (should implement ``step`` method); \ + for example, + :: + + runner.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau() + + **runner.callbacks** - + ordered dictionary with Catalyst.Callback instances;\ + for example, + :: + + runner.callbacks = { + "accuracy": AccuracyCallback(), + "criterion": CriterionCallback(), + "optim": OptimizerCallback(), + "saver": CheckpointCallback() + } + + + Dataflow section + + + **runner.loaders** - ordered dictionary with torch.DataLoaders; \ + for example, + :: + + runner.loaders = { + "train": MnistTrainLoader(), + "valid": MnistValidLoader() + } + + .. note:: + - "*train*" prefix is used for training loaders - \ + metrics computations, backward pass, optimization + - "*valid*" prefix is used for validation loaders - \ + metrics computations only + - "*infer*" prefix is used for inference loaders - \ + dataset prediction + + **runner.input** - dictionary, \ + containing batch of data from currents DataLoader; \ + for example, + :: + + runner.input = { + "images": np.ndarray(batch_size, c, h, w), + "targets": np.ndarray(batch_size, 1), + } + + **runner.output** - dictionary, \ + containing model output for current batch; \ + for example, + :: + + runner.output = {"logits": torch.Tensor(batch_size, num_classes)} + + + Metrics section + + + **runner.batch_metrics** - dictionary, flatten storage for batch metrics; \ + for example, + :: + + runner.batch_metrics = {"loss": ..., "accuracy": ..., "iou": ...} + + **runner.loader_metrics** - dictionary with aggregated batch statistics \ + for loader (mean over all batches) and global loader metrics, like AUC; \ + for example, + :: + + runner.loader_metrics = {"loss": ..., "accuracy": ..., "auc": ...} + + **runner.epoch_metrics** - dictionary with summarized metrics \ + for different loaders and global epoch metrics, like lr, momentum; \ + for example, + :: + + runner.epoch_metrics = { + "train_loss": ..., "train_auc": ..., "valid_loss": ..., + "lr": ..., "momentum": ..., + } + + + Validation metrics section + + + **runner.main_metric** - string, containing name of metric of interest \ + for optimization, validation and checkpointing during training + + **runner.minimize_metric** - bool, indicator flag + + - ``True`` if we need to minimize metric during training,\ + like `Cross Entropy loss` + - ``False`` if we need to maximize metric during training, \ + like `Accuracy` or `Intersection over Union` + + + Validation section + + + **runner.valid_loader** - string, name of validation loader \ + for metric selection, validation and model checkpoining + + **runner.valid_metrics** - dictionary with validation metrics\ + for currect epoch; \ + for example, + :: + + runner.valid_metrics = {"loss": ..., "accuracy": ..., "auc": ...} + + .. note:: + subdictionary of epoch_metrics + + **runner.is_best_valid** - bool, indicator flag + + - ``True`` if this training epoch is best over all epochs + - ``False`` if not + + **runner.best_valid_metrics** - dictionary with best validation metrics \ + during whole training process + + + Distributed section + + + **runner.distributed_rank** - distributed rank of current worker + + **runner.is_distributed_master** - bool, indicator flag + + - ``True`` if is master node (runner.distributed_rank == 0) + - ``False`` if is worker node (runner.distributed_rank != 0) + + **runner.is_distributed_worker** - bool, indicator flag + + - ``True`` if is worker node (runner.distributed_rank > 0) + - ``False`` if is master node (runner.distributed_rank <= 0) + + + Experiment info section + + + **runner.global_sample_step** - int, numerical indicator, counter for all\ + individual samples, that passes through our model during training,\ + validation and inference stages + + **runner.global_batch_step** - int, numerical indicator, counter for all + batches, that passes through our model during training, validation and\ + inference stages + + **runner.global_epoch** - int, numerical indicator, + counter for all epochs,\ + that have passed during model training, validation and\ + inference stages + + **runner.verbose** - bool, indicator flag + + **runner.is_check_run** - bool, indicator flag + + - ``True`` if you want to check you pipeline and \ + run only 2 batches per loader and 2 epochs per stage + - ``False`` (default) if you want to just the pipeline + + **runner.need_early_stop** - bool, indicator flag \ + used for EarlyStopping and CheckRun Callbacks + + - ``True`` if we need to stop the training + - ``False`` (default) otherwise + + **runner.need_exception_reraise** - bool, indicator flag + + - ``True`` (default) if you want to show exception \ + during pipeline and stop the training process + - ``False`` otherwise + + + Stage info section + + + **runner.stage_name** - string, current stage name,\ + for example, + :: + + runner.stage_name = "pretraining" / "training" / "finetuning" / etc + + **runner.num_epochs** - int, maximum number of epochs, \ + required for this stage + + **runner.is_infer_stage** - bool, indicator flag + + - ``True`` for inference stages + - ``False`` otherwise + + + Epoch info section + + + **runner.epoch** - int, numerical indicator for current stage epoch + + + Loader info section + + + **runner.loader_sample_step** - int, numerical indicator \ + for number of samples passed through our model in current loader + + **runner.loader_batch_step** - int, numerical indicator \ + for batch index in current loader + + + **runner.loader_name** - string, current loader name\ + for example, + :: + + runner.loader_name = "train_dataset1" / "valid_data2" / "infer_golden" + + **runner.loader_len** - int, maximum number of batches in current loader + + **runner.loader_batch_size** - int, batch size parameter in current loader + + **runner.is_train_loader** - bool, indicator flag + + - ``True`` for training loaders + - ``False`` otherwise + + **runner.is_valid_loader** - bool, indicator flag + + - ``True`` for validation loaders + - ``False`` otherwise + + **runner.is_infer_loader** - bool, indicator flag + + - ``True`` for inference loaders + - ``False`` otherwise + + + Batch info section + + + **runner.batch_size** - int, length of the current batch + + Logging section + + + **runner.logdir** - string, path to logging directory to save\ + all logs, metrics, checkpoints and artifacts + + **runner.checkpoint_data** - dictionary\ + with all extra data for experiment tracking + + Extra section + + + **runner.exception** - python Exception instance to raise (or not ;) ) + """ _experiment_fn: Callable = _Experiment - _state_fn: Callable = State def __init__( - self, model: Model = None, device: Device = None, + self, model: RunnerModel = None, device: Device = None, **kwargs, ): """ Args: - model (Model): Torch model object + model (RunnerModel): Torch model object device (Device): Torch device """ - self._model: Model = model - self._device: Device = device - self._init() + self._device = None + self._model = None + self.device: Device = device + self.model: RunnerModel = model + self._init(**kwargs) + self._freeze() + + def _prepare_inner_state( + self, + stage: str = settings.stage_infer_prefix, # @TODO: wtf? + device: Device = None, + model: RunnerModel = None, + criterion: RunnerCriterion = None, + optimizer: RunnerOptimizer = None, + scheduler: RunnerScheduler = None, + callbacks: Dict[str, "Callback"] = None, + logdir: str = None, + num_epochs: int = 1, + main_metric: str = "loss", + minimize_metric: bool = True, + valid_loader: str = settings.loader_valid_prefix, + checkpoint_data: Dict = None, + is_check_run: bool = False, + verbose: bool = False, + **kwargs, + ): + self._unfreeze() + + # main runner components: model and device to run + self.device: Device = device + self.model: RunnerModel = model + + # extra experiment components, + # use `catalyst.core._Experiment` to setup them + self.criterion: RunnerCriterion = criterion + self.optimizer: RunnerOptimizer = optimizer + self.scheduler: RunnerScheduler = scheduler + # and callbacks + self.callbacks: Dict[str, "Callback"] = callbacks or {} + + # the data + self.loaders: OrderedDict[str, DataLoader] = None + # and the dataflow - model input, model output + self.input = None + self.output = None + + # metrics flow - batch, loader, epoch metrics + # let's use flatten storage for batch metrics + # batch_metrics = {'loss': ..., 'accuracy': ..., 'iou': ...} + self.batch_metrics: Dict = defaultdict(None) + # just aggregated (aka mean over all batches) + # batch statistics for loader + # and global loader metrics, like AUC + # loader_metrics = {'loss': ..., 'accuracy': ..., `auc`: ...} + self.loader_metrics: Dict = defaultdict(None) + # summarized metrics for different loaders + # and global epoch metrics, like lr, momentum + # epoch_metrics = { + # 'train_loss': ..., 'train_auc': ..., 'valid_loss': ..., + # 'lr': ..., 'momentum': ..., + # } + self.epoch_metrics: Dict = defaultdict(None) + + # metrics & validation + self.main_metric: str = main_metric + self.minimize_metric: bool = minimize_metric + + # validation + self.valid_loader: str = valid_loader + self.valid_metrics: Dict = defaultdict(None) + self.is_best_valid: bool = False + self.best_valid_metrics: Dict = defaultdict(None) + + # distributed info + self.distributed_rank: int = utils.get_rank() + self.is_distributed_master: bool = ~(self.distributed_rank > 0) + self.is_distributed_worker: bool = self.distributed_rank > 0 + # experiment info + self.global_sample_step: int = 0 + self.global_batch_step: int = 0 + self.global_epoch: int = 1 + self.verbose: bool = verbose + self.is_check_run: bool = is_check_run + self.need_early_stop: bool = False + self.need_exception_reraise: bool = True + # stage info + self.num_epochs: int = num_epochs + self.stage_name: str = stage + self.is_infer_stage: bool = self.stage_name.startswith( + settings.stage_infer_prefix + ) + # epoch info + self.epoch: int = 1 + # loader info + self.loader_sample_step: int = 0 + self.loader_batch_step: int = 0 + self.loader_name: str = None + self.loader_len: int = 0 + self.loader_batch_size = 0 + self.is_train_loader: bool = False + self.is_valid_loader: bool = False + self.is_infer_loader: bool = False + # batch info + self.batch_size: int = 0 + + # logging + self.expdir: Path = None + self.logdir: Path = Path(logdir) if logdir is not None else None + # extra checkpoint data for saving in checkpoint files + self.checkpoint_data: Dict = checkpoint_data or {} + + # extra + self.exception: Optional[Exception] = None + + # kwargs + for key, value in kwargs.items(): + setattr(self, key, value) + + self._freeze() + + def _init(self, **kwargs) -> None: + """ + Inner method for children's classes + to specify type for Runners' Experiment. + """ + self.experiment: _Experiment = None @property def model(self) -> Model: @@ -82,7 +503,8 @@ def model(self, value: Union[Model, Dict[str, Model]]): ) model = value - + elif isinstance(value, type(None)): + model = None else: raise TypeError( f"Invalid value type " @@ -90,7 +512,7 @@ def model(self, value: Union[Model, Dict[str, Model]]): f"got '{type(value)}'" ) - if self._device is not None: + if model is not None and self._device is not None: model: Model = utils.maybe_recursive_call( model, "to", device=self._device ) @@ -110,8 +532,12 @@ def device(self, value: Device): Args: value (Device): new torch device. """ - if isinstance(value, (str, torch.device)): + if isinstance(value, torch.device): self._device = value + elif isinstance(value, str): + self._device = torch.device(value) + elif isinstance(value, type(None)): + self._device = None else: raise TypeError( f"Invalid value type " @@ -121,19 +547,12 @@ def device(self, value: Device): if self._model is not None: self._model = utils.maybe_recursive_call( - self._model, "to", device=self._device + self._model, "to", device=self._device or "cpu" ) - def _init(self) -> None: - """ - Inner method for children's classes - to specify types for Runners' Experiment and State. - """ - self.experiment: _Experiment = None - self.state: State = None - + @staticmethod def _get_experiment_components( - self, stage: str = None + experiment: _Experiment, stage: str = None, device: Device = None, ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]: """ Inner method for `Experiment` components preparation. @@ -149,14 +568,12 @@ def _get_experiment_components( tuple: model, criterion, optimizer, scheduler and device for a given stage and model """ - utils.set_global_seed(self.experiment.initial_seed) - model = self.experiment.get_model(stage) ( + model, criterion, optimizer, scheduler, - ) = self.experiment.get_experiment_components(model, stage) - + ) = experiment.get_experiment_components(stage) ( model, criterion, @@ -168,105 +585,86 @@ def _get_experiment_components( criterion=criterion, optimizer=optimizer, scheduler=scheduler, - distributed_params=self.experiment.distributed_params, - device=self.device, + distributed_params=experiment.distributed_params, + device=device, ) - return model, criterion, optimizer, scheduler, device - def _get_state( - self, - stage: str, - model: Model, - criterion: Criterion, - optimizer: Optimizer, - scheduler: Scheduler, - device: Device, - callbacks: Dict[str, Callback], - ) -> State: - """ - Inner method for `State` preparation. + @staticmethod + def _get_experiment_callbacks( + experiment: _Experiment, stage: str, + ) -> Dict[str, Callback]: + """Inner method for `Callbacks` preparation. - Migrates State parameters from previous stage if possible, - create new State for current stage. + Takes callbacks from the Experiment + and filters them for distributed master/worker cases. Args: stage (str): stage name of interest, like "pretrain" / "train" / "finetune" / etc - model (Model): stage model - criterion (Criterion): stage criterion - optimizer (Optimizer): stage optimizer - scheduler (Scheduler): stage scheduler - device (Device): torch device - callbacks (dict): dictionary with stage callbacks Returns: - State: State instance for specified stage - - .. note:: - To learn more about Catalyst Core concepts, please check out + OrderedDict[str, Callback]: Ordered dictionary + with callbacks for current experiment stage. + """ + callbacks = experiment.get_callbacks(stage) + callbacks = utils.filter_callbacks_by_node(callbacks) + callbacks = utils.sort_callbacks_by_order(callbacks) + return callbacks - - :py:mod:`catalyst.core.experiment._Experiment` - - :py:mod:`catalyst.core.runner._Runner` - - :py:mod:`catalyst.core.state.State` - - :py:mod:`catalyst.core.callback.Callback` + def get_attr(self, key: str, inner_key: str = None) -> Any: """ - migrating_params = dict(**self.experiment.get_state_params(stage)) - migrate_from_previous_stage = migrating_params.get( - "migrate_from_previous_stage", True - ) + Alias for python `getattr` method. Useful for Callbacks preparation + and cases with multi-criterion, multi-optimizer setup. + For example, when you would like to train multi-task classification. - if ( - migrate_from_previous_stage - and self.state is not None - and self.state.callbacks is not None - ): - for key, value in self.state.callbacks.items(): - if value.scope == CallbackScope.Experiment: - callbacks[key] = value - callbacks = utils.sort_callbacks_by_order(callbacks) - - if self.state is not None and migrate_from_previous_stage: - migrating_params.update( - { - "global_batch_step": self.state.global_batch_step, - "global_sample_step": self.state.global_sample_step, - "global_epoch": self.state.global_epoch, - "resume": getattr(self.state, "resume", None), - } - ) + Used to get a named attribute from a `_Runner` by `key` keyword; + for example\ + :: - state = self._state_fn( - stage=stage, - model=model, - device=device, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - callbacks=callbacks, - **migrating_params, - ) + # example 1 + runner.get_attr("criterion") + # is equivalent to + runner.criterion - return state + # example 2 + runner.get_attr("optimizer") + # is equivalent to + runner.optimizer - def _get_callbacks(self, stage: str) -> Dict[str, Callback]: - """Inner method for `Callbacks` preparation. + # example 3 + runner.get_attr("scheduler") + # is equivalent to + runner.scheduler - Takes callbacks from the Experiment - and filters them for distributed master/worker cases. + With `inner_key` usage, it suppose to find a dictionary under `key`\ + and would get `inner_key` from this dict; for example, + :: - Args: - stage (str): stage name of interest, - like "pretrain" / "train" / "finetune" / etc + # example 1 + runner.get_attr("criterion", "bce") + # is equivalent to + runner.criterion["bce"] - Returns: - OrderedDict[str, Callback]: Ordered dictionary - with callbacks for current experiment stage. + # example 2 + runner.get_attr("optimizer", "adam") + # is equivalent to + runner.optimizer["adam"] + + # example 3 + runner.get_attr("scheduler", "adam") + # is equivalent to + runner.scheduler["adam"] + + Args: + key (str): name for attribute of interest, + like `criterion`, `optimizer`, `scheduler` + inner_key (str): name of inner dictionary key """ - callbacks = self.experiment.get_callbacks(stage) - callbacks = utils.filter_callbacks_by_node(callbacks) - callbacks = utils.sort_callbacks_by_order(callbacks) - return callbacks + if inner_key is None: + return getattr(self, key) + else: + return getattr(self, key)[inner_key] def _prepare_for_stage(self, stage: str) -> None: """ @@ -274,8 +672,8 @@ def _prepare_for_stage(self, stage: str) -> None: Sets `Experiment` initial seed. Prepares experiment components with `self._get_experiment_components`. - Prepares callbacks with `self._get_callbacks`. - Prepares `State` with `self._get_state`. + Prepares callbacks with `self._get_experiment_callbacks`. + Prepares inner state with `self._prepare_inner_state` Args: stage (str): stage name of interest, @@ -283,25 +681,43 @@ def _prepare_for_stage(self, stage: str) -> None: """ utils.set_global_seed(self.experiment.initial_seed) ( - self.model, + model, criterion, optimizer, scheduler, - self.device, - ) = self._get_experiment_components(stage=stage) + device, + ) = self._get_experiment_components( + experiment=self.experiment, stage=stage, + ) utils.set_global_seed(self.experiment.initial_seed) - callbacks = self._get_callbacks(stage) + callbacks = self._get_experiment_callbacks( + experiment=self.experiment, stage=stage + ) - utils.set_global_seed(self.experiment.initial_seed) - self.state = self._get_state( + migrating_params = dict(**self.experiment.get_stage_params(stage)) + migrate_from_previous_stage = migrating_params.get( + "migrate_from_previous_stage", True + ) + if ( + migrate_from_previous_stage + and getattr(self, "callbacks", None) is not None + ): + for key, value in self.callbacks.items(): + if value.scope == CallbackScope.Experiment: + callbacks[key] = value + + callbacks = utils.sort_callbacks_by_order(callbacks) + + self._prepare_inner_state( stage=stage, - model=self.model, + model=model, + device=device, criterion=criterion, optimizer=optimizer, scheduler=scheduler, - device=self.device, callbacks=callbacks, + **migrating_params, ) def _prepare_for_epoch(self, stage: str, epoch: int) -> None: @@ -326,8 +742,8 @@ def _run_event(self, event: str) -> None: :py:mod:`catalyst.core.callback.Callback` documentation. """ - for callback in self.state.callbacks.values(): - getattr(callback, event)(self.state) + for callback in self.callbacks.values(): + getattr(callback, event)(self) def _batch2device( self, batch: Mapping[str, Any], device: Device, @@ -369,13 +785,13 @@ def _run_batch(self, batch: Mapping[str, Any]) -> None: from DataLoader. """ if isinstance(batch, dict): - self.state.batch_size = next(iter(batch.values())).shape[0] + self.batch_size = next(iter(batch.values())).shape[0] else: - self.state.batch_size = len(batch[0]) - self.state.global_sample_step += self.state.batch_size - self.state.loader_sample_step += self.state.batch_size + self.batch_size = len(batch[0]) + self.global_sample_step += self.batch_size + self.loader_sample_step += self.batch_size batch = self._batch2device(batch, self.device) - self.state.input = batch + self.input = batch self._run_event("on_batch_start") self._handle_batch(batch=batch) @@ -389,19 +805,19 @@ def _run_loader(self, loader: DataLoader) -> None: Args: loader (DataLoader): dataloader to iterate """ - self.state.loader_batch_size = ( + self.loader_batch_size = ( loader.batch_sampler.batch_size if loader.batch_sampler is not None else loader.batch_size ) - self.state.loader_sample_step = 0 + self.loader_sample_step = 0 for i, batch in enumerate(loader): - self.state.global_batch_step += 1 - self.state.loader_batch_step = i + 1 + self.global_batch_step += 1 + self.loader_batch_step = i + 1 self._run_batch(batch) - if self.state.need_early_stop: - self.state.need_early_stop = False + if self.need_early_stop: + self.need_early_stop = False break def _run_epoch(self, stage: str, epoch: int) -> None: @@ -415,52 +831,49 @@ def _run_epoch(self, stage: str, epoch: int) -> None: epoch (int): epoch index """ self._prepare_for_epoch(stage=stage, epoch=epoch) - state: State = self.state - - assert state.loaders is not None - loaders = state.loaders + assert self.loaders is not None # @TODO: better solution with train/inference handling ? - state.is_infer_stage = state.stage_name.startswith("infer") - if not state.is_infer_stage: - assert state.valid_loader in loaders.keys(), ( - f"'{state.valid_loader}' " - f"should be in provided loaders: {list(loaders.keys())}" + self.is_infer_stage = self.stage_name.startswith("infer") + if not self.is_infer_stage: + assert self.valid_loader in self.loaders.keys(), ( + f"'{self.valid_loader}' " + f"should be in provided loaders: {list(self.loaders.keys())}" ) else: # @TODO: add check for non distributed run for inference assert not any( x.startswith(settings.loader_train_prefix) - for x in loaders.keys() + for x in self.loaders.keys() ), "for inference no train loader should be passed" - for loader_name, loader in loaders.items(): - state.loader_name = loader_name - state.loader_len = len(loader) - state.is_train_loader = loader_name.startswith( + for loader_name, loader in self.loaders.items(): + self.loader_name = loader_name + self.loader_len = len(loader) + self.is_train_loader = loader_name.startswith( settings.loader_train_prefix ) - state.is_valid_loader = loader_name.startswith( + self.is_valid_loader = loader_name.startswith( settings.loader_valid_prefix ) - state.is_infer_loader = loader_name.startswith( + self.is_infer_loader = loader_name.startswith( settings.loader_infer_prefix ) utils.maybe_recursive_call( - self.model, "train", mode=state.is_train_loader, + self.model, "train", mode=self.is_train_loader, ) if ( isinstance(loader.sampler, DistributedSampler) - and not state.is_infer_stage + and not self.is_infer_stage ): - loader.sampler.set_epoch(state.epoch) + loader.sampler.set_epoch(self.epoch) utils.set_global_seed( - self.experiment.initial_seed + state.global_epoch + 1 + self.experiment.initial_seed + self.global_epoch + 1 ) self._run_event("on_loader_start") - with torch.set_grad_enabled(state.is_train_loader): + with torch.set_grad_enabled(self.is_train_loader): self._run_loader(loader) self._run_event("on_loader_end") @@ -476,23 +889,21 @@ def _run_stage(self, stage: str) -> None: """ self._prepare_for_stage(stage) - state: State = self.state - self._run_event("on_stage_start") - while state.epoch < state.num_epochs + 1: + while self.epoch < self.num_epochs + 1: utils.set_global_seed( - self.experiment.initial_seed + state.global_epoch + 1 + self.experiment.initial_seed + self.global_epoch + 1 ) self._run_event("on_epoch_start") - self._run_epoch(stage=stage, epoch=state.epoch) + self._run_epoch(stage=stage, epoch=self.epoch) self._run_event("on_epoch_end") - if state.need_early_stop: - state.need_early_stop = False + if self.need_early_stop: + self.need_early_stop = False break - state.global_epoch += 1 - state.epoch += 1 + self.global_epoch += 1 + self.epoch += 1 self._run_event("on_stage_end") def run_experiment(self, experiment: _Experiment = None) -> "_Runner": @@ -510,6 +921,7 @@ def run_experiment(self, experiment: _Experiment = None) -> "_Runner": for stage in self.experiment.stages: self._run_stage(stage) except (Exception, KeyboardInterrupt) as ex: + from catalyst.core.callbacks.exception import ExceptionCallback def _exception_handler_check(callbacks: Union[OrderedDict, Dict]): return callbacks is not None and any( @@ -517,10 +929,8 @@ def _exception_handler_check(callbacks: Union[OrderedDict, Dict]): for x in callbacks.values() ) - if self.state is not None and _exception_handler_check( - self.state.callbacks - ): - self.state.exception = ex + if _exception_handler_check(getattr(self, "callbacks", None)): + self.exception = ex self._run_event("on_exception") else: raise ex @@ -534,24 +944,14 @@ class _StageBasedRunner(_Runner): datasources per stage. """ - _experiment_fn: Callable = _Experiment - _state_fn: Callable = State - - def _init(self): - """ - Inner method for `experiment` and `state` linting. - """ - self.experiment: _Experiment = None - self.state: State = None - def _prepare_for_stage(self, stage: str): """ Inner method to prepare `Runner` for the specified stage. Sets `Experiment` initial seed. Prepares experiment components with `self._get_experiment_components`. - Prepares callbacks with `self._get_callbacks`. - Prepares `State` with `self._get_state`. + Prepares callbacks with `self._get_experiment_callbacks`. + Prepares inner state with `self._prepare_inner_state` Additionally sets `Experiment` datasources for specified stage. Args: @@ -563,7 +963,7 @@ def _prepare_for_stage(self, stage: str): utils.set_global_seed(self.experiment.initial_seed) loaders = self.experiment.get_loaders(stage=stage) loaders = utils.validate_loaders(loaders) - self.state.loaders = loaders + self.loaders = loaders __all__ = ["_Runner", "_StageBasedRunner"] diff --git a/catalyst/core/state.py b/catalyst/core/state.py index 7c83c52125..51a796f693 100644 --- a/catalyst/core/state.py +++ b/catalyst/core/state.py @@ -1,509 +1 @@ -from typing import Any, Dict, Optional, TYPE_CHECKING, Union -from collections import defaultdict, OrderedDict -from pathlib import Path -import warnings - -from torch.utils.data import DataLoader - -from catalyst.core import utils -from catalyst.tools import settings -from catalyst.tools.frozen_class import FrozenClass -from catalyst.tools.typing import ( - Criterion, - Device, - Model, - Optimizer, - Scheduler, -) - -if TYPE_CHECKING: - from .callback import Callback # noqa: F401 - -StateModel = Union[Model, Dict[str, Model]] -StateCriterion = Union[Criterion, Dict[str, Criterion]] -StateOptimizer = Union[Optimizer, Dict[str, Optimizer]] -StateScheduler = Union[Scheduler, Dict[str, Scheduler]] - - -class State(FrozenClass): - """ - Some intermediate storage between Experiment and Runner - that saves the current state of the Experiments – - model, criterion, optimizer, schedulers, metrics, loggers, loaders, etc - - .. note:: - To learn more about Catalyst Core concepts, please check out - - - :py:mod:`catalyst.core.experiment._Experiment` - - :py:mod:`catalyst.core.runner._Runner` - - :py:mod:`catalyst.core.state.State` - - :py:mod:`catalyst.core.callback.Callback` - - **state.loaders** - ordered dictionary with torch.DataLoaders; \ - for example, - :: - - state.loaders = { - "train": MnistTrainLoader(), - "valid": MnistValidLoader() - } - - .. note:: - - "*train*" prefix is used for training loaders - \ - metrics computations, backward pass, optimization - - "*valid*" prefix is used for validation loaders - \ - metrics computations only - - "*infer*" prefix is used for inference loaders - \ - dataset prediction - - - **state.model** - an instance of torch.nn.Module class, \ - (should implement ``forward`` method); \ - for example, - :: - - state.model = torch.nn.Linear(10, 10) - - **state.criterion** - an instance of torch.nn.Module class\ - or torch.nn.modules.loss._Loss (should implement ``forward`` method); \ - for example, - :: - - state.criterion = torch.nn.CrossEntropyLoss() - - **state.optimizer** - an instance of torch.optim.optimizer.Optimizer\ - (should implement ``step`` method); \ - for example, - :: - - state.optimizer = torch.optim.Adam() - - **state.scheduler** - an instance of torch.optim.lr_scheduler._LRScheduler\ - (should implement ``step`` method); \ - for example, - :: - - state.scheduler = htorch.optim.lr_scheduler.ReduceLROnPlateau() - - **state.device** - an instance of torch.device (CPU, GPU, TPU); \ - for example, - :: - - state.device = torch.device("cpu") - - **state.callbacks** - ordered dictionary with Catalyst.Callback instances;\ - for example, - :: - - state.callbacks = { - "accuracy": AccuracyCallback(), - "criterion": CriterionCallback(), - "optim": OptimizerCallback(), - "saver": CheckpointCallback() - } - - - **state.input** - dictionary, \ - containing batch of data from currents DataLoader; \ - for example, - :: - - state.input = { - "images": np.ndarray(batch_size, c, h, w), - "targets": np.ndarray(batch_size, 1), - } - - **state.output** - dictionary, \ - containing model output for current batch; \ - for example, - :: - - state.output = {"logits": torch.Tensor(batch_size, num_classes)} - - **state.batch_metrics** - dictionary, flatten storage for batch metrics; \ - for example, - :: - - state.batch_metrics = {"loss": ..., "accuracy": ..., "iou": ...} - - **state.loader_metrics** - dictionary with aggregated batch statistics \ - for loader (mean over all batches) and global loader metrics, like AUC; \ - for example, - :: - - state.loader_metrics = {"loss": ..., "accuracy": ..., "auc": ...} - - **state.epoch_metrics** - dictionary with summarized metrics \ - for different loaders and global epoch metrics, like lr, momentum; \ - for example, - :: - - state.epoch_metrics = { - "train_loss": ..., "train_auc": ..., "valid_loss": ..., - "lr": ..., "momentum": ..., - } - - - **state.is_best_valid** - bool, indicator flag - - - ``True`` if this training epoch is best over all epochs - - ``False`` if not - - **state.valid_metrics** - dictionary with validation metrics\ - for currect epoch; \ - for example, - :: - - state.valid_metrics = {"loss": ..., "accuracy": ..., "auc": ...} - - .. note:: - subdictionary of epoch_metrics - - **state.best_valid_metrics** - dictionary with best validation metrics \ - during whole training process - - - **state.distributed_rank** - distributed rank of current worker - - **state.is_distributed_master** - bool, indicator flag - - - ``True`` if is master node (state.distributed_rank == 0) - - ``False`` if is worker node (state.distributed_rank != 0) - - **state.is_distributed_worker** - bool, indicator flag - - - ``True`` if is worker node (state.distributed_rank > 0) - - ``False`` if is master node (state.distributed_rank <= 0) - - - **state.stage_name** - string, current stage name,\ - for example, - :: - - state.stage_name = "pretraining" / "training" / "finetuning" / etc - - **state.epoch** - int, numerical indicator for current stage epoch - - **state.num_epochs** - int, maximum number of epochs, \ - required for this stage - - - **state.loader_name** - string, current loader name\ - for example, - :: - - state.loader_name = "train_dataset1" / "valid_data2" / "infer_golden" - - **state.loader_batch_step** - int, numerical indicator \ - for batch index in current loader - - **state.loader_len** - int, maximum number of batches in current loader - - **state.loader_sample_step** - int, numerical indicator \ - for number of samples passed through our model in current loader - - **state.loader_batch_size** - int, batch size parameter in current loader - - - **state.batch_size** - int, length of the current batch - - - **state.global_batch_step** - int, numerical indicator, counter for all - batches, that passes through our model during training, validation and\ - inference stages - - **state.global_sample_step** - int, numerical indicator, counter for all\ - individual samples, that passes through our model during training,\ - validation and inference stages - - **state.global_epoch** - int, numerical indicator, counter for all epochs,\ - that have passed during model training, validation and\ - inference stages - - - **state.main_metric** - string, containing name of metric of interest \ - for optimization, validation and checkpointing during training - - **state.minimize_metric** - bool, indicator flag - - - ``True`` if we need to minimize metric during training,\ - like `Cross Entropy loss` - - ``False`` if we need to maximize metric during training, \ - like `Accuracy` or `Intersection over Union` - - **state.valid_loader** - string, name of validation loader \ - for metric selection, validation and model checkpoining - - - **state.logdir** - string, path to logging directory to save\ - all logs, metrics, checkpoints and artifacts - - **state.checkpoint_data** - dictionary\ - with all extra data for experiment tracking - - - **state.is_check_run** - bool, indicator flag - - - ``True`` if you want to check you pipeline and \ - run only 2 batches per loader and 2 epochs per stage - - ``False`` (default) if you want to just the pipeline - - **state.is_train_loader** - bool, indicator flag - - - ``True`` for training loaders - - ``False`` otherwise - - **state.is_valid_loader** - bool, indicator flag - - - ``True`` for validation loaders - - ``False`` otherwise - - **state.is_infer_loader** - bool, indicator flag - - - ``True`` for inference loaders - - ``False`` otherwise - - **state.is_infer_stage** - bool, indicator flag - - - ``True`` for inference stages - - ``False`` otherwise - - **state.need_early_stop** - bool, indicator flag \ - used for EarlyStopping and CheckRun Callbacks - - - ``True`` if we need to stop the training - - ``False`` (default) otherwise - - **state.need_exception_reraise** - bool, indicator flag - - - ``True`` (default) if you want to show exception \ - during pipeline and stop the training process - - ``False`` otherwise - - **state.exception** - python Exception instance to raise (or not ;) ) - """ - - def __init__( - self, - *, - device: Device = None, - model: StateModel = None, - criterion: StateCriterion = None, - optimizer: StateOptimizer = None, - scheduler: StateScheduler = None, - callbacks: Dict[str, "Callback"] = None, - logdir: str = None, - stage: str = settings.stage_infer_prefix, # @TODO: wtf? - num_epochs: int = 1, - main_metric: str = "loss", - minimize_metric: bool = True, - valid_loader: str = settings.loader_valid_prefix, - checkpoint_data: Dict = None, - is_check_run: bool = False, - **kwargs, - ): - """ - Args: - @TODO: Docs. Contribution is welcome - """ - # main part - # data - self.loaders: OrderedDict[str, DataLoader] = None - # components - self.model: StateModel = model - self.criterion: StateCriterion = criterion - self.optimizer: StateOptimizer = optimizer - self.scheduler: StateScheduler = scheduler - # extra components - PyTorch device - self.device: Device = device - # extra components - Catalyst callbacks - self.callbacks: Dict[str, "Callback"] = callbacks - - # dataflow - model input, model output - self.input = None - self.output = None - - # metrics flow - batch, loader, epoch metrics - # let's use flatten storage for batch metrics - # batch_metrics = {'loss': ..., 'accuracy': ..., 'iou': ...} - self.batch_metrics = defaultdict(None) - # just aggregated (aka mean over all batches) - # batch statistics for loader - # and global loader metrics, like AUC - # loader_metrics = {'loss': ..., 'accuracy': ..., `auc`: ...} - self.loader_metrics = defaultdict(None) - # summarized metrics for different loaders - # and global epoch metrics, like lr, momentum - # epoch_metrics = { - # 'train_loss': ..., 'train_auc': ..., 'valid_loss': ..., - # 'lr': ..., 'momentum': ..., - # } - self.epoch_metrics = defaultdict(None) - - # validation - self.is_best_valid = False - self.valid_metrics = defaultdict(None) - self.best_valid_metrics = defaultdict(None) - - # pipeline info - self.distributed_rank = utils.get_rank() - self.is_distributed_master = ~(self.distributed_rank > 0) - self.is_distributed_worker = self.distributed_rank > 0 - - self.stage_name: str = stage - self.epoch: int = 1 - self.num_epochs: int = num_epochs - - self.loader_name: str = None - self.loader_batch_step: int = 0 - self.loader_sample_step: int = 0 - self.loader_len: int = 0 - self.loader_batch_size = 0 - - self.batch_size: int = 0 - - self.global_sample_step: int = 0 - self.global_batch_step: int = 0 - self.global_epoch: int = 1 - - # metrics & validation - self.main_metric: str = main_metric - self.minimize_metric: bool = minimize_metric - self.valid_loader: str = valid_loader - - # logging - self.logdir: Path = Path(logdir) if logdir is not None else None - # extra checkpoint data for saving in checkpoint files - self.checkpoint_data: Dict = checkpoint_data or {} - - # other - self.is_check_run: bool = is_check_run - self.is_train_loader: bool = False - self.is_valid_loader: bool = False - self.is_infer_loader: bool = False - self.is_infer_stage: bool = self.stage_name.startswith( - settings.stage_infer_prefix - ) - self.need_early_stop: bool = False - self.need_exception_reraise: bool = True - self.exception: Optional[Exception] = None - - # kwargs - for k, v in kwargs.items(): - setattr(self, k, v) - - self._freeze() - - @property - def batch_in(self): - """Alias for `state.input`. - - .. warning:: - Deprecated, saved for backward compatibility. - Please use `state.batch_in` instead. - """ - warnings.warn( - "`state.batch_in` was deprecated, " - "please use `state.input` instead", - DeprecationWarning, - ) - return self.input - - @property - def batch_out(self): - """Alias for `state.output`. - - .. warning:: - Deprecated, saved for backward compatibility. - Please use `state.batch_out` instead. - """ - warnings.warn( - "`state.batch_out` was deprecated, " - "please use `state.output` instead", - DeprecationWarning, - ) - return self.output - - @property - def need_backward_pass(self): - """Alias for `state.is_train_loader`. - - .. warning:: - Deprecated, saved for backward compatibility. - Please use `state.is_train_loader` instead. - """ - warnings.warn( - "`need_backward_pass` was deprecated, " - "please use `is_train_loader` instead", - DeprecationWarning, - ) - return self.is_train_loader - - @property - def loader_step(self): - """Alias for `state.loader_batch_step`. - - .. warning:: - Deprecated, saved for backward compatibility. - Please use `state.loader_batch_step` instead. - """ - warnings.warn( - "`loader_step` was deprecated, " - "please use `loader_batch_step` instead", - DeprecationWarning, - ) - return self.loader_batch_step - - def get_attr(self, key: str, inner_key: str = None) -> Any: - """ - Alias for python `getattr` method. Useful for Callbacks preparation - and cases with multi-criterion, multi-optimizer setup. - For example, when you would like to train multi-task classification. - - Used to get a named attribute from a `State` by `key` keyword; - for example\ - :: - - # example 1 - state.get_attr("criterion") - # is equivalent to - state.criterion - - # example 2 - state.get_attr("optimizer") - # is equivalent to - state.optimizer - - # example 3 - state.get_attr("scheduler") - # is equivalent to - state.scheduler - - With `inner_key` usage, it suppose to find a dictionary under `key`\ - and would get `inner_key` from this dict; for example, - :: - - # example 1 - state.get_attr("criterion", "bce") - # is equivalent to - state.criterion["bce"] - - # example 2 - state.get_attr("optimizer", "adam") - # is equivalent to - state.optimizer["adam"] - - # example 3 - state.get_attr("scheduler", "adam") - # is equivalent to - state.scheduler["adam"] - - Args: - key (str): name for attribute of interest, - like `criterion`, `optimizer`, `scheduler` - inner_key (str): name of inner dictionary key - """ - if inner_key is None: - return getattr(self, key) - else: - return getattr(self, key)[inner_key] +from catalyst.core.runner import _Runner as State # noqa: F401 diff --git a/catalyst/core/utils/callbacks.py b/catalyst/core/utils/callbacks.py index 4e3881bb53..237d07fb63 100644 --- a/catalyst/core/utils/callbacks.py +++ b/catalyst/core/utils/callbacks.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict, List, Union from collections import OrderedDict from catalyst.core.callback import CallbackNode @@ -6,7 +6,7 @@ def sort_callbacks_by_order( - callbacks: Union[list, OrderedDict] + callbacks: Union[List, Dict, OrderedDict] ) -> OrderedDict: """Creates an sequence of callbacks and sort them. @@ -18,7 +18,7 @@ def sort_callbacks_by_order( """ if callbacks is None: output = OrderedDict() - elif isinstance(callbacks, (Dict, OrderedDict)): + elif isinstance(callbacks, (dict, OrderedDict)): output = [(k, v) for k, v in callbacks.items()] output = sorted(output, key=lambda x: x[1].order) output = OrderedDict(output) diff --git a/catalyst/dl/callbacks/confusion_matrix.py b/catalyst/dl/callbacks/confusion_matrix.py index 77921b7074..fad5b15e99 100644 --- a/catalyst/dl/callbacks/confusion_matrix.py +++ b/catalyst/dl/callbacks/confusion_matrix.py @@ -6,7 +6,8 @@ import torch import torch.distributed -from catalyst.dl import Callback, CallbackNode, CallbackOrder, State, utils +from catalyst.core import _Runner, Callback, CallbackNode, CallbackOrder +from catalyst.dl import utils from catalyst.tools import meters @@ -89,47 +90,47 @@ def _plot_confusion_matrix( fig = utils.render_figure_to_tensor(fig) logger.add_image(f"{self.prefix}/epoch", fig, global_step=epoch) - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ self._reset_stats() - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ self._add_to_stats( - state.output[self.output_key].detach(), - state.input[self.input_key].detach(), + runner.output[self.output_key].detach(), + runner.input[self.input_key].detach(), ) - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """Loader end hook. Args: - state (State): current state + runner (_Runner): current runner """ class_names = self.class_names or [ str(i) for i in range(self.num_classes) ] confusion_matrix = self._compute_confusion_matrix() - if state.distributed_rank >= 0: + if runner.distributed_rank >= 0: confusion_matrix = torch.from_numpy(confusion_matrix) confusion_matrix = confusion_matrix.to(utils.get_device()) torch.distributed.reduce(confusion_matrix, 0) confusion_matrix = confusion_matrix.cpu().numpy() - if state.distributed_rank <= 0: - tb_callback = state.callbacks[self.tensorboard_callback_name] + if runner.distributed_rank <= 0: + tb_callback = runner.callbacks[self.tensorboard_callback_name] self._plot_confusion_matrix( - logger=tb_callback.loggers[state.loader_name], - epoch=state.global_epoch, + logger=tb_callback.loggers[runner.loader_name], + epoch=runner.global_epoch, confusion_matrix=confusion_matrix, class_names=class_names, ) diff --git a/catalyst/dl/callbacks/inference.py b/catalyst/dl/callbacks/inference.py index 8f22560dd7..fffd778535 100644 --- a/catalyst/dl/callbacks/inference.py +++ b/catalyst/dl/callbacks/inference.py @@ -3,7 +3,7 @@ import numpy as np -from catalyst.dl import Callback, CallbackOrder, State +from catalyst.core import _Runner, Callback, CallbackOrder # @TODO: refactor @@ -19,16 +19,16 @@ def __init__(self, out_dir=None, out_prefix=None): self.out_dir = out_dir self.out_prefix = out_prefix self.predictions = defaultdict(lambda: []) - self._keys_from_state = ["out_dir", "out_prefix"] + self._keys_from_runner = ["out_dir", "out_prefix"] - def on_stage_start(self, state: State): + def on_stage_start(self, runner: _Runner): """Stage start hook. Args: - state (State): current state + runner (_Runner): current runner """ - for key in self._keys_from_state: - value = getattr(state, key, None) + for key in self._keys_from_runner: + value = getattr(runner, key, None) if value is not None: setattr(self, key, value) # assert self.out_prefix is not None @@ -37,30 +37,30 @@ def on_stage_start(self, state: State): if self.out_prefix is not None: os.makedirs(os.path.dirname(self.out_prefix), exist_ok=True) - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ self.predictions = defaultdict(lambda: []) - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Batch end hook. Args: - state (State): current state + runner (_Runner): current runner """ - dct = state.output + dct = runner.output dct = {key: value.detach().cpu().numpy() for key, value in dct.items()} for key, value in dct.items(): self.predictions[key].append(value) - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """Loader end hook. Args: - state (State): current state + runner (_Runner): current runner """ self.predictions = { key: np.concatenate(value, axis=0) @@ -68,7 +68,7 @@ def on_loader_end(self, state: State): } if self.out_prefix is not None: for key, value in self.predictions.items(): - suffix = ".".join([state.loader_name, key]) + suffix = ".".join([runner.loader_name, key]) np.save(f"{self.out_prefix}/{suffix}.npy", value) diff --git a/catalyst/dl/callbacks/meter.py b/catalyst/dl/callbacks/meter.py index 947367e9e6..b6f217d3c4 100644 --- a/catalyst/dl/callbacks/meter.py +++ b/catalyst/dl/callbacks/meter.py @@ -3,14 +3,14 @@ import numpy as np -from catalyst.core import Callback, CallbackOrder, State +from catalyst.core import _Runner, Callback, CallbackOrder from catalyst.dl.utils import get_activation_fn class MeterMetricsCallback(Callback): """ A callback that tracks metrics through meters and prints metrics for - each class on `state.on_loader_end`. + each class on `runner.on_loader_end`. .. note:: This callback works for both single metric and multi-metric meters. @@ -57,35 +57,35 @@ def _reset_stats(self): for meter in self.meters: meter.reset() - def on_loader_start(self, state): + def on_loader_start(self, runner: _Runner): """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ self._reset_stats() - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Batch end hook. Computes batch metrics. Args: - state (State): current state + runner (_Runner): current runner """ - logits = state.output[self.output_key].detach().float() - targets = state.input[self.input_key].detach().float() + logits = runner.output[self.output_key].detach().float() + targets = runner.input[self.input_key].detach().float() probabilities = self.activation_fn(logits) for i in range(self.num_classes): self.meters[i].add(probabilities[:, i], targets[:, i]) - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """Loader end hook. Computes loader metrics. Args: - state (State): current state + runner (_Runner): current runner """ metrics_tracker = defaultdict(list) - loader_values = state.loader_metrics + loader_values = runner.loader_metrics # Computing metrics for each class for i, meter in enumerate(self.meters): metrics = meter.value() diff --git a/catalyst/dl/callbacks/metrics/dice.py b/catalyst/dl/callbacks/metrics/dice.py index cd6c973ae7..010b53727c 100644 --- a/catalyst/dl/callbacks/metrics/dice.py +++ b/catalyst/dl/callbacks/metrics/dice.py @@ -1,6 +1,6 @@ import numpy as np -from catalyst.core import Callback, CallbackOrder, MetricCallback, State +from catalyst.core import _Runner, Callback, CallbackOrder, MetricCallback from catalyst.dl import utils from catalyst.utils import metrics @@ -74,14 +74,14 @@ def _reset_stats(self): """Resets the confusion matrix holding the epoch-wise stats.""" self.confusion_matrix = None - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """Records the confusion matrix at the end of each batch. Args: - state (State): current state + runner (_Runner): current runner """ - outputs = state.output[self.output_key] - targets = state.input[self.input_key] + outputs = runner.output[self.output_key] + targets = runner.input[self.input_key] confusion_matrix = utils.calculate_confusion_matrix_from_tensors( outputs, targets @@ -92,11 +92,11 @@ def on_batch_end(self, state: State): else: self.confusion_matrix += confusion_matrix - def on_loader_end(self, state: State): + def on_loader_end(self, runner: _Runner): """@TODO: Docs. Contribution is welcome. Args: - state (State): current state + runner (_Runner): current runner """ tp_fp_fn_dict = utils.calculate_tp_fp_fn(self.confusion_matrix) @@ -112,15 +112,15 @@ def on_loader_end(self, state: State): self.class_names[i] if self.class_names is not None else str(i) ) - state.loader_metrics[f"{self.prefix}_{postfix}"] = dice + runner.loader_metrics[f"{self.prefix}_{postfix}"] = dice # For supporting averaging of only classes specified in `class_names` values_to_avg = [ value - for key, value in state.loader_metrics.items() + for key, value in runner.loader_metrics.items() if key.startswith(f"{self.prefix}_") ] - state.loader_metrics[f"{self.prefix}_mean"] = np.mean(values_to_avg) + runner.loader_metrics[f"{self.prefix}_mean"] = np.mean(values_to_avg) self._reset_stats() diff --git a/catalyst/dl/callbacks/mixup.py b/catalyst/dl/callbacks/mixup.py index b1e50b5c50..e82df11237 100644 --- a/catalyst/dl/callbacks/mixup.py +++ b/catalyst/dl/callbacks/mixup.py @@ -4,7 +4,8 @@ import torch -from catalyst.dl import CriterionCallback, State +from catalyst.core import _Runner +from catalyst.dl import CriterionCallback class MixupCallback(CriterionCallback): @@ -58,32 +59,32 @@ def __init__( self.index = None self.is_needed = True - def _compute_loss_value(self, state: State, criterion): + def _compute_loss_value(self, runner: _Runner, criterion): if not self.is_needed: - return super()._compute_loss_value(state, criterion) + return super()._compute_loss_value(runner, criterion) - pred = state.output[self.output_key] - y_a = state.input[self.input_key] - y_b = state.input[self.input_key][self.index] + pred = runner.output[self.output_key] + y_a = runner.input[self.input_key] + y_b = runner.input[self.input_key][self.index] loss = self.lam * criterion(pred, y_a) + (1 - self.lam) * criterion( pred, y_b ) return loss - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """Loader start hook. Args: - state (State): current state + runner (_Runner): current runner """ - self.is_needed = not self.on_train_only or state.is_train_loader + self.is_needed = not self.on_train_only or runner.is_train_loader - def on_batch_start(self, state: State): + def on_batch_start(self, runner: _Runner): """Batch start hook. Args: - state (State): current state + runner (_Runner): current runner """ if not self.is_needed: return @@ -93,13 +94,13 @@ def on_batch_start(self, state: State): else: self.lam = 1 - self.index = torch.randperm(state.input[self.fields[0]].shape[0]) - self.index.to(state.device) + self.index = torch.randperm(runner.input[self.fields[0]].shape[0]) + self.index.to(runner.device) for f in self.fields: - state.input[f] = ( - self.lam * state.input[f] - + (1 - self.lam) * state.input[f][self.index] + runner.input[f] = ( + self.lam * runner.input[f] + + (1 - self.lam) * runner.input[f][self.index] ) diff --git a/catalyst/dl/callbacks/scheduler.py b/catalyst/dl/callbacks/scheduler.py index f86b8602cb..2325d17cec 100644 --- a/catalyst/dl/callbacks/scheduler.py +++ b/catalyst/dl/callbacks/scheduler.py @@ -1,7 +1,7 @@ from typing import Optional +from catalyst.core import _Runner from catalyst.core.callbacks import LRUpdater -from catalyst.dl import State class LRFinder(LRUpdater): @@ -63,27 +63,27 @@ def calc_lr(self): self.find_iter += 1 return res - def on_loader_start(self, state: State): + def on_loader_start(self, runner: _Runner): """@TODO: Docs. Contribution is welcome. Args: - state (State): current state + runner (_Runner): current runner """ - if state.is_train_loader: + if runner.is_train_loader: lr_ = self.final_lr / self.init_lr - self.num_steps = self.num_steps or state.loader_len + self.num_steps = self.num_steps or runner.loader_len self.multiplier = lr_ ** (1 / self.num_steps) self.lr_step = (self.final_lr - self.init_lr) / self.num_steps - super().on_loader_start(state=state) + super().on_loader_start(runner=runner) - def on_batch_end(self, state: State): + def on_batch_end(self, runner: _Runner): """@TODO: Docs. Contribution is welcome. Args: - state (State): current state + runner (_Runner): current runner """ - super().on_batch_end(state=state) + super().on_batch_end(runner=runner) if self.find_iter > self.num_steps: raise NotImplementedError("End of LRFinder") diff --git a/catalyst/dl/experiment/__init__.py b/catalyst/dl/experiment/__init__.py index 45d32d0f3f..97d6d17400 100644 --- a/catalyst/dl/experiment/__init__.py +++ b/catalyst/dl/experiment/__init__.py @@ -1,5 +1,5 @@ # flake8: noqa from .config import ConfigExperiment -from .core import Experiment +from .experiment import Experiment from .supervised import SupervisedExperiment diff --git a/catalyst/dl/experiment/config.py b/catalyst/dl/experiment/config.py index f4402e3deb..bc9108866b 100644 --- a/catalyst/dl/experiment/config.py +++ b/catalyst/dl/experiment/config.py @@ -46,7 +46,7 @@ class ConfigExperiment(_Experiment): "scheduler_params", "data_params", "transform_params", - "state_params", + "stage_params", "callbacks_params", ] @@ -68,8 +68,11 @@ def __init__(self, config: Dict): ) self.__prepare_logdir() - self._config["stages"]["state_params"] = utils.merge_dicts( - deepcopy(self._config["stages"].get("state_params", {})), + self._config["stages"]["stage_params"] = utils.merge_dicts( + deepcopy( + self._config["stages"].get("state_params", {}) + ), # saved for backward compatibility + deepcopy(self._config["stages"].get("stage_params", {})), deepcopy(self._config.get("args", {})), {"logdir": self._logdir}, ) @@ -95,19 +98,36 @@ def _get_stages_config(self, stages_config: Dict): stages_defaults = {} stages_config_out = OrderedDict() for key in self.STAGE_KEYWORDS: - stages_defaults[key] = deepcopy(stages_config.get(key, {})) + if key == "stage_params": + # backward compatibility + stages_defaults[key] = utils.merge_dicts( + deepcopy(stages_config.get("state_params", {})), + deepcopy(stages_config.get(key, {})), + ) + else: + stages_defaults[key] = deepcopy(stages_config.get(key, {})) for stage in stages_config: if ( stage in self.STAGE_KEYWORDS + or stage == "state_params" or stages_config.get(stage) is None ): continue stages_config_out[stage] = {} for key in self.STAGE_KEYWORDS: - stages_config_out[stage][key] = utils.merge_dicts( - deepcopy(stages_defaults.get(key, {})), - deepcopy(stages_config[stage].get(key, {})), - ) + if key == "stage_params": + # backward compatibility + stages_config_out[stage][key] = utils.merge_dicts( + deepcopy(stages_defaults.get("state_params", {})), + deepcopy(stages_defaults.get(key, {})), + deepcopy(stages_config[stage].get("state_params", {})), + deepcopy(stages_config[stage].get(key, {})), + ) + else: + stages_config_out[stage][key] = utils.merge_dicts( + deepcopy(stages_defaults.get(key, {})), + deepcopy(stages_config[stage].get(key, {})), + ) return stages_config_out @@ -131,23 +151,6 @@ def logdir(self): def stages(self) -> List[str]: """Experiment's stage names.""" stages_keys = list(self.stages_config.keys()) - - # @TODO: return the feature - # # Change start `stages_keys` if resume data were founded - # state_params = self.get_state_params(stages_keys[0]) - # resume, resume_dir = [ - # state_params.get(key, None) for key in ["resume", "resume_dir"] - # ] - # - # if resume_dir is not None: - # resume = resume_dir / str(resume) - # - # if resume is not None and Path(resume).is_file(): - # checkpoint = utils.load_checkpoint(resume) - # start_stage = checkpoint["stage"] - # start_idx = stages_keys.index(start_stage) - # stages_keys = stages_keys[start_idx:] - return stages_keys @property @@ -155,16 +158,9 @@ def distributed_params(self) -> Dict: """Dict with the parameters for distributed and FP16 methond.""" return self._config.get("distributed_params", {}) - def get_state_params(self, stage: str) -> Mapping[str, Any]: + def get_stage_params(self, stage: str) -> Mapping[str, Any]: """Returns the state parameters for a given stage.""" - return self.stages_config[stage].get("state_params", {}) - - def _preprocess_model_for_stage(self, stage: str, model: Model): - # stage_index = self.stages.index(stage) - return model - - def _postprocess_model_for_stage(self, stage: str, model: Model): - return model + return self.stages_config[stage].get("stage_params", {}) @staticmethod def _get_model(**params): @@ -183,9 +179,6 @@ def get_model(self, stage: str): """Returns the model for a given stage.""" model_params = self._config["model_params"] model = self._get_model(**model_params) - - model = self._preprocess_model_for_stage(stage, model) - model = self._postprocess_model_for_stage(stage, model) return model @staticmethod @@ -285,9 +278,9 @@ def _get_optimizer( device = utils.get_device() for param in model_params: param = param["params"][0] - state = optimizer.state[param] - for key, value in state.items(): - state[key] = utils.any2device(value, device) + optimizer_state = optimizer.state[param] + for key, value in optimizer_state.items(): + optimizer_state[key] = utils.any2device(value, device) # update optimizer params for key, value in params.items(): diff --git a/catalyst/dl/experiment/core.py b/catalyst/dl/experiment/experiment.py similarity index 96% rename from catalyst/dl/experiment/core.py rename to catalyst/dl/experiment/experiment.py index 481fb8ad67..4d91e01b79 100644 --- a/catalyst/dl/experiment/core.py +++ b/catalyst/dl/experiment/experiment.py @@ -47,7 +47,7 @@ def __init__( verbose: bool = False, check_time: bool = False, check_run: bool = False, - state_kwargs: Dict = None, + stage_kwargs: Dict = None, checkpoint_data: Dict = None, distributed_params: Dict = None, initial_seed: int = 42, @@ -85,7 +85,7 @@ def __init__( of training process and displays it to the console. check_run (bool): if True, we run only 3 batches per loader and 3 epochs per stage to check pipeline correctness - state_kwargs (dict): additional state params to ``State`` + stage_kwargs (dict): additional stage params checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc distributed_params (dict): dictionary with the parameters @@ -119,7 +119,7 @@ def __init__( self._verbose = verbose self._check_time = check_time self._check_run = check_run - self._state_kwargs = state_kwargs or {} + self._stage_kwargs = stage_kwargs or {} self._checkpoint_data = checkpoint_data or {} self._distributed_params = distributed_params or {} @@ -169,7 +169,7 @@ def process_loaders( ) return loaders, valid_loader - def get_state_params(self, stage: str) -> Mapping[str, Any]: + def get_stage_params(self, stage: str) -> Mapping[str, Any]: """Returns the state parameters for a given stage.""" default_params = { "logdir": self.logdir, @@ -180,8 +180,8 @@ def get_state_params(self, stage: str) -> Mapping[str, Any]: "minimize_metric": self._minimize_metric, "checkpoint_data": self._checkpoint_data, } - state_params = {**default_params, **self._state_kwargs} - return state_params + stage_params = {**default_params, **self._stage_kwargs} + return stage_params def get_model(self, stage: str) -> Model: """Returns the model for a given stage.""" diff --git a/catalyst/dl/experiment/supervised.py b/catalyst/dl/experiment/supervised.py index c92a989f7d..7fc896a013 100644 --- a/catalyst/dl/experiment/supervised.py +++ b/catalyst/dl/experiment/supervised.py @@ -10,7 +10,7 @@ ) from catalyst.tools.typing import Criterion, Optimizer, Scheduler -from .core import Experiment +from .experiment import Experiment class SupervisedExperiment(Experiment): @@ -33,9 +33,9 @@ class SupervisedExperiment(Experiment): your model/criterion/optimizer/metrics. ConsoleLogger: standard Catalyst logger, - translates ``state.*_metrics`` to console and text file + translates ``runner.*_metrics`` to console and text file TensorboardLogger: - will write ``state.*_metrics`` to tensorboard + will write ``runner.*_metrics`` to tensorboard RaiseExceptionCallback: will raise exception if needed """ diff --git a/catalyst/dl/experiment/tests/test_config.py b/catalyst/dl/experiment/tests/test_config.py index ad70098663..760c716c1a 100644 --- a/catalyst/dl/experiment/tests/test_config.py +++ b/catalyst/dl/experiment/tests/test_config.py @@ -94,7 +94,7 @@ def test_defaults(): assert exp.logdir == "./logdir" assert exp.stages == ["train"] assert exp.distributed_params == {} - assert exp.get_state_params("train") == { + assert exp.get_stage_params("train") == { "logdir": "./logdir", } assert isinstance(exp.get_model("train"), SomeModel) @@ -126,7 +126,7 @@ def test_defaults_criterion_optimizer_scheduler(): assert exp.logdir == "./logdir" assert exp.stages == ["train"] assert exp.distributed_params == {} - assert exp.get_state_params("train") == { + assert exp.get_stage_params("train") == { "logdir": "./logdir", } assert isinstance(exp.get_model("train"), SomeModel) diff --git a/catalyst/dl/experiment/tests/test_core.py b/catalyst/dl/experiment/tests/test_core.py index b79e2ba960..f8f5345299 100644 --- a/catalyst/dl/experiment/tests/test_core.py +++ b/catalyst/dl/experiment/tests/test_core.py @@ -8,7 +8,7 @@ MetricManagerCallback, ValidationManagerCallback, ) -from catalyst.dl.experiment.core import Experiment +from catalyst.dl.experiment.experiment import Experiment def _test_callbacks(test_callbacks, exp, stage="train"): @@ -51,7 +51,7 @@ def test_defaults(): assert exp.logdir is None assert exp.stages == ["train"] assert exp.distributed_params == {} - assert exp.get_state_params("") == { + assert exp.get_stage_params("") == { "logdir": None, "num_epochs": 1, "valid_loader": "train", diff --git a/catalyst/dl/runner/__init__.py b/catalyst/dl/runner/__init__.py index 76cd33be6c..18af4312b1 100644 --- a/catalyst/dl/runner/__init__.py +++ b/catalyst/dl/runner/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa -from .core import Runner +from .runner import Runner from .supervised import SupervisedRunner diff --git a/catalyst/dl/runner/core.py b/catalyst/dl/runner/runner.py similarity index 95% rename from catalyst/dl/runner/core.py rename to catalyst/dl/runner/runner.py index 6b0e3c1eb2..0da9f0d582 100644 --- a/catalyst/dl/runner/core.py +++ b/catalyst/dl/runner/runner.py @@ -5,13 +5,9 @@ from torch.jit import ScriptModule from torch.utils.data import DataLoader, Dataset -from catalyst.core import ( - _StageBasedRunner, - Callback, - CheckpointCallback, - State, -) -from catalyst.dl import Experiment, utils +from catalyst.core import _StageBasedRunner, Callback, CheckpointCallback +from catalyst.dl import utils +from catalyst.dl.experiment.experiment import Experiment from catalyst.tools.typing import ( Criterion, Device, @@ -23,15 +19,13 @@ class Runner(_StageBasedRunner): """ - Deep Learning Runner for different supervised, unsupervised, gan, etc runs. + Deep Learning Runner for supervised, unsupervised, gan, etc runs. """ _experiment_fn: Callable = Experiment - _state_fn: Callable = State - def _init(self): + def _init(self, **kwargs): self.experiment: Experiment = None - self.state: State = None def train( self, @@ -50,7 +44,7 @@ def train( main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, - state_kwargs: Dict = None, + stage_kwargs: Dict = None, checkpoint_data: Dict = None, fp16: Union[Dict, bool] = None, distributed: bool = False, @@ -58,6 +52,7 @@ def train( timeit: bool = False, load_best_on_end: bool = False, initial_seed: int = 42, + state_kwargs: Dict = None, ) -> None: """ Starts the train stage of the model. @@ -90,7 +85,7 @@ def train( the ``main_metric`` should be minimized. verbose (bool): if `True`, it displays the status of the training to the console. - state_kwargs (dict): additional state params for ``State`` + stage_kwargs (dict): additional params for stage checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc fp16 (Union[Dict, bool]): If not None, then sets training to FP16. @@ -108,6 +103,8 @@ def train( according to validation metrics. Requires specified ``logdir``. initial_seed (int): experiment's initial seed value """ + assert state_kwargs is None or stage_kwargs is None + if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} @@ -147,7 +144,7 @@ def train( verbose=verbose, check_time=timeit, check_run=check, - state_kwargs=state_kwargs, + stage_kwargs=stage_kwargs or state_kwargs, checkpoint_data=checkpoint_data, distributed_params=fp16, initial_seed=initial_seed, @@ -165,11 +162,12 @@ def infer( logdir: str = None, resume: str = None, verbose: bool = False, - state_kwargs: Dict = None, + stage_kwargs: Dict = None, fp16: Union[Dict, bool] = None, check: bool = False, timeit: bool = False, initial_seed: int = 42, + state_kwargs: Dict = None, ) -> None: """ Starts the inference stage of the model. @@ -189,7 +187,7 @@ def infer( logdir (str): path to output directory verbose (bool): if `True`, it displays the status of the training to the console. - state_kwargs (dict): additional state params for ``State`` + stage_kwargs (dict): additional stage params checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc fp16 (Union[Dict, bool]): If not None, then sets training to FP16. @@ -201,6 +199,8 @@ def infer( of training process and displays it to the console. initial_seed (int): experiment's initial seed value """ + assert state_kwargs is None or stage_kwargs is None + if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} @@ -224,7 +224,7 @@ def infer( verbose=verbose, check_time=timeit, check_run=check, - state_kwargs=state_kwargs, + stage_kwargs=stage_kwargs or state_kwargs, distributed_params=fp16, initial_seed=initial_seed, ) diff --git a/catalyst/dl/runner/supervised.py b/catalyst/dl/runner/supervised.py index e846e397fc..4f4f409f8b 100644 --- a/catalyst/dl/runner/supervised.py +++ b/catalyst/dl/runner/supervised.py @@ -3,10 +3,9 @@ import torch -from catalyst.dl import State, SupervisedExperiment -from catalyst.tools.typing import Device, Model - -from .core import Runner +from catalyst.dl.experiment.supervised import SupervisedExperiment +from catalyst.dl.runner.runner import Runner +from catalyst.tools.typing import Device, RunnerModel logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ class SupervisedRunner(Runner): def __init__( self, - model: Model = None, + model: RunnerModel = None, device: Device = None, input_key: Any = "features", output_key: Any = "logits", @@ -26,14 +25,36 @@ def __init__( ): """ Args: - model (Module): Torch model object + model (RunnerModel): Torch model object device (Device): Torch device input_key (Any): Key in batch dict mapping for model input output_key (Any): Key in output dict model output will be stored under input_target_key (str): Key in batch dict mapping for target """ - super().__init__(model=model, device=device) + super().__init__( + model=model, + device=device, + input_key=input_key, + output_key=output_key, + input_target_key=input_target_key, + ) + + def _init( + self, + input_key: Any = "features", + output_key: Any = "logits", + input_target_key: str = "targets", + ): + """ + Args: + input_key (Any): Key in batch dict mapping for model input + output_key (Any): Key in output dict model output + will be stored under + input_target_key (str): Key in batch dict mapping for target + """ + self.experiment: SupervisedExperiment = None + self.input_key = input_key self.output_key = output_key self.target_key = input_target_key @@ -62,10 +83,6 @@ def __init__( else: raise NotImplementedError() - def _init(self): - self.experiment: SupervisedExperiment = None - self.state: State = None - def _batch2device(self, batch: Mapping[str, Any], device: Device): if isinstance(batch, (tuple, list)): assert len(batch) == 2 @@ -121,7 +138,7 @@ def _handle_batch(self, batch: Mapping[str, Any]) -> None: batch (Mapping[str, Any]): dictionary with data batches from DataLoader. """ - self.state.output = self.forward(batch) + self.output = self.forward(batch) @torch.no_grad() def predict_batch( diff --git a/catalyst/dl/utils/__init__.py b/catalyst/dl/utils/__init__.py index 300d3ba690..859646d3ce 100644 --- a/catalyst/dl/utils/__init__.py +++ b/catalyst/dl/utils/__init__.py @@ -11,6 +11,6 @@ save_traced_model, trace_model, trace_model_from_checkpoint, - trace_model_from_state, + trace_model_from_runner, ) from .wizard import run_wizard, Wizard diff --git a/catalyst/dl/utils/trace.py b/catalyst/dl/utils/trace.py index 4bc622888c..c875ab5c82 100644 --- a/catalyst/dl/utils/trace.py +++ b/catalyst/dl/utils/trace.py @@ -3,7 +3,6 @@ Callable, Dict, List, - TYPE_CHECKING, Union, ) import inspect @@ -12,7 +11,7 @@ from torch import nn from torch.jit import load, save, ScriptModule, trace -from catalyst.core.state import State +from catalyst.core.runner import _Runner from catalyst.dl.experiment.config import ConfigExperiment from catalyst.tools.typing import Device, Model from catalyst.utils import ( @@ -30,9 +29,6 @@ unpack_checkpoint, ) -if TYPE_CHECKING: - from catalyst.dl import Runner # noqa: F401 - def _get_input_argnames( fn: Callable[..., Any], exclude: List[str] = None @@ -242,8 +238,8 @@ def predict_fn(model, inputs, **kwargs): return traced_model -def trace_model_from_state( - state: State, +def trace_model_from_runner( + runner: _Runner, checkpoint_name: str = None, method_name: str = "forward", mode: str = "eval", @@ -255,9 +251,9 @@ def trace_model_from_state( Traces model using created experiment and runner. Args: - state (State): Current runner state. + runner (Runner): Current runner. checkpoint_name (str): Name of model checkpoint to use, if None - traces current model from state + traces current model from runner method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) @@ -268,8 +264,8 @@ def trace_model_from_state( Returns: (ScriptModule): Traced model """ - logdir = state.logdir - model = get_nn_from_ddp_module(state.model) + logdir = runner.logdir + model = get_nn_from_ddp_module(runner.model) if checkpoint_name is not None: dumped_checkpoint = pack_checkpoint(model=model) @@ -285,17 +281,17 @@ def trace_model_from_state( batch = {} for name in method_argnames: # TODO: We don't know input_keys without runner - assert name in state.input, ( + assert name in runner.input, ( "Input batch should contain the same keys as input argument " "names of `forward` function to be traced correctly" ) - batch[name] = state.input[name] + batch[name] = runner.input[name] batch = any2device(batch, device) - # Dumping previous state of the model, we will need it to restore + # Dumping previous runner of the model, we will need it to restore _device, _is_training, _requires_grad = ( - state.device, + runner.device, model.training, get_requires_grad(model), ) @@ -320,7 +316,7 @@ def predict_fn(model: Model, inputs, **kwargs): if checkpoint_name is not None: unpack_checkpoint(checkpoint=dumped_checkpoint, model=model) - # Restore previous state of the model + # Restore previous runner of the model getattr(model, "train" if _is_training else "eval")() set_requires_grad(model, _requires_grad) model.to(_device) @@ -454,7 +450,7 @@ def load_traced_model( __all__ = [ "trace_model", "trace_model_from_checkpoint", - "trace_model_from_state", + "trace_model_from_runner", "get_trace_name", "save_traced_model", "load_traced_model", diff --git a/catalyst/dl/utils/wizard.py b/catalyst/dl/utils/wizard.py index 09272e4b68..a97b80647d 100644 --- a/catalyst/dl/utils/wizard.py +++ b/catalyst/dl/utils/wizard.py @@ -219,14 +219,14 @@ def _basic_params_step(self, param, stage, optional=False): self.__res(opts, is_yaml=True) stage[f"{param}_params"] = opts - def _state_params_step(self, stage): + def _stage_params_step(self, stage): """ Step #5.b - ``state_params`` of Experiment. + ``stage_params`` of Experiment. """ - self.__sep(f"state_params") - if self._skip_override_stages_common("state_params"): + self.__sep(f"stage_params") + if self._skip_override_stages_common("stage_params"): return opts = OrderedDict() opts["num_epochs"] = int( @@ -244,7 +244,7 @@ def _state_params_step(self, stage): ) opts["minimize_metric"] = minimize self.__res(opts["minimize_metric"]) - stage["state_params"] = opts + stage["stage_params"] = opts def _data_params_step(self, stage): """ @@ -276,7 +276,7 @@ def _stage_step(self, stage): method to gather all we need to know about the stage and its settings """ self._data_params_step(stage) - self._state_params_step(stage) + self._stage_params_step(stage) self._basic_params_step("criterion", stage) self._basic_params_step("optimizer", stage) self._basic_params_step("scheduler", stage, optional=True) diff --git a/catalyst/tools/frozen_class.py b/catalyst/tools/frozen_class.py index 7a840b1c5b..ca6878edd3 100644 --- a/catalyst/tools/frozen_class.py +++ b/catalyst/tools/frozen_class.py @@ -1,6 +1,6 @@ """ Frozen class. -Example of usage can be found in :py:class:`catalyst.core.state.State`. +Example of usage can be found in :py:class:`catalyst.core.runner._Runner`. """ @@ -8,16 +8,19 @@ class FrozenClass: """Class which prohibit ``__setattr__`` on existing attributes. Examples: - >>> class State(FrozenClass): + >>> class _Runner(FrozenClass): """ - __isfrozen = False + __is_frozen = False def __setattr__(self, key, value): """@TODO: Docs. Contribution is welcome.""" - if self.__isfrozen and not hasattr(self, key): - raise TypeError("%r is a frozen class" % self) + if self.__is_frozen and not hasattr(self, key): + raise TypeError("%r is a frozen class for key %s" % (self, key)) object.__setattr__(self, key, value) def _freeze(self): - self.__isfrozen = True + self.__is_frozen = True + + def _unfreeze(self): + self.__is_frozen = False diff --git a/catalyst/tools/typing.py b/catalyst/tools/typing.py index 1492f6b2e5..14a7289864 100644 --- a/catalyst/tools/typing.py +++ b/catalyst/tools/typing.py @@ -1,7 +1,7 @@ """ All Catalyst custom types are defined in this module. """ -from typing import Union +from typing import Dict, Union import torch from torch import nn, optim @@ -15,6 +15,11 @@ Dataset = data.Dataset Device = Union[str, torch.device] +RunnerModel = Union[Model, Dict[str, Model]] +RunnerCriterion = Union[Criterion, Dict[str, Criterion]] +RunnerOptimizer = Union[Optimizer, Dict[str, Optimizer]] +RunnerScheduler = Union[Scheduler, Dict[str, Scheduler]] + __all__ = [ "Model", "Criterion", @@ -22,4 +27,8 @@ "Scheduler", "Dataset", "Device", + "RunnerModel", + "RunnerCriterion", + "RunnerOptimizer", + "RunnerScheduler", ] diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index 2d2d78085b..da81725bde 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -105,7 +105,7 @@ Callbacks AlchemyLogger """"""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.alchemy +.. automodule:: catalyst.contrib.dl.callbacks.alchemy_logger :members: :undoc-members: :show-inheritance: @@ -117,37 +117,44 @@ CutmixCallback :undoc-members: :show-inheritance: -InferMaskCallback -""""""""""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.inference +GradNormLogger +"""""""""""""""""""""" +.. automodule:: catalyst.contrib.dl.callbacks.gradnorm_logger :members: :undoc-members: :show-inheritance: KNNMetricCallback """"""""""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.knn +.. automodule:: catalyst.contrib.dl.callbacks.knn_metric + :members: + :undoc-members: + :show-inheritance: + +InferMaskCallback +""""""""""""""""" +.. automodule:: catalyst.contrib.dl.callbacks.mask_inference :members: :undoc-members: :show-inheritance: NeptuneLogger """"""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.neptune +.. automodule:: catalyst.contrib.dl.callbacks.neptune_logger :members: :undoc-members: :show-inheritance: -SaveModelGradsCallback +PeriodicLoaderCallback """""""""""""""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.optimizer +.. automodule:: catalyst.contrib.dl.callbacks.periodic_loader_callback :members: :undoc-members: :show-inheritance: -PeriodicLoaderCallback -"""""""""""""""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.periodic_loader +PerplexityMetricCallback +"""""""""""""""""""""""" +.. automodule:: catalyst.contrib.dl.callbacks.perplexity_metric :members: :undoc-members: :show-inheritance: @@ -161,7 +168,7 @@ TelegramLogger TracerCallback """""""""""""""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.trace +.. automodule:: catalyst.contrib.dl.callbacks.tracer_callback :members: :undoc-members: :show-inheritance: @@ -175,7 +182,7 @@ VisdomLogger WandbLogger """""""""""""""""""""" -.. automodule:: catalyst.contrib.dl.callbacks.wandb +.. automodule:: catalyst.contrib.dl.callbacks.wandb_logger :members: :undoc-members: :show-inheritance: diff --git a/docs/api/core.rst b/docs/api/core.rst index c4cab6b90f..d154c9a7cf 100644 --- a/docs/api/core.rst +++ b/docs/api/core.rst @@ -51,13 +51,6 @@ Callback :undoc-members: :show-inheritance: -State -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: catalyst.core.state - :members: - :undoc-members: - :show-inheritance: - Callbacks ---------------------- @@ -158,3 +151,14 @@ Utils :members: :undoc-members: :show-inheritance: + + +Legacy +---------------------- + +Runner +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: catalyst.core.legacy._RunnerLegacy + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/dl.rst b/docs/api/dl.rst index a7372b3cdc..4f6c2ffb0d 100644 --- a/docs/api/dl.rst +++ b/docs/api/dl.rst @@ -16,7 +16,7 @@ Experiment Experiment ~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: catalyst.dl.experiment.core +.. automodule:: catalyst.dl.experiment.experiment :members: :undoc-members: :show-inheritance: @@ -45,7 +45,7 @@ Runner Runner ~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: catalyst.dl.runner.core +.. automodule:: catalyst.dl.runner.runner :members: :undoc-members: :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index e7828d08b0..b4b37f8de0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -64,14 +64,14 @@ Getting started loss = F.cross_entropy(y_hat, y) accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3)) - self.state.batch_metrics.update( + self.batch_metrics.update( {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03} ) - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() runner = CustomRunner() # model training @@ -143,7 +143,7 @@ Features Structure ~~~~~~~~~~~~~~~~~~~~~~ -- **core** - framework core with main abstractions - Experiment, Runner, Callback and State. +- **core** - framework core with main abstractions - Experiment, Runner and Callback. - **data** - useful tools and scripts for data processing. - **dl** – runner for training and inference, all of the classic ML and CV/NLP/RecSys metrics and a variety of callbacks for training, validation and inference of neural networks. - **tools** - extra tools for Deep Learning research, class-based helpers. diff --git a/examples/_empty/configs/config.yml b/examples/_empty/configs/config.yml index 96dd4c8f8b..21b0465921 100644 --- a/examples/_empty/configs/config.yml +++ b/examples/_empty/configs/config.yml @@ -11,7 +11,7 @@ stages: batch_size: 64 # CHANGE ME num_workers: 1 # CHANGE ME - state_params: + stage_params: num_epochs: 2 # CHANGE ME main_metric: &reduced_metric loss # loss for scheduler and checkpoint saver minimize_metric: True # Change if you change `main_metric` diff --git a/examples/cifar_simple/config.yml b/examples/cifar_simple/config.yml index 39d18aba63..c134a12937 100644 --- a/examples/cifar_simple/config.yml +++ b/examples/cifar_simple/config.yml @@ -13,7 +13,7 @@ stages: batch_size: 64 num_workers: 1 - state_params: + stage_params: num_epochs: 2 main_metric: &reduced_metric accuracy01 minimize_metric: False diff --git a/examples/cifar_simple/config_experiment.yml b/examples/cifar_simple/config_experiment.yml index d7b41897ae..792466f3ad 100644 --- a/examples/cifar_simple/config_experiment.yml +++ b/examples/cifar_simple/config_experiment.yml @@ -17,7 +17,7 @@ stages: batch_size: 64 num_workers: 1 - state_params: + stage_params: num_epochs: 2 main_metric: &reduced_metric accuracy01 minimize_metric: False diff --git a/examples/cifar_stages/config.yml b/examples/cifar_stages/config.yml index d6c8e71a25..720cec9cbe 100644 --- a/examples/cifar_stages/config.yml +++ b/examples/cifar_stages/config.yml @@ -33,7 +33,7 @@ stages: - transform: A.Normalize - transform: catalyst.ToTensor - state_params: + stage_params: num_epochs: 3 main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -79,7 +79,7 @@ stages: # tune stage2: - state_params: + stage_params: num_epochs: 3 optimizer_params: diff --git a/examples/cifar_stages/experiment.py b/examples/cifar_stages/experiment.py index bc13a8f8f8..f77d8bb2e0 100644 --- a/examples/cifar_stages/experiment.py +++ b/examples/cifar_stages/experiment.py @@ -1,7 +1,6 @@ from collections import OrderedDict import torch -from torch import nn import torchvision from catalyst import utils @@ -31,16 +30,24 @@ def __getitem__(self, index: int): class Experiment(ConfigExperiment): """``ConfigExperiment`` with CIFAR10 dataset.""" - def _postprocess_model_for_stage(self, stage: str, model: nn.Module): - model_ = model + def get_model(self, stage: str): + """ + Model specification for currect stage + Args: + stage: current stage name + + Returns: + model + """ + model = super().get_model(stage=stage) if isinstance(model, torch.nn.DataParallel): - model_ = model_.module + model = model.module if stage == "stage2": for key in ["conv1", "pool", "conv2"]: - layer = getattr(model_, key) + layer = getattr(model, key) utils.set_requires_grad(layer, requires_grad=False) - return model_ + return model def get_datasets(self, stage: str, **kwargs): """Provides train/validation subsets from CIFAR10 dataset. diff --git a/examples/configs/config-description-eng.yml b/examples/configs/config-description-eng.yml index 9151590311..4c75ad81a1 100644 --- a/examples/configs/config-description-eng.yml +++ b/examples/configs/config-description-eng.yml @@ -87,7 +87,7 @@ stages: # REQUIRED KEYWORD, dictionary of all stages of Catalyst, for training infer: *transform # only for infer dataset - state_params: # REQUIRED KEYWORD, parameters for State (for all stages) + stage_params: # REQUIRED KEYWORD, parameters for all stages main_metric: &main_metric accuracy01 # REQUIRED KEYWORD, the name of the metric by which the checkpoints will be taken minimize_metric: False # REQUIRED KEYWORD, flag, should we minimize `main_metric` num_epochs: 2 # KEYWORD, The number of epochs in all the stages @@ -130,7 +130,7 @@ stages: # REQUIRED KEYWORD, dictionary of all stages of Catalyst, for training gamma: 0.3 stage1: # Anything that's not a keyword is considered a name for a stage. For training in Catalyst, at least one stage is required. The name can be anything. - state_params: # You can override any parameters for a particular stage, for example + stage_params: # You can override any parameters for a particular stage, for example num_epochs: 3 callbacks_params: # REQUIRED KEYWORD, The most important part. It's where all the callbacks are written down for this stage. @@ -148,7 +148,7 @@ stages: # REQUIRED KEYWORD, dictionary of all stages of Catalyst, for training save_n_best: 3 finetune: # Example of a second training stage, here we can change our parameters - state_params: # You can override any parameters for a particular stage, for example + stage_params: # You can override any parameters for a particular stage, for example num_epochs: 1 optimizer_params: # Example of an overridden optimizer diff --git a/examples/configs/config-description-rus.yml b/examples/configs/config-description-rus.yml index 8ef1743334..27522ecc41 100644 --- a/examples/configs/config-description-rus.yml +++ b/examples/configs/config-description-rus.yml @@ -87,7 +87,7 @@ stages: # REQUIRED KEYWORD, словарь всех стадий Catalyst, дл infer: *transform # трансформация только для infer датасета - state_params: # REQUIRED KEYWORD, параметры для State (для всех стейджей) + stage_params: # REQUIRED KEYWORD, параметры для всех стейджей main_metric: &main_metric accuracy01 # REQUIRED KEYWORD, имя метрики, по которой будут отбираться чекпоинты minimize_metric: False # REQUIRED KEYWORD, флаг, нужно ли минимизировать `main_metric` num_epochs: 2 # KEYWORD, Количество эпох во всех стейджах @@ -130,7 +130,7 @@ stages: # REQUIRED KEYWORD, словарь всех стадий Catalyst, дл gamma: 0.3 stage1: # Все, что не ключевое слово, расценивается, как имя стейджа. Для тренировки в Catalyst требуется хотябы один стейдж. Имя может быть произвольным - state_params: # Вы можете переопределить любые параметры, для конкретного стейджа, например + stage_params: # Вы можете переопределить любые параметры, для конкретного стейджа, например num_epochs: 3 callbacks_params: # REQUIRED KEYWORD, самая важная часть, тут записываются все коллбеки для данного стейджа @@ -148,7 +148,7 @@ stages: # REQUIRED KEYWORD, словарь всех стадий Catalyst, дл save_n_best: 3 finetune: # Пример второго стейджа обучения, тут мы можем изменить наши параметры - state_params: # Вы можете переопределить любые параметры, для конкретного стейджа, например + stage_params: # Вы можете переопределить любые параметры, для конкретного стейджа, например num_epochs: 1 optimizer_params: # Пример, переопределенного оптимизатора diff --git a/examples/distilbert_text_classification/config.yml b/examples/distilbert_text_classification/config.yml index 49140a24c2..c43327a730 100644 --- a/examples/distilbert_text_classification/config.yml +++ b/examples/distilbert_text_classification/config.yml @@ -28,7 +28,7 @@ stages: label_field: "label" max_sequence_length: 512 - state_params: + stage_params: main_metric: &reduced_metric loss minimize_metric: True @@ -59,7 +59,7 @@ stages: # params specific for stage 1 called "train_val" train_val: # overriding state params and specifying that we train for 2 epochs - state_params: + stage_params: num_epochs: 2 # optimizer params are specific only for this stage # in principle, we can define another stage with other optim params diff --git a/examples/notebooks/classification-tutorial.ipynb b/examples/notebooks/classification-tutorial.ipynb index 9033c67c44..f934ada1a8 100644 --- a/examples/notebooks/classification-tutorial.ipynb +++ b/examples/notebooks/classification-tutorial.ipynb @@ -1138,7 +1138,7 @@ "source": [ "import collections\n", "\n", - "from catalyst.dl import Callback, CallbackOrder, State\n", + "from catalyst.dl import Callback, CallbackOrder\n", "\n", "\n", "class CustomInferCallback(Callback):\n", @@ -1146,13 +1146,13 @@ " super().__init__(CallbackOrder.Internal)\n", " self.class_counts = collections.defaultdict(lambda: 0)\n", "\n", - " def on_loader_start(self, state: State):\n", + " def on_loader_start(self, runner: _Runner):\n", " self.class_counts = collections.defaultdict(lambda: 0)\n", "\n", - " def on_batch_end(self, state: State):\n", + " def on_batch_end(self, runner: _Runner):\n", " # data from the Dataloader\n", - " # features, targets = state.input[\"features\"], state.input[\"targets\"]\n", - " logits = state.output[\"logits\"]\n", + " # features, targets = runner.input[\"features\"], runner.input[\"targets\"]\n", + " logits = runner.output[\"logits\"]\n", "\n", " labels = logits.argmax(axis=1)\n", " labels = labels.cpu().detach().numpy().tolist()\n", @@ -1208,7 +1208,7 @@ "import numpy as np\n", "%matplotlib inline\n", "\n", - "class_counts = runner.state._callbacks[\"infer\"].class_counts\n", + "class_counts = runner.runner._callbacks[\"infer\"].class_counts\n", "counts = [class_counts[x] for x in range(len(class_names))]\n", "\n", "plt.figure(figsize=(20, 9))\n", @@ -1498,17 +1498,8 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "source": [], - "metadata": { - "collapsed": false - } - } } }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/notebooks/customizing_what_happens_in_train.ipynb b/examples/notebooks/customizing_what_happens_in_train.ipynb index 932986456e..812983424d 100644 --- a/examples/notebooks/customizing_what_happens_in_train.ipynb +++ b/examples/notebooks/customizing_what_happens_in_train.ipynb @@ -118,7 +118,7 @@ "\n", "The input argument `batch` is what gets passed to fit as training data. If you pass a `torch.utils.data.DataLoader`, by calling `train(loaders={\"train\": loader, \"valid\": loader}, ...)`, then `batch` will be what gets yielded by `loader` at each batch.\n", "\n", - "In the body of the `_handle_batch` method, we implement a regular training update, similar to what you are already familiar with. Importantly, **we log metrics via `self.state.batch_metrics`**, which passes them to the loggers." + "In the body of the `_handle_batch` method, we implement a regular training update, similar to what you are already familiar with. Importantly, **we log metrics via `self.batch_metrics`**, which passes them to the loggers." ] }, { @@ -147,15 +147,15 @@ " loss = F.mse_loss(y_pred, y)\n", "\n", " # Update metrics (includes the metric that tracks the loss)\n", - " self.state.batch_metrics.update({\"loss\": loss, \"mae\": F.l1_loss(y_pred, y)})\n", + " self.batch_metrics.update({\"loss\": loss, \"mae\": F.l1_loss(y_pred, y)})\n", "\n", - " if self.state.is_train_loader:\n", + " if self.is_train_loader:\n", " # Compute gradients\n", " loss.backward()\n", " # Update weights\n", " # (the optimizer is stored in `self.state`)\n", - " self.state.optimizer.step()\n", - " self.state.optimizer.zero_grad()" + " self.optimizer.step()\n", + " self.optimizer.zero_grad()" ] }, { @@ -256,10 +256,10 @@ "\n", " # Compute the loss value\n", " # (the criterion is stored in `self.state` also)\n", - " loss = self.state.criterion(y_pred, y)\n", + " loss = self.criterion(y_pred, y)\n", "\n", " # Update metrics (includes the metric that tracks the loss)\n", - " self.state.batch_metrics.update({\"loss\": loss, \"mae\": F.l1_loss(y_pred, y)})\n", + " self.batch_metrics.update({\"loss\": loss, \"mae\": F.l1_loss(y_pred, y)})\n", "\n", "\n", "# Construct custom data\n", @@ -288,7 +288,7 @@ " callbacks={\n", " \"optimizer\": dl.OptimizerCallback(\n", " metric_key=\"loss\", # you can also pass 'mae' to optimize it instead\n", - " # generaly, you can optimize any differentiable metric from `state.batch_metrics`\n", + " # generaly, you can optimize any differentiable metric from `runner.batch_metrics`\n", " accumulation_steps=1, # also you can pass any number of steps for gradient accumulation\n", " grad_clip_params=None, # or yor use `{\"func\": \"clip_grad_norm_\", max_norm=1, norm_type=2}`\n", " # or `{\"func\": \"clip_grad_value_\", clip_value=1}`\n", @@ -312,8 +312,8 @@ "Let's go even deeper! Could we transfer different metrics/criterions computation to `Callbacks` too? Of course! If you want to support different losses, you'd simply do the following:\n", "\n", "- Do your model forward pass as usual.\n", - "- Save model input to `state.input` and model output to `state.output`, so Callbacks can find it.\n", - "- Add extra callbacks, that will use data from `state.input` and `state.output` for computation.\n", + "- Save model input to `runner.input` and model output to `runner.output`, so Callbacks can find it.\n", + "- Add extra callbacks, that will use data from `runner.input` and `runner.output` for computation.\n", "\n", "That's it. That's the list. Let's see the example:" ] @@ -340,10 +340,10 @@ " y_pred = self.model(x) # Forward pass\n", " \n", " # pass network input to state `input`\n", - " self.state.input = {\"features\": x, \"targets\": y}\n", + " self.input = {\"features\": x, \"targets\": y}\n", " # and network output to state `output`\n", " # we recommend to use key-value storage to make it Callbacks-friendly\n", - " self.state.output = {\"logits\": y_pred}\n", + " self.output = {\"logits\": y_pred}\n", "\n", "\n", "# Construct custom data\n", @@ -371,16 +371,16 @@ " timeit=False,\n", " callbacks={\n", " \"criterion\": dl.CriterionCallback( # special Callback for criterion computation\n", - " input_key=\"targets\", # `input_key` specifies correct labels (or `y_true`) from `state.input` \n", - " output_key=\"logits\", # `output_key` specifies model predictions (`y_pred`) from `state.output`\n", - " prefix=\"loss\", # `prefix` - key to use with `state.batch_metrics`\n", - " ), # alias for `state.batch_metrics[prefix] = state.criterion(state.output[output_key], state.input[input_key])`\n", + " input_key=\"targets\", # `input_key` specifies correct labels (or `y_true`) from `runner.input` \n", + " output_key=\"logits\", # `output_key` specifies model predictions (`y_pred`) from `runner.output`\n", + " prefix=\"loss\", # `prefix` - key to use with `runner.batch_metrics`\n", + " ), # alias for `runner.batch_metrics[prefix] = runner.criterion(runner.output[output_key], runner.input[input_key])`\n", " \"metric\": dl.MetricCallback( # special Callback for metrics computation\n", " input_key=\"targets\", # shares logic with `CriterionCallback`\n", " output_key=\"logits\",\n", " prefix=\"loss_mae\",\n", " metric_fn=F.l1_loss, # metric function to use\n", - " ), # alias for `state.batch_metrics[prefix] = metric_fn(state.output[output_key], state.input[input_key])`\n", + " ), # alias for `runner.batch_metrics[prefix] = metric_fn(runner.output[output_key], runner.input[input_key])`\n", " \"optimizer\": dl.OptimizerCallback(\n", " metric_key=\"loss\", \n", " accumulation_steps=1,\n", @@ -427,12 +427,12 @@ "# Just use `train` as usual\n", "runner = dl.SupervisedRunner( # `SupervisedRunner` works with any model like `some_output = model(some_input)`\n", " input_key=\"features\", # if your dataloader yields (x, y) tuple, it will be transformed to \n", - " output_key=\"logits\", # {input_key: x, input_target_key: y} and stored to state.input\n", + " output_key=\"logits\", # {input_key: x, input_target_key: y} and stored to runner.input\n", " input_target_key=\"targets\", # then the model will be used like\n", - ") # state.output = model(state.input[input_key])\n", + ") # runner.output = model(runner.input[input_key])\n", " # loss computation suppose to looks like\n", - " # loss = criterion(state.output[input_target_key], state.output[output_key])\n", - " # and stored to `state.batch_metrics['loss']`\n", + " # loss = criterion(runner.output[input_target_key], runner.output[output_key])\n", + " # and stored to `runner.batch_metrics['loss']`\n", "\n", "runner.train(\n", " model=model, \n", @@ -503,15 +503,15 @@ " loss = F.mse_loss(y_pred, y)\n", "\n", " # Update metrics (includes the metric that tracks the loss)\n", - " self.state.batch_metrics.update({\"loss\": loss, \"mae\": F.l1_loss(y_pred, y)})\n", + " self.batch_metrics.update({\"loss\": loss, \"mae\": F.l1_loss(y_pred, y)})\n", "\n", - " if self.state.is_train_loader:\n", + " if self.is_train_loader:\n", " # Compute gradients\n", " loss.backward()\n", " # Update weights\n", " # (the optimizer is stored in `self.state`)\n", - " self.state.optimizer.step()\n", - " self.state.optimizer.zero_grad()" + " self.optimizer.step()\n", + " self.optimizer.zero_grad()" ] }, { @@ -715,7 +715,7 @@ " batch_metrics[\"loss_generator\"] = \\\n", " F.binary_cross_entropy_with_logits(predictions, misleading_labels)\n", " \n", - " self.state.batch_metrics.update(**batch_metrics)" + " self.batch_metrics.update(**batch_metrics)" ] }, { diff --git a/examples/notebooks/demo.ipynb b/examples/notebooks/demo.ipynb index 4a88a2cbf3..32031aa7e8 100644 --- a/examples/notebooks/demo.ipynb +++ b/examples/notebooks/demo.ipynb @@ -165,17 +165,17 @@ " loss = F.cross_entropy(y_hat, y)\n", " accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5))\n", " \n", - " self.state.batch_metrics = {\n", + " self.batch_metrics = {\n", " \"loss\": loss,\n", " \"accuracy01\": accuracy01,\n", " \"accuracy03\": accuracy03,\n", " \"accuracy05\": accuracy05,\n", " }\n", " \n", - " if self.state.is_train_loader:\n", + " if self.is_train_loader:\n", " loss.backward()\n", - " self.state.optimizer.step()\n", - " self.state.optimizer.zero_grad()\n", + " self.optimizer.step()\n", + " self.optimizer.zero_grad()\n", " " ] }, @@ -273,7 +273,7 @@ " loss = loss_clf + loss_ae\n", " accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5))\n", " \n", - " self.state.batch_metrics = {\n", + " self.batch_metrics = {\n", " \"loss_clf\": loss_clf,\n", " \"loss_ae\": loss_ae,\n", " \"loss\": loss,\n", @@ -282,10 +282,10 @@ " \"accuracy05\": accuracy05,\n", " }\n", " \n", - " if self.state.is_train_loader:\n", + " if self.is_train_loader:\n", " loss.backward()\n", - " self.state.optimizer.step()\n", - " self.state.optimizer.zero_grad()\n", + " self.optimizer.step()\n", + " self.optimizer.zero_grad()\n", " " ] }, @@ -424,7 +424,7 @@ " loss = loss_clf + loss_ae + loss_kld + loss_logprob\n", " accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5))\n", " \n", - " self.state.batch_metrics = {\n", + " self.batch_metrics = {\n", " \"loss_clf\": loss_clf,\n", " \"loss_ae\": loss_ae,\n", " \"loss_kld\": loss_kld,\n", @@ -435,10 +435,10 @@ " \"accuracy05\": accuracy05,\n", " }\n", " \n", - " if self.state.is_train_loader:\n", + " if self.is_train_loader:\n", " loss.backward()\n", - " self.state.optimizer.step()\n", - " self.state.optimizer.zero_grad()\n", + " self.optimizer.step()\n", + " self.optimizer.zero_grad()\n", " " ] }, @@ -543,7 +543,7 @@ " loss = loss_clf + loss_iou\n", " accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat, y, topk=(1, 3, 5))\n", " \n", - " self.state.batch_metrics = {\n", + " self.batch_metrics = {\n", " \"loss_clf\": loss_clf,\n", " \"loss_iou\": loss_iou,\n", " \"loss\": loss,\n", @@ -553,10 +553,10 @@ " \"accuracy05\": accuracy05,\n", " }\n", " \n", - " if self.state.is_train_loader:\n", + " if self.is_train_loader:\n", " loss.backward()\n", - " self.state.optimizer.step()\n", - " self.state.optimizer.zero_grad()\n", + " self.optimizer.step()\n", + " self.optimizer.zero_grad()\n", " " ], "metadata": { @@ -702,7 +702,7 @@ " batch_metrics[\"loss_generator\"] = \\\n", " F.binary_cross_entropy_with_logits(predictions, misleading_labels)\n", " \n", - " self.state.batch_metrics.update(**batch_metrics)" + " self.batch_metrics.update(**batch_metrics)" ] }, { @@ -1198,22 +1198,22 @@ "metadata": {}, "outputs": [], "source": [ - "from catalyst.dl import Callback, CallbackOrder, State\n", + "from catalyst.dl import Callback, CallbackOrder, _Runner\n", "\n", "class NdcgLoaderMetricCallback(Callback):\n", " def __init__(self):\n", " super().__init__(CallbackOrder.Metric)\n", "\n", - " def on_batch_end(self, state: State):\n", - " item = state.input[\"item\"]\n", - " predictions = state.output[\"logits\"]\n", + " def on_batch_end(self, runner: _Runner):\n", + " item = runner.input[\"item\"]\n", + " predictions = runner.output[\"logits\"]\n", "\n", " _, indices = torch.topk(predictions, top_k)\n", " recommended = torch.take(item, indices).cpu().numpy().tolist()\n", "\n", " item = item[0].item()\n", - " state.batch_metrics[\"hits\"] = hit_metric(recommended, item)\n", - " state.batch_metrics[\"dcgs\"] = dcg_metric(recommended, item)" + " runner.batch_metrics[\"hits\"] = hit_metric(recommended, item)\n", + " runner.batch_metrics[\"dcgs\"] = dcg_metric(recommended, item)" ] }, { @@ -1261,15 +1261,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "source": [], - "metadata": { - "collapsed": false - } - } } }, "nbformat": 4, diff --git a/examples/notebooks/segmentation-tutorial.ipynb b/examples/notebooks/segmentation-tutorial.ipynb index 110a43ca2c..98b87ce20a 100644 --- a/examples/notebooks/segmentation-tutorial.ipynb +++ b/examples/notebooks/segmentation-tutorial.ipynb @@ -1162,7 +1162,7 @@ "source": [ "import collections\n", "\n", - "from catalyst.dl import Callback, CallbackOrder, State\n", + "from catalyst.dl import Callback, CallbackOrder, _Runner\n", "\n", "\n", "class CustomInferCallback(Callback):\n", @@ -1171,14 +1171,14 @@ " self.heatmap = None\n", " self.counter = 0\n", "\n", - " def on_loader_start(self, state: State):\n", + " def on_loader_start(self, runner: _Runner):\n", " self.predictions = None\n", " self.counter = 0\n", "\n", - " def on_batch_end(self, state: State):\n", + " def on_batch_end(self, runner: _Runner):\n", " # data from the Dataloader\n", - " # image, mask = state.input[\"image\"], state.input[\"mask\"]\n", - " logits = state.output[\"logits\"]\n", + " # image, mask = runner.input[\"image\"], runner.input[\"mask\"]\n", + " logits = runner.output[\"logits\"]\n", " probabilities = torch.sigmoid(logits)\n", "\n", " self.heatmap = (\n", @@ -1188,7 +1188,7 @@ " )\n", " self.counter += len(probabilities)\n", "\n", - " def on_loader_end(self, state: State):\n", + " def on_loader_end(self, runner: _Runner):\n", " self.heatmap = self.heatmap.sum(axis=0)\n", " self.heatmap /= self.counter" ] @@ -1242,7 +1242,7 @@ "%matplotlib inline \n", "import matplotlib.pyplot as plt\n", "\n", - "heatmap = utils.detach(runner.state.callbacks[\"infer\"].heatmap[0])\n", + "heatmap = utils.detach(runner.runner.callbacks[\"infer\"].heatmap[0])\n", "plt.figure(figsize=(20, 9))\n", "plt.imshow(heatmap, cmap=\"hot\", interpolation=\"nearest\")\n", "plt.show()" @@ -1366,17 +1366,8 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } } }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file diff --git a/tests/_tests_contrib_dl_callbacks/config0.yml b/tests/_tests_contrib_dl_callbacks/config0.yml index 6de0d24c55..e99278e281 100644 --- a/tests/_tests_contrib_dl_callbacks/config0.yml +++ b/tests/_tests_contrib_dl_callbacks/config0.yml @@ -18,7 +18,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -31,7 +31,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_contrib_dl_callbacks/config1.yml b/tests/_tests_contrib_dl_callbacks/config1.yml index 6a6dc4af65..721581f226 100644 --- a/tests/_tests_contrib_dl_callbacks/config1.yml +++ b/tests/_tests_contrib_dl_callbacks/config1.yml @@ -18,7 +18,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -31,7 +31,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_contrib_dl_callbacks/config2.yml b/tests/_tests_contrib_dl_callbacks/config2.yml index db4b0dcfb5..9feddd8072 100644 --- a/tests/_tests_contrib_dl_callbacks/config2.yml +++ b/tests/_tests_contrib_dl_callbacks/config2.yml @@ -19,7 +19,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -32,7 +32,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 10 scheduler_params: diff --git a/tests/_tests_contrib_dl_callbacks/config3.yml b/tests/_tests_contrib_dl_callbacks/config3.yml index 96e053b543..d1ad40404b 100644 --- a/tests/_tests_contrib_dl_callbacks/config3.yml +++ b/tests/_tests_contrib_dl_callbacks/config3.yml @@ -19,7 +19,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -32,7 +32,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: 5 scheduler_params: @@ -60,7 +60,7 @@ stages: valid_additional: 0 stage2: - state_params: + stage_params: num_epochs: 10 scheduler_params: diff --git a/tests/_tests_contrib_dl_callbacks/config4.yml b/tests/_tests_contrib_dl_callbacks/config4.yml index fd402511c0..1d3580795c 100644 --- a/tests/_tests_contrib_dl_callbacks/config4.yml +++ b/tests/_tests_contrib_dl_callbacks/config4.yml @@ -19,7 +19,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -32,7 +32,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 10 scheduler_params: diff --git a/tests/_tests_contrib_dl_callbacks/config5.yml b/tests/_tests_contrib_dl_callbacks/config5.yml index 59ac28237f..335b746e57 100644 --- a/tests/_tests_contrib_dl_callbacks/config5.yml +++ b/tests/_tests_contrib_dl_callbacks/config5.yml @@ -19,7 +19,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -32,7 +32,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 10 scheduler_params: diff --git a/tests/_tests_cv_classification/config1.yml b/tests/_tests_cv_classification/config1.yml index 6e47a05977..35be306ba9 100644 --- a/tests/_tests_cv_classification/config1.yml +++ b/tests/_tests_cv_classification/config1.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 100 scheduler_params: diff --git a/tests/_tests_cv_classification/config2.yml b/tests/_tests_cv_classification/config2.yml index 2f9f36cd54..73867afd0d 100644 --- a/tests/_tests_cv_classification/config2.yml +++ b/tests/_tests_cv_classification/config2.yml @@ -11,7 +11,7 @@ stages: batch_size: 64 num_workers: 0 - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -29,7 +29,7 @@ stages: gamma: 0.3 stage1: - state_params: + stage_params: num_epochs: 100 callbacks_params: diff --git a/tests/_tests_cv_classification/config2_infer.yml b/tests/_tests_cv_classification/config2_infer.yml index 1d123725df..938c72a651 100644 --- a/tests/_tests_cv_classification/config2_infer.yml +++ b/tests/_tests_cv_classification/config2_infer.yml @@ -10,7 +10,7 @@ stages: batch_size: 64 num_workers: 0 - state_params: + stage_params: num_epochs: 1 infer: diff --git a/tests/_tests_cv_classification/config3.yml b/tests/_tests_cv_classification/config3.yml index dd04a712f8..91b6991769 100644 --- a/tests/_tests_cv_classification/config3.yml +++ b/tests/_tests_cv_classification/config3.yml @@ -11,7 +11,7 @@ stages: batch_size: 32 num_workers: 0 - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -31,7 +31,7 @@ stages: gamma: 0.3 stage1: - state_params: + stage_params: num_epochs: 100 callbacks_params: diff --git a/tests/_tests_cv_classification/config4.yml b/tests/_tests_cv_classification/config4.yml index 230bdaf4e0..a5f2a46fa8 100644 --- a/tests/_tests_cv_classification/config4.yml +++ b/tests/_tests_cv_classification/config4.yml @@ -11,7 +11,7 @@ stages: batch_size: 64 num_workers: 0 - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -40,7 +40,7 @@ stages: stage1: - state_params: + stage_params: num_epochs: 100 optimizer_params: @@ -54,7 +54,7 @@ stages: no_bias_weight_decay: True stage2: - state_params: + stage_params: num_epochs: 100 optimizer_params: diff --git a/tests/_tests_cv_classification/config5.yml b/tests/_tests_cv_classification/config5.yml index 41309f95c6..f8e819d7ff 100644 --- a/tests/_tests_cv_classification/config5.yml +++ b/tests/_tests_cv_classification/config5.yml @@ -11,7 +11,7 @@ stages: batch_size: 64 num_workers: 0 - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -40,7 +40,7 @@ stages: stage1: - state_params: + stage_params: num_epochs: 100 optimizer_params: @@ -54,7 +54,7 @@ stages: no_bias_weight_decay: True stage2: - state_params: + stage_params: num_epochs: 100 optimizer_params: diff --git a/tests/_tests_cv_classification/config6_finder.yml b/tests/_tests_cv_classification/config6_finder.yml index 0b622402ff..1f81627ca2 100644 --- a/tests/_tests_cv_classification/config6_finder.yml +++ b/tests/_tests_cv_classification/config6_finder.yml @@ -20,7 +20,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: 100 callbacks_params: diff --git a/tests/_tests_cv_classification_experiment_registry/test1/config1.yml b/tests/_tests_cv_classification_experiment_registry/test1/config1.yml index 65dcb04718..d8f7b05370 100644 --- a/tests/_tests_cv_classification_experiment_registry/test1/config1.yml +++ b/tests/_tests_cv_classification_experiment_registry/test1/config1.yml @@ -20,7 +20,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -33,7 +33,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 100 scheduler_params: diff --git a/tests/_tests_cv_classification_experiment_registry/test2/config1.yml b/tests/_tests_cv_classification_experiment_registry/test2/config1.yml index f9d3c7b3d1..0625e6ead4 100644 --- a/tests/_tests_cv_classification_experiment_registry/test2/config1.yml +++ b/tests/_tests_cv_classification_experiment_registry/test2/config1.yml @@ -20,7 +20,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -33,7 +33,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 100 scheduler_params: diff --git a/tests/_tests_cv_classification_experiment_registry/test2/config2.yml b/tests/_tests_cv_classification_experiment_registry/test2/config2.yml index f0ccacbcbd..645f06d4d7 100644 --- a/tests/_tests_cv_classification_experiment_registry/test2/config2.yml +++ b/tests/_tests_cv_classification_experiment_registry/test2/config2.yml @@ -20,7 +20,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -33,7 +33,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 100 scheduler_params: diff --git a/tests/_tests_cv_classification_transforms/config1.yml b/tests/_tests_cv_classification_transforms/config1.yml index cf813ed0ea..172872d5cf 100644 --- a/tests/_tests_cv_classification_transforms/config1.yml +++ b/tests/_tests_cv_classification_transforms/config1.yml @@ -23,7 +23,7 @@ stages: std: [0.3081] - transform: catalyst.ToTensor # the same as `C.ToTensor` - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -36,7 +36,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 100 scheduler_params: diff --git a/tests/_tests_cv_classification_transforms/config2.yml b/tests/_tests_cv_classification_transforms/config2.yml index e2ab3db059..cdda428685 100644 --- a/tests/_tests_cv_classification_transforms/config2.yml +++ b/tests/_tests_cv_classification_transforms/config2.yml @@ -34,7 +34,7 @@ stages: std: [0.3081] - transform: catalyst.ToTensor - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -52,7 +52,7 @@ stages: gamma: 0.3 stage1: - state_params: + stage_params: num_epochs: 100 callbacks_params: diff --git a/tests/_tests_cv_classification_transforms/config3.yml b/tests/_tests_cv_classification_transforms/config3.yml index ca86e42f9c..2bef877f5c 100644 --- a/tests/_tests_cv_classification_transforms/config3.yml +++ b/tests/_tests_cv_classification_transforms/config3.yml @@ -27,7 +27,7 @@ stages: - transform: catalyst.ToTensor valid: *transform - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -40,7 +40,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 100 scheduler_params: diff --git a/tests/_tests_cv_classification_transforms/config4_finder.yml b/tests/_tests_cv_classification_transforms/config4_finder.yml index 351851d014..f420047df7 100644 --- a/tests/_tests_cv_classification_transforms/config4_finder.yml +++ b/tests/_tests_cv_classification_transforms/config4_finder.yml @@ -32,7 +32,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: 100 callbacks_params: diff --git a/tests/_tests_cv_classification_transforms/config5_fp16.yml b/tests/_tests_cv_classification_transforms/config5_fp16.yml index 911d9b6f6e..29eb39b60e 100644 --- a/tests/_tests_cv_classification_transforms/config5_fp16.yml +++ b/tests/_tests_cv_classification_transforms/config5_fp16.yml @@ -41,7 +41,7 @@ stages: std: [0.3081] - transform: catalyst.ToTensor - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -59,7 +59,7 @@ stages: gamma: 0.3 stage1: - state_params: + stage_params: num_epochs: 100 callbacks_params: diff --git a/tests/_tests_cv_segmentation/config.yml b/tests/_tests_cv_segmentation/config.yml index 2b7c313763..888d596ebc 100644 --- a/tests/_tests_cv_segmentation/config.yml +++ b/tests/_tests_cv_segmentation/config.yml @@ -21,7 +21,7 @@ stages: mask_path: ./_tests_cv_segmentation/data/segmentation_data/train_masks valid_size: 0.2 - state_params: + stage_params: main_metric: iou minimize_metric: False @@ -46,7 +46,7 @@ stages: no_bias_weight_decay: True stage1: - state_params: + stage_params: num_epochs: 3 callbacks_params: diff --git a/tests/_tests_dl_callbacks/config0.yml b/tests/_tests_dl_callbacks/config0.yml index 73d41ffc75..25f2af3d17 100644 --- a/tests/_tests_dl_callbacks/config0.yml +++ b/tests/_tests_dl_callbacks/config0.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config1.yml b/tests/_tests_dl_callbacks/config1.yml index b07220c42b..52513206ec 100644 --- a/tests/_tests_dl_callbacks/config1.yml +++ b/tests/_tests_dl_callbacks/config1.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config10.yml b/tests/_tests_dl_callbacks/config10.yml index d59fcb9473..ee39f619e9 100644 --- a/tests/_tests_dl_callbacks/config10.yml +++ b/tests/_tests_dl_callbacks/config10.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -57,7 +57,7 @@ stages: load_on_stage_end: best stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config11.yml b/tests/_tests_dl_callbacks/config11.yml index 3fc20c4a8f..884970b65f 100644 --- a/tests/_tests_dl_callbacks/config11.yml +++ b/tests/_tests_dl_callbacks/config11.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -45,9 +45,9 @@ stages: reduced_metric: *reduced_metric stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 stage2: - state_params: + stage_params: num_epochs: *num_epochs diff --git a/tests/_tests_dl_callbacks/config12.yml b/tests/_tests_dl_callbacks/config12.yml index 72ea8b42d4..8d029cf998 100644 --- a/tests/_tests_dl_callbacks/config12.yml +++ b/tests/_tests_dl_callbacks/config12.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config13.yml b/tests/_tests_dl_callbacks/config13.yml index 9d3f6de09c..6d45ef7cfc 100644 --- a/tests/_tests_dl_callbacks/config13.yml +++ b/tests/_tests_dl_callbacks/config13.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config14.yml b/tests/_tests_dl_callbacks/config14.yml index adb19cdfd7..d0d1822c1d 100644 --- a/tests/_tests_dl_callbacks/config14.yml +++ b/tests/_tests_dl_callbacks/config14.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -56,7 +56,7 @@ stages: save_n_best: 1 stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: @@ -83,7 +83,7 @@ stages: load_on_stage_start: last stage3: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config15.yml b/tests/_tests_dl_callbacks/config15.yml index 6415ad8f0b..3f88c3d01d 100644 --- a/tests/_tests_dl_callbacks/config15.yml +++ b/tests/_tests_dl_callbacks/config15.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -56,7 +56,7 @@ stages: save_n_best: 1 stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: @@ -85,7 +85,7 @@ stages: optimizer: last stage3: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config16.yml b/tests/_tests_dl_callbacks/config16.yml index 52dc5a04eb..1187e9f701 100644 --- a/tests/_tests_dl_callbacks/config16.yml +++ b/tests/_tests_dl_callbacks/config16.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -56,7 +56,7 @@ stages: save_n_best: 1 stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: @@ -84,7 +84,7 @@ stages: optimizer: last stage3: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config17.yml b/tests/_tests_dl_callbacks/config17.yml index 74674d551e..3925c4d580 100644 --- a/tests/_tests_dl_callbacks/config17.yml +++ b/tests/_tests_dl_callbacks/config17.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -56,7 +56,7 @@ stages: save_n_best: 1 stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config18.yml b/tests/_tests_dl_callbacks/config18.yml index 99056e5ac8..85c8b3e74a 100644 --- a/tests/_tests_dl_callbacks/config18.yml +++ b/tests/_tests_dl_callbacks/config18.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config19.yml b/tests/_tests_dl_callbacks/config19.yml index 9e90793fb0..acfc367fe4 100644 --- a/tests/_tests_dl_callbacks/config19.yml +++ b/tests/_tests_dl_callbacks/config19.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -56,7 +56,7 @@ stages: stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config2.yml b/tests/_tests_dl_callbacks/config2.yml index 7b7629ca86..271793a7e2 100644 --- a/tests/_tests_dl_callbacks/config2.yml +++ b/tests/_tests_dl_callbacks/config2.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -56,7 +56,7 @@ stages: save_n_best: &nbest 1 stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: @@ -82,7 +82,7 @@ stages: save_n_best: *nbest stage3: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config20.yml b/tests/_tests_dl_callbacks/config20.yml index af3382f6a1..45587d30cb 100644 --- a/tests/_tests_dl_callbacks/config20.yml +++ b/tests/_tests_dl_callbacks/config20.yml @@ -20,7 +20,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -33,7 +33,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config3.yml b/tests/_tests_dl_callbacks/config3.yml index d98471bf23..8265f9a642 100644 --- a/tests/_tests_dl_callbacks/config3.yml +++ b/tests/_tests_dl_callbacks/config3.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config4.yml b/tests/_tests_dl_callbacks/config4.yml index 858c6fb185..9051ed9cd0 100644 --- a/tests/_tests_dl_callbacks/config4.yml +++ b/tests/_tests_dl_callbacks/config4.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -56,7 +56,7 @@ stages: save_n_best: &nbest 3 stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: @@ -82,7 +82,7 @@ stages: save_n_best: *nbest stage3: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config5.yml b/tests/_tests_dl_callbacks/config5.yml index aad83c0e66..ea874a0e74 100644 --- a/tests/_tests_dl_callbacks/config5.yml +++ b/tests/_tests_dl_callbacks/config5.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config6.yml b/tests/_tests_dl_callbacks/config6.yml index 6575066a72..ec2db119ae 100644 --- a/tests/_tests_dl_callbacks/config6.yml +++ b/tests/_tests_dl_callbacks/config6.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -57,7 +57,7 @@ stages: load_on_stage_end: last_full stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: @@ -84,7 +84,7 @@ stages: load_on_stage_end: last_full stage3: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config7.yml b/tests/_tests_dl_callbacks/config7.yml index abeb3ecbdc..a1bb25290c 100644 --- a/tests/_tests_dl_callbacks/config7.yml +++ b/tests/_tests_dl_callbacks/config7.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: diff --git a/tests/_tests_dl_callbacks/config8.yml b/tests/_tests_dl_callbacks/config8.yml index 2174985d66..f2d6722ee0 100644 --- a/tests/_tests_dl_callbacks/config8.yml +++ b/tests/_tests_dl_callbacks/config8.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -57,7 +57,7 @@ stages: load_on_stage_end: &end_load best_full stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: @@ -84,7 +84,7 @@ stages: load_on_stage_end: *end_load stage3: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_dl_callbacks/config9.yml b/tests/_tests_dl_callbacks/config9.yml index 9104008187..b0272edaba 100644 --- a/tests/_tests_dl_callbacks/config9.yml +++ b/tests/_tests_dl_callbacks/config9.yml @@ -17,7 +17,7 @@ stages: num_workers: 1 drop_last: True - state_params: + stage_params: main_metric: &reduced_metric accuracy01 minimize_metric: False @@ -30,7 +30,7 @@ stages: weight_decay: 0.0001 stage1: - state_params: + stage_params: num_epochs: &num_epochs 5 scheduler_params: @@ -57,7 +57,7 @@ stages: load_on_stage_end: best stage2: - state_params: + stage_params: num_epochs: *num_epochs scheduler_params: diff --git a/tests/_tests_nlp_classification/config1_basic.yml b/tests/_tests_nlp_classification/config1_basic.yml index c31fdcf38c..ed823f74e5 100644 --- a/tests/_tests_nlp_classification/config1_basic.yml +++ b/tests/_tests_nlp_classification/config1_basic.yml @@ -28,7 +28,7 @@ stages: label_field: "label" max_sequence_length: 512 - state_params: + stage_params: main_metric: &reduced_metric loss minimize_metric: True @@ -60,7 +60,7 @@ stages: # params specific for stage 1 called "train_val" train_val: # overriding state params and specifying that we train for 3 epochs - state_params: + stage_params: num_epochs: 3 # optimizer params are specific only for this stage # in principle, we can define another stage with other optim params diff --git a/tests/_tests_nlp_classification/config2_small_max_seq_length.yml b/tests/_tests_nlp_classification/config2_small_max_seq_length.yml index 66b3ff8f84..30e29702dc 100644 --- a/tests/_tests_nlp_classification/config2_small_max_seq_length.yml +++ b/tests/_tests_nlp_classification/config2_small_max_seq_length.yml @@ -28,7 +28,7 @@ stages: label_field: "label" max_sequence_length: 24 - state_params: + stage_params: main_metric: &reduced_metric loss minimize_metric: True @@ -55,7 +55,7 @@ stages: # params specific for stage 1 called "train_val" train_val: # overriding state params and specifying that we train for 2 epochs - state_params: + stage_params: num_epochs: 2 # optimizer params are specific only for this stage # in principle, we can define another stage with other optim params diff --git a/tests/_tests_scripts/dl_z_mvp_distributed_mnist_ae.py b/tests/_tests_scripts/dl_z_mvp_distributed_mnist_ae.py index 217dfcfc78..af6a60d9c5 100644 --- a/tests/_tests_scripts/dl_z_mvp_distributed_mnist_ae.py +++ b/tests/_tests_scripts/dl_z_mvp_distributed_mnist_ae.py @@ -58,7 +58,7 @@ def _handle_batch(self, batch): y_hat, y, topk=(1, 3, 5) ) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_ae": loss_ae, "loss": loss, @@ -67,10 +67,10 @@ def _handle_batch(self, batch): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() def datasets_fn(): diff --git a/tests/_tests_scripts/dl_z_mvp_mnist.py b/tests/_tests_scripts/dl_z_mvp_mnist.py index 0e082f6ece..5338139e86 100644 --- a/tests/_tests_scripts/dl_z_mvp_mnist.py +++ b/tests/_tests_scripts/dl_z_mvp_mnist.py @@ -33,14 +33,14 @@ def _handle_batch(self, batch): loss = F.cross_entropy(y_hat, y) accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3)) - self.state.batch_metrics.update( + self.batch_metrics.update( {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03} ) - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() def main(): diff --git a/tests/_tests_scripts/dl_z_mvp_mnist_ae.py b/tests/_tests_scripts/dl_z_mvp_mnist_ae.py index 76b012e590..e2d542ddde 100644 --- a/tests/_tests_scripts/dl_z_mvp_mnist_ae.py +++ b/tests/_tests_scripts/dl_z_mvp_mnist_ae.py @@ -42,7 +42,7 @@ def _handle_batch(self, batch): y_hat, y, topk=(1, 3, 5) ) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_ae": loss_ae, "loss": loss, @@ -51,10 +51,10 @@ def _handle_batch(self, batch): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() def main(): diff --git a/tests/_tests_scripts/dl_z_mvp_mnist_gan.py b/tests/_tests_scripts/dl_z_mvp_mnist_gan.py index 3730f60c63..45844b0e04 100644 --- a/tests/_tests_scripts/dl_z_mvp_mnist_gan.py +++ b/tests/_tests_scripts/dl_z_mvp_mnist_gan.py @@ -59,7 +59,7 @@ def _handle_batch(self, batch): predictions, misleading_labels ) - self.state.batch_metrics.update(**batch_metrics) + self.batch_metrics.update(**batch_metrics) def main(): diff --git a/tests/_tests_scripts/dl_z_mvp_mnist_unet.py b/tests/_tests_scripts/dl_z_mvp_mnist_unet.py index 2c90c0a9fb..da1028e348 100644 --- a/tests/_tests_scripts/dl_z_mvp_mnist_unet.py +++ b/tests/_tests_scripts/dl_z_mvp_mnist_unet.py @@ -43,7 +43,7 @@ def _handle_batch(self, batch): y_hat, y, topk=(1, 3, 5) ) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_iou": loss_iou, "loss": loss, @@ -53,10 +53,10 @@ def _handle_batch(self, batch): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() def main(): diff --git a/tests/_tests_scripts/dl_z_mvp_mnist_vae.py b/tests/_tests_scripts/dl_z_mvp_mnist_vae.py index c1a01285bc..84a1979450 100644 --- a/tests/_tests_scripts/dl_z_mvp_mnist_vae.py +++ b/tests/_tests_scripts/dl_z_mvp_mnist_vae.py @@ -80,7 +80,7 @@ def _handle_batch(self, batch): y_hat, y, topk=(1, 3, 5) ) - self.state.batch_metrics = { + self.batch_metrics = { "loss_clf": loss_clf, "loss_ae": loss_ae, "loss_kld": loss_kld, @@ -91,10 +91,10 @@ def _handle_batch(self, batch): "accuracy05": accuracy05, } - if self.state.is_train_loader: + if self.is_train_loader: loss.backward() - self.state.optimizer.step() - self.state.optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() def main():