Skip to content

Commit

Permalink
[Feature] Adding callback functionalities in train functions (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
chMoussa authored Aug 21, 2024
1 parent 424cf83 commit 0d4cc5d
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 135 deletions.
8 changes: 7 additions & 1 deletion docs/tutorials/qml/ml_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,25 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d

The [`TrainConfig`][qadence.ml_tools.config.TrainConfig] tells `train_with_grad` what batch_size should be used,
how many epochs to train, in which intervals to print/log metrics and how often to store intermediate checkpoints.
It is also possible to provide custom callback functions by instantiating a [`Callback`][qadence.ml_tools.config.Callback]
with a function `callback` that only accept as argument an instance of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult] created within the `train` functions.
One can also provide a `callback_condition` function, also only accepting an instance of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult], which returns True if `callback` should be called. If no `callback_condition` is provided, `callback` is called at every x epochs (specified by `Callback`'s `called_every` argument). We can also specify in which part of the training function the `Callback` will be applied.

```python exec="on" source="material-block"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools import TrainConfig, Callback

batch_size = 5
n_epochs = 100

custom_callback = Callback(lambda opt_res: print(opt_res.model.parameters()), callback_condition=lambda opt_res: opt_res.loss < 1.0e-03, called_every=10, call_end_epoch=True)

config = TrainConfig(
folder="some_path/",
max_iter=n_epochs,
checkpoint_every=100,
write_every=100,
batch_size=batch_size,
callbacks = [custom_callback]
)
```

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ authors = [
]
requires-python = ">=3.9"
license = { text = "Apache 2.0" }
version = "1.7.4"
version = "1.7.5"
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
Expand Down Expand Up @@ -133,6 +133,7 @@ filterwarnings = [
[tool.hatch.envs.docs]
dependencies = [
"mkdocs",
"mkdocs_autorefs<1.1.0",
"mkdocs-material",
"mkdocstrings",
"mkdocstrings-python",
Expand Down
3 changes: 2 additions & 1 deletion qadence/ml_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from .config import AnsatzConfig, FeatureMapConfig, TrainConfig
from .config import AnsatzConfig, Callback, FeatureMapConfig, TrainConfig
from .constructors import create_ansatz, create_fm_blocks, observable_from_config
from .data import DictDataLoader, InfiniteTensorDataset, to_dataloader
from .models import QNN
Expand All @@ -23,6 +23,7 @@
"observable_from_config",
"QNN",
"TrainConfig",
"Callback",
"train_with_grad",
"train_gradient_free",
"write_checkpoint",
Expand Down
83 changes: 82 additions & 1 deletion qadence/ml_tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from dataclasses import dataclass, field, fields
from logging import getLogger
from pathlib import Path
from typing import Callable, Type
from typing import Any, Callable, Type
from uuid import uuid4

from sympy import Basic
from torch import Tensor

from qadence.blocks.analog import AnalogBlock
from qadence.blocks.primitive import ParametricBlock
from qadence.ml_tools.data import OptimizeResult
from qadence.operations import RX, AnalogRX
from qadence.parameters import Parameter
from qadence.types import (
Expand All @@ -27,6 +28,84 @@

logger = getLogger(__file__)

CallbackFunction = Callable[[OptimizeResult], None]
CallbackConditionFunction = Callable[[OptimizeResult], bool]


class Callback:
"""Callback functions are calling in train functions.
Each callback function should take at least as first input
an OptimizeResult instance.
Attributes:
callback (CallbackFunction): Callback function accepting an
OptimizeResult as first argument.
callback_condition (CallbackConditionFunction | None, optional): Function that
conditions the call to callback. Defaults to None.
called_every (int, optional): Callback to be called each `called_every` epoch.
Defaults to 1.
If callback_condition is None, we set
callback_condition to returns True when iteration % every == 0.
call_before_opt (bool, optional): If true, callback is applied before training.
Defaults to False.
call_end_epoch (bool, optional): If true, callback is applied during training,
after an epoch is performed. Defaults to True.
call_after_opt (bool, optional): If true, callback is applied after training.
Defaults to False.
call_during_eval (bool, optional): If true, callback is applied during evaluation.
Defaults to False.
"""

def __init__(
self,
callback: CallbackFunction,
callback_condition: CallbackConditionFunction | None = None,
called_every: int = 1,
call_before_opt: bool = False,
call_end_epoch: bool = True,
call_after_opt: bool = False,
call_during_eval: bool = False,
) -> None:
"""Initialized Callback.
Args:
callback (CallbackFunction): Callback function accepting an
OptimizeResult as ifrst argument.
callback_condition (CallbackConditionFunction | None, optional): Function that
conditions the call to callback. Defaults to None.
called_every (int, optional): Callback to be called each `called_every` epoch.
Defaults to 1.
If callback_condition is None, we set
callback_condition to returns True when iteration % every == 0.
call_before_opt (bool, optional): If true, callback is applied before training.
Defaults to False.
call_end_epoch (bool, optional): If true, callback is applied during training,
after an epoch is performed. Defaults to True.
call_after_opt (bool, optional): If true, callback is applied after training.
Defaults to False.
call_during_eval (bool, optional): If true, callback is applied during evaluation.
Defaults to False.
"""
self.callback = callback
self.call_before_opt = call_before_opt
self.call_end_epoch = call_end_epoch
self.call_after_opt = call_after_opt
self.call_during_eval = call_during_eval

if called_every <= 0:
raise ValueError("Please provide a strictly positive `called_every` argument.")
self.called_every = called_every

if callback_condition is None:
self.callback_condition = lambda opt_result: True
else:
self.callback_condition = callback_condition

def __call__(self, opt_result: OptimizeResult) -> Any:
if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result):
return self.callback(opt_result)


@dataclass
class TrainConfig:
Expand Down Expand Up @@ -64,6 +143,8 @@ class TrainConfig:
Set to 0 to disable
"""
callbacks: list[Callback] = field(default_factory=lambda: list())
"""List of callbacks."""
log_model: bool = False
"""Logs a serialised version of the model."""
folder: Path | None = None
Expand Down
28 changes: 27 additions & 1 deletion qadence/ml_tools/data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,41 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import singledispatch
from itertools import cycle
from typing import Any, Iterator

from nevergrad.optimization.base import Optimizer as NGOptimizer
from torch import Tensor
from torch import device as torch_device
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader, IterableDataset, TensorDataset


@dataclass
class OptimizeResult:
"""OptimizeResult stores many optimization intermediate values.
We store at a current iteration,
the model, optimizer, loss values, metrics. An extra dict
can be used for saving other information to be used for callbacks.
"""

iteration: int
"""Current iteration number."""
model: Module
"""Model at iteration."""
optimizer: Optimizer | NGOptimizer
"""Optimizer at iteration."""
loss: Tensor | float | None = None
"""Loss value."""
metrics: dict = field(default_factory=lambda: dict())
"""Metrics that can be saved during training."""
extra: dict = field(default_factory=lambda: dict())
"""Extra dict for saving anything else to be used in callbacks."""


@dataclass
class DictDataLoader:
"""This class only holds a dictionary of `DataLoader`s and samples from them."""
Expand Down
5 changes: 3 additions & 2 deletions qadence/ml_tools/optimize_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def optimize_step(
xs (dict | list | torch.Tensor | None): the input data. If None it means
that the given model does not require any input data
device (torch.device): A target device to run computation on.
dtype (torch.dtype): Data type for xs conversion.
Returns:
tuple: tuple containing the model, the optimizer, a dictionary with
the collected metrics and the compute value loss
tuple: tuple containing the computed loss value, and a dictionary with
the collected metrics.
"""

loss, metrics = None, {}
Expand Down
Loading

0 comments on commit 0d4cc5d

Please sign in to comment.