Skip to content

Commit

Permalink
Merge branch 'fix/sanity_check_metrics' into 'main'
Browse files Browse the repository at this point in the history
Skip metric extraction when sanity checking

See merge request es/ai/hannah/hannah!400
  • Loading branch information
moreib committed Aug 12, 2024
2 parents 2ac38b6 + 90b281d commit e9530ee
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions hannah/callbacks/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
monitor_val = callback_metrics[monitor] * direction
if monitor.startswith("train"):
self._curves[monitor][trainer.global_step] = monitor_val

self.values[monitor] = monitor_val

def on_test_end(self, trainer, pl_module):
"""
Expand Down Expand Up @@ -151,6 +153,10 @@ def on_validation_end(self, trainer, pl_module):
Returns:
"""
# Skip evaluation of validation metrics during sanity check
if trainer.sanity_checking:
return

callback_metrics = trainer.callback_metrics

for k, v in callback_metrics.items():
Expand All @@ -162,11 +168,8 @@ def on_validation_end(self, trainer, pl_module):
try:
monitor_val = float(callback_metrics[monitor])
directed_monitor_val = monitor_val * direction
if (
monitor not in self.values
or directed_monitor_val < self.values[monitor]
):
self.values[monitor] = directed_monitor_val

self.values[monitor] = directed_monitor_val
self._curves[monitor][trainer.global_step] = directed_monitor_val
except Exception:
pass
Expand Down

0 comments on commit e9530ee

Please sign in to comment.