diff --git a/hannah/callbacks/optimization.py b/hannah/callbacks/optimization.py index 36d7e791..57784b92 100644 --- a/hannah/callbacks/optimization.py +++ b/hannah/callbacks/optimization.py @@ -112,6 +112,8 @@ def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu monitor_val = callback_metrics[monitor] * direction if monitor.startswith("train"): self._curves[monitor][trainer.global_step] = monitor_val + + self.values[monitor] = monitor_val def on_test_end(self, trainer, pl_module): """ @@ -151,6 +153,10 @@ def on_validation_end(self, trainer, pl_module): Returns: """ + # Skip evaluation of validation metrics during sanity check + if trainer.sanity_checking: + return + callback_metrics = trainer.callback_metrics for k, v in callback_metrics.items(): @@ -162,11 +168,8 @@ def on_validation_end(self, trainer, pl_module): try: monitor_val = float(callback_metrics[monitor]) directed_monitor_val = monitor_val * direction - if ( - monitor not in self.values - or directed_monitor_val < self.values[monitor] - ): - self.values[monitor] = directed_monitor_val + + self.values[monitor] = directed_monitor_val self._curves[monitor][trainer.global_step] = directed_monitor_val except Exception: pass