From 0d4cc5df6f56a1a32b3b22539a450ffdfbd7d543 Mon Sep 17 00:00:00 2001 From: chMoussa Date: Wed, 21 Aug 2024 11:55:37 +0200 Subject: [PATCH] [Feature] Adding callback functionalities in train functions (#533) --- docs/tutorials/qml/ml_tools.md | 8 +- pyproject.toml | 3 +- qadence/ml_tools/__init__.py | 3 +- qadence/ml_tools/config.py | 83 +++++++++- qadence/ml_tools/data.py | 28 +++- qadence/ml_tools/optimize_step.py | 5 +- qadence/ml_tools/train_grad.py | 248 ++++++++++++++++++------------ qadence/ml_tools/train_no_grad.py | 118 ++++++++++---- 8 files changed, 361 insertions(+), 135 deletions(-) diff --git a/docs/tutorials/qml/ml_tools.md b/docs/tutorials/qml/ml_tools.md index f2d762ab9..dcb7b6762 100644 --- a/docs/tutorials/qml/ml_tools.md +++ b/docs/tutorials/qml/ml_tools.md @@ -78,19 +78,25 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d The [`TrainConfig`][qadence.ml_tools.config.TrainConfig] tells `train_with_grad` what batch_size should be used, how many epochs to train, in which intervals to print/log metrics and how often to store intermediate checkpoints. +It is also possible to provide custom callback functions by instantiating a [`Callback`][qadence.ml_tools.config.Callback] +with a function `callback` that only accept as argument an instance of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult] created within the `train` functions. +One can also provide a `callback_condition` function, also only accepting an instance of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult], which returns True if `callback` should be called. If no `callback_condition` is provided, `callback` is called at every x epochs (specified by `Callback`'s `called_every` argument). We can also specify in which part of the training function the `Callback` will be applied. ```python exec="on" source="material-block" -from qadence.ml_tools import TrainConfig +from qadence.ml_tools import TrainConfig, Callback batch_size = 5 n_epochs = 100 +custom_callback = Callback(lambda opt_res: print(opt_res.model.parameters()), callback_condition=lambda opt_res: opt_res.loss < 1.0e-03, called_every=10, call_end_epoch=True) + config = TrainConfig( folder="some_path/", max_iter=n_epochs, checkpoint_every=100, write_every=100, batch_size=batch_size, + callbacks = [custom_callback] ) ``` diff --git a/pyproject.toml b/pyproject.toml index 937c7b5f7..aff876831 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ authors = [ ] requires-python = ">=3.9" license = { text = "Apache 2.0" } -version = "1.7.4" +version = "1.7.5" classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", @@ -133,6 +133,7 @@ filterwarnings = [ [tool.hatch.envs.docs] dependencies = [ "mkdocs", + "mkdocs_autorefs<1.1.0", "mkdocs-material", "mkdocstrings", "mkdocstrings-python", diff --git a/qadence/ml_tools/__init__.py b/qadence/ml_tools/__init__.py index ed6336ed6..a06871773 100644 --- a/qadence/ml_tools/__init__.py +++ b/qadence/ml_tools/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .config import AnsatzConfig, FeatureMapConfig, TrainConfig +from .config import AnsatzConfig, Callback, FeatureMapConfig, TrainConfig from .constructors import create_ansatz, create_fm_blocks, observable_from_config from .data import DictDataLoader, InfiniteTensorDataset, to_dataloader from .models import QNN @@ -23,6 +23,7 @@ "observable_from_config", "QNN", "TrainConfig", + "Callback", "train_with_grad", "train_gradient_free", "write_checkpoint", diff --git a/qadence/ml_tools/config.py b/qadence/ml_tools/config.py index 9b35ed660..e1ce89d7b 100644 --- a/qadence/ml_tools/config.py +++ b/qadence/ml_tools/config.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field, fields from logging import getLogger from pathlib import Path -from typing import Callable, Type +from typing import Any, Callable, Type from uuid import uuid4 from sympy import Basic @@ -13,6 +13,7 @@ from qadence.blocks.analog import AnalogBlock from qadence.blocks.primitive import ParametricBlock +from qadence.ml_tools.data import OptimizeResult from qadence.operations import RX, AnalogRX from qadence.parameters import Parameter from qadence.types import ( @@ -27,6 +28,84 @@ logger = getLogger(__file__) +CallbackFunction = Callable[[OptimizeResult], None] +CallbackConditionFunction = Callable[[OptimizeResult], bool] + + +class Callback: + """Callback functions are calling in train functions. + + Each callback function should take at least as first input + an OptimizeResult instance. + + Attributes: + callback (CallbackFunction): Callback function accepting an + OptimizeResult as first argument. + callback_condition (CallbackConditionFunction | None, optional): Function that + conditions the call to callback. Defaults to None. + called_every (int, optional): Callback to be called each `called_every` epoch. + Defaults to 1. + If callback_condition is None, we set + callback_condition to returns True when iteration % every == 0. + call_before_opt (bool, optional): If true, callback is applied before training. + Defaults to False. + call_end_epoch (bool, optional): If true, callback is applied during training, + after an epoch is performed. Defaults to True. + call_after_opt (bool, optional): If true, callback is applied after training. + Defaults to False. + call_during_eval (bool, optional): If true, callback is applied during evaluation. + Defaults to False. + """ + + def __init__( + self, + callback: CallbackFunction, + callback_condition: CallbackConditionFunction | None = None, + called_every: int = 1, + call_before_opt: bool = False, + call_end_epoch: bool = True, + call_after_opt: bool = False, + call_during_eval: bool = False, + ) -> None: + """Initialized Callback. + + Args: + callback (CallbackFunction): Callback function accepting an + OptimizeResult as ifrst argument. + callback_condition (CallbackConditionFunction | None, optional): Function that + conditions the call to callback. Defaults to None. + called_every (int, optional): Callback to be called each `called_every` epoch. + Defaults to 1. + If callback_condition is None, we set + callback_condition to returns True when iteration % every == 0. + call_before_opt (bool, optional): If true, callback is applied before training. + Defaults to False. + call_end_epoch (bool, optional): If true, callback is applied during training, + after an epoch is performed. Defaults to True. + call_after_opt (bool, optional): If true, callback is applied after training. + Defaults to False. + call_during_eval (bool, optional): If true, callback is applied during evaluation. + Defaults to False. + """ + self.callback = callback + self.call_before_opt = call_before_opt + self.call_end_epoch = call_end_epoch + self.call_after_opt = call_after_opt + self.call_during_eval = call_during_eval + + if called_every <= 0: + raise ValueError("Please provide a strictly positive `called_every` argument.") + self.called_every = called_every + + if callback_condition is None: + self.callback_condition = lambda opt_result: True + else: + self.callback_condition = callback_condition + + def __call__(self, opt_result: OptimizeResult) -> Any: + if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result): + return self.callback(opt_result) + @dataclass class TrainConfig: @@ -64,6 +143,8 @@ class TrainConfig: Set to 0 to disable """ + callbacks: list[Callback] = field(default_factory=lambda: list()) + """List of callbacks.""" log_model: bool = False """Logs a serialised version of the model.""" folder: Path | None = None diff --git a/qadence/ml_tools/data.py b/qadence/ml_tools/data.py index 80ad24f62..8e92e7f94 100644 --- a/qadence/ml_tools/data.py +++ b/qadence/ml_tools/data.py @@ -1,15 +1,41 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import singledispatch from itertools import cycle from typing import Any, Iterator +from nevergrad.optimization.base import Optimizer as NGOptimizer from torch import Tensor from torch import device as torch_device +from torch.nn import Module +from torch.optim import Optimizer from torch.utils.data import DataLoader, IterableDataset, TensorDataset +@dataclass +class OptimizeResult: + """OptimizeResult stores many optimization intermediate values. + + We store at a current iteration, + the model, optimizer, loss values, metrics. An extra dict + can be used for saving other information to be used for callbacks. + """ + + iteration: int + """Current iteration number.""" + model: Module + """Model at iteration.""" + optimizer: Optimizer | NGOptimizer + """Optimizer at iteration.""" + loss: Tensor | float | None = None + """Loss value.""" + metrics: dict = field(default_factory=lambda: dict()) + """Metrics that can be saved during training.""" + extra: dict = field(default_factory=lambda: dict()) + """Extra dict for saving anything else to be used in callbacks.""" + + @dataclass class DictDataLoader: """This class only holds a dictionary of `DataLoader`s and samples from them.""" diff --git a/qadence/ml_tools/optimize_step.py b/qadence/ml_tools/optimize_step.py index c067f0923..93bd9ac5c 100644 --- a/qadence/ml_tools/optimize_step.py +++ b/qadence/ml_tools/optimize_step.py @@ -29,10 +29,11 @@ def optimize_step( xs (dict | list | torch.Tensor | None): the input data. If None it means that the given model does not require any input data device (torch.device): A target device to run computation on. + dtype (torch.dtype): Data type for xs conversion. Returns: - tuple: tuple containing the model, the optimizer, a dictionary with - the collected metrics and the compute value loss + tuple: tuple containing the computed loss value, and a dictionary with + the collected metrics. """ loss, metrics = None, {} diff --git a/qadence/ml_tools/train_grad.py b/qadence/ml_tools/train_grad.py index eb20d3f56..12b607935 100644 --- a/qadence/ml_tools/train_grad.py +++ b/qadence/ml_tools/train_grad.py @@ -3,10 +3,10 @@ import importlib import math from logging import getLogger -from typing import Callable, Union +from typing import Any, Callable, Union from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn -from torch import complex128, float32, float64 +from torch import Tensor, complex128, float32, float64 from torch import device as torch_device from torch import dtype as torch_dtype from torch.nn import DataParallel, Module @@ -14,8 +14,8 @@ from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from qadence.ml_tools.config import TrainConfig -from qadence.ml_tools.data import DictDataLoader, data_to_device +from qadence.ml_tools.config import Callback, TrainConfig +from qadence.ml_tools.data import DictDataLoader, OptimizeResult, data_to_device from qadence.ml_tools.optimize_step import optimize_step from qadence.ml_tools.printing import ( log_model_tracker, @@ -160,115 +160,178 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d best_val_loss = math.inf + if not ((dataloader is None) or isinstance(dataloader, (DictDataLoader, DataLoader))): + raise NotImplementedError( + f"Unsupported dataloader type: {type(dataloader)}. " + "You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader." + ) + + def next_loss_iter(dl_iter: Union[None, DataLoader, DictDataLoader]) -> Any: + """Get loss on the next batch of a dataloader. + + loaded on device if not None. + + Args: + dl_iter (Union[None, DataLoader, DictDataLoader]): Dataloader. + + Returns: + Any: Loss value + """ + xs = next(dl_iter) if dl_iter is not None else None + xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) + return loss_fn(model, xs_to_device) + + # populate callbacks with already available internal functions + # printing, writing and plotting + callbacks = config.callbacks + + # printing + if config.verbose and config.print_every > 0: + # Note that the loss returned by optimize_step + # is the value before doing the training step + # which is printed accordingly by the previous iteration number + callbacks += [ + Callback( + lambda opt_res: print_metrics(opt_res.loss, opt_res.metrics, opt_res.iteration - 1), + called_every=config.print_every, + call_after_opt=True, + ) + ] + + # plotting + callbacks += [ + Callback( + lambda opt_res: plot_tracker( + writer, + opt_res.model, + opt_res.iteration, + config.plotting_functions, + tracking_tool=config.tracking_tool, + ), + called_every=config.plot_every, + call_before_opt=True, + ) + ] + + # writing metrics + callbacks += [ + Callback( + lambda opt_res: write_tracker( + writer, + opt_res.loss, + opt_res.metrics, + opt_res.iteration, + tracking_tool=config.tracking_tool, + ), + called_every=config.write_every, + call_before_opt=False, + call_after_opt=True, + call_during_eval=True, + ) + ] + + # checkpointing + if config.folder and config.checkpoint_every > 0 and not config.checkpoint_best_only: + callbacks += [ + Callback( + lambda opt_res: write_checkpoint( + config.folder, # type: ignore[arg-type] + opt_res.model, + opt_res.optimizer, + opt_res.iteration, + ), + called_every=config.checkpoint_every, + call_before_opt=False, + call_after_opt=True, + ) + ] + + if config.folder and config.checkpoint_best_only: + callbacks += [ + Callback( + lambda opt_res: write_checkpoint( + config.folder, # type: ignore[arg-type] + opt_res.model, + opt_res.optimizer, + "best", + ), + called_every=config.checkpoint_every, + call_before_opt=True, + call_after_opt=True, + call_during_eval=True, + ) + ] + + def run_callbacks(callback_iterable: list[Callback], opt_res: OptimizeResult) -> None: + for callback in callback_iterable: + callback(opt_res) + + callbacks_before_opt = [ + callback + for callback in callbacks + if callback.call_before_opt and not callback.call_during_eval + ] + callbacks_before_opt_eval = [ + callback for callback in callbacks if callback.call_before_opt and callback.call_during_eval + ] + with progress: dl_iter = iter(dataloader) if dataloader is not None else None # Initial validation evaluation try: + opt_result = OptimizeResult(init_iter, model, optimizer) if perform_val: dl_iter_val = iter(val_dataloader) if val_dataloader is not None else None - xs = next(dl_iter_val) - xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) - best_val_loss, metrics = loss_fn(model, xs_to_device) - + best_val_loss, metrics, *_ = next_loss_iter(dl_iter_val) metrics["val_loss"] = best_val_loss - write_tracker(writer, None, metrics, init_iter, tracking_tool=config.tracking_tool) + opt_result.metrics = metrics + run_callbacks(callbacks_before_opt_eval, opt_result) - if config.folder: - if config.checkpoint_best_only: - write_checkpoint(config.folder, model, optimizer, iteration="best") - else: - write_checkpoint(config.folder, model, optimizer, init_iter) - - plot_tracker( - writer, - model, - init_iter, - config.plotting_functions, - tracking_tool=config.tracking_tool, - ) + run_callbacks(callbacks_before_opt, opt_result) except KeyboardInterrupt: logger.info("Terminating training gracefully after the current iteration.") # outer epoch loop init_iter += 1 + callbacks_end_epoch = [ + callback + for callback in callbacks + if callback.call_end_epoch and not callback.call_during_eval + ] + callbacks_end_epoch_eval = [ + callback + for callback in callbacks + if callback.call_end_epoch and callback.call_during_eval + ] for iteration in progress.track(range(init_iter, init_iter + config.max_iter)): try: # in case there is not data needed by the model # this is the case, for example, of quantum models # which do not have classical input data (e.g. chemistry) - if dataloader is None: - loss, metrics = optimize_step( - model=model, - optimizer=optimizer, - loss_fn=loss_fn, - xs=None, - device=device, - dtype=data_dtype, - ) + loss, metrics = optimize_step( + model=model, + optimizer=optimizer, + loss_fn=loss_fn, + xs=None if dataloader is None else next(dl_iter), # type: ignore[arg-type] + device=device, + dtype=data_dtype, + ) + if isinstance(loss, Tensor): loss = loss.item() + opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics) + run_callbacks(callbacks_end_epoch, opt_result) - elif isinstance(dataloader, (DictDataLoader, DataLoader)): - loss, metrics = optimize_step( - model=model, - optimizer=optimizer, - loss_fn=loss_fn, - xs=next(dl_iter), # type: ignore[arg-type] - device=device, - dtype=data_dtype, - ) - - else: - raise NotImplementedError( - f"Unsupported dataloader type: {type(dataloader)}. " - "You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader." - ) - - if ( - config.print_every > 0 - and iteration % config.print_every == 0 - and config.verbose - ): - # Note that the loss returned by optimize_step - # is the value before doing the training step - # which is printed accordingly by the previous iteration number - print_metrics(loss, metrics, iteration - 1) - - if config.write_every > 0 and iteration % config.write_every == 0: - write_tracker( - writer, loss, metrics, iteration, tracking_tool=config.tracking_tool - ) - - if config.plot_every > 0 and iteration % config.plot_every == 0: - plot_tracker( - writer, - model, - iteration, - config.plotting_functions, - tracking_tool=config.tracking_tool, - ) if perform_val: if iteration % config.val_every == 0: - xs = next(dl_iter_val) - xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) - val_loss, *_ = loss_fn(model, xs_to_device) + val_loss, *_ = next_loss_iter(dl_iter_val) if config.validation_criterion(val_loss, best_val_loss, config.val_epsilon): # type: ignore[misc] best_val_loss = val_loss - if config.folder and config.checkpoint_best_only: - write_checkpoint(config.folder, model, optimizer, iteration="best") metrics["val_loss"] = val_loss - write_tracker( - writer, loss, metrics, iteration, tracking_tool=config.tracking_tool - ) - - if config.folder: - if ( - config.checkpoint_every > 0 - and iteration % config.checkpoint_every == 0 - and not config.checkpoint_best_only - ): - write_checkpoint(config.folder, model, optimizer, iteration) + opt_result.metrics = metrics + + run_callbacks(callbacks_end_epoch_eval, opt_result) except KeyboardInterrupt: logger.info("Terminating training gracefully after the current iteration.") @@ -277,21 +340,16 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d # Handling printing the last training loss # as optimize_step does not give the loss value at the last iteration try: - xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type] - xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) - loss, metrics, *_ = loss_fn(model, xs_to_device) - if dataloader is None: - loss = loss.item() + loss, metrics, *_ = next_loss_iter(dl_iter) if iteration % config.print_every == 0 and config.verbose: print_metrics(loss, metrics, iteration) except KeyboardInterrupt: logger.info("Terminating training gracefully after the current iteration.") - # Final checkpointing and writing - if config.folder and not config.checkpoint_best_only: - write_checkpoint(config.folder, model, optimizer, iteration) - write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool) + # Final callbacks, by default checkpointing and writing + callbacks_after_opt = [callback for callback in callbacks if callback.call_after_opt] + run_callbacks(callbacks_after_opt, opt_result) # writing hyperparameters if config.hyperparams: diff --git a/qadence/ml_tools/train_no_grad.py b/qadence/ml_tools/train_no_grad.py index 41d325300..e312244a7 100644 --- a/qadence/ml_tools/train_no_grad.py +++ b/qadence/ml_tools/train_no_grad.py @@ -12,8 +12,8 @@ from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from qadence.ml_tools.config import TrainConfig -from qadence.ml_tools.data import DictDataLoader +from qadence.ml_tools.config import Callback, TrainConfig +from qadence.ml_tools.data import DictDataLoader, OptimizeResult from qadence.ml_tools.parameters import get_parameters, set_parameters from qadence.ml_tools.printing import ( log_model_tracker, @@ -86,6 +86,12 @@ def _update_parameters( params = get_parameters(model).detach().numpy() ng_params = ng.p.Array(init=params) + if not ((dataloader is None) or isinstance(dataloader, (DictDataLoader, DataLoader))): + raise NotImplementedError( + f"Unsupported dataloader type: {type(dataloader)}. " + "You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader." + ) + # serial training # TODO: Add a parallelization using the num_workers argument in Nevergrad progress = Progress( @@ -94,38 +100,85 @@ def _update_parameters( TaskProgressColumn(), TimeRemainingColumn(elapsed_when_finished=True), ) - with progress: - dl_iter = iter(dataloader) if dataloader is not None else None - - for iteration in progress.track(range(init_iter, init_iter + config.max_iter)): - if dataloader is None: - loss, metrics, ng_params = _update_parameters(None, ng_params) - - elif isinstance(dataloader, (DictDataLoader, DataLoader)): - data = next(dl_iter) # type: ignore[arg-type] - loss, metrics, ng_params = _update_parameters(data, ng_params) - - else: - raise NotImplementedError("Unsupported dataloader type!") - - if config.print_every > 0 and iteration % config.print_every == 0 and config.verbose: - print_metrics(loss, metrics, iteration) - if config.write_every > 0 and iteration % config.write_every == 0: - write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool) - - if config.plot_every > 0 and iteration % config.plot_every == 0: - plot_tracker( + # populate callbacks with already available internal functions + # printing, writing and plotting + callbacks = config.callbacks + + # printing + if config.verbose and config.print_every > 0: + callbacks += [ + Callback( + lambda opt_res: print_metrics(opt_res.loss, opt_res.metrics, opt_res.iteration), + called_every=config.print_every, + ) + ] + + # writing metrics + if config.write_every > 0: + callbacks += [ + Callback( + lambda opt_res: write_tracker( + writer, + opt_res.loss, + opt_res.metrics, + opt_res.iteration, + tracking_tool=config.tracking_tool, + ), + called_every=config.write_every, + call_after_opt=True, + ) + ] + + # plot tracker + if config.plot_every > 0: + callbacks += [ + Callback( + lambda opt_res: plot_tracker( writer, - model, - iteration, + opt_res.model, + opt_res.iteration, config.plotting_functions, tracking_tool=config.tracking_tool, - ) + ), + called_every=config.plot_every, + ) + ] + + # checkpointing + if config.folder and config.checkpoint_every > 0: + callbacks += [ + Callback( + lambda opt_res: write_checkpoint( + config.folder, # type: ignore[arg-type] + opt_res.model, + opt_res.optimizer, + opt_res.iteration, + ), + called_every=config.checkpoint_every, + call_after_opt=True, + ) + ] + + def run_callbacks(callback_iterable: list[Callback], opt_res: OptimizeResult) -> None: + for callback in callback_iterable: + callback(opt_res) + + callbacks_end_opt = [ + callback + for callback in callbacks + if callback.call_end_epoch and not callback.call_during_eval + ] + + with progress: + dl_iter = iter(dataloader) if dataloader is not None else None - if config.folder: - if config.checkpoint_every > 0 and iteration % config.checkpoint_every == 0: - write_checkpoint(config.folder, model, optimizer, iteration) + for iteration in progress.track(range(init_iter, init_iter + config.max_iter)): + loss, metrics, ng_params = _update_parameters( + None if dataloader is None else next(dl_iter), ng_params # type: ignore[arg-type] + ) + opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics) + run_callbacks(callbacks_end_opt, opt_result) if iteration >= init_iter + config.max_iter: break @@ -137,10 +190,9 @@ def _update_parameters( if config.log_model: log_model_tracker(writer, model, dataloader, tracking_tool=config.tracking_tool) - # Final writing and checkpointing - if config.folder: - write_checkpoint(config.folder, model, optimizer, iteration) - write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool) + # Final callbacks + callbacks_after_opt = [callback for callback in callbacks if callback.call_after_opt] + run_callbacks(callbacks_after_opt, opt_result) # close tracker if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD: