-
Notifications
You must be signed in to change notification settings - Fork 40
/
callback.py
108 lines (80 loc) · 2.97 KB
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations
from typing import List
from tensordict import TensorDictBase
class Callback:
"""
A Callback that can be added to experiments.
To create your callback, you can inherit from this class
and reimplement just the functions you need.
Attributes:
experiment (Experiment): the experiment associated to the callback
"""
def __init__(self):
self.experiment = None
def on_setup(self):
"""A callback called atexperiment setup."""
pass
def on_batch_collected(self, batch: TensorDictBase):
"""
A callback called at the end of every collection step.
Args:
batch (TensorDictBase): batch of collected data
"""
pass
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
"""
A callback called for every training step.
Args:
batch (TensorDictBase): tensordict with the training batch
group (str): group name
Returns:
TensorDictBase: a new tensordict containing the loss values
"""
pass
def on_train_end(self, training_td: TensorDictBase, group: str):
"""
A callback called at the end of training.
Args:
training_td (TensorDictBase): tensordict containing the loss values
group (str): group name
"""
pass
def on_evaluation_end(self, rollouts: List[TensorDictBase]):
"""
A callback called at the end of every training step.
Args:
rollouts (list of TensorDictBase): tensordict containing the loss values
"""
pass
class CallbackNotifier:
def __init__(self, experiment, callbacks: List[Callback]):
self.callbacks = callbacks
for callback in self.callbacks:
callback.experiment = experiment
def _on_setup(self):
for callback in self.callbacks:
callback.on_setup()
def _on_batch_collected(self, batch: TensorDictBase):
for callback in self.callbacks:
callback.on_batch_collected(batch)
def _on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
train_td = None
for callback in self.callbacks:
td = callback.on_train_step(batch, group)
if td is not None:
if train_td is None:
train_td = td
else:
train_td.update(td)
return train_td
def _on_train_end(self, training_td: TensorDictBase, group: str):
for callback in self.callbacks:
callback.on_train_end(training_td, group)
def _on_evaluation_end(self, rollouts: List[TensorDictBase]):
for callback in self.callbacks:
callback.on_evaluation_end(rollouts)