diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 291d0d4294c..348575c9abf 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -171,7 +171,7 @@ def main(cfg: "DictConfig"): # noqa: F821 scaler3.step(value_opt) scaler3.update() - metrics_to_log = {"reward": ep_reward.item()} + metrics_to_log = {"reward": ep_reward.mean().item()} if collected_frames >= init_random_frames: loss_metrics = { "loss_model_kl": model_loss_td["loss_model_kl"].item(),