Skip to content

Commit

Permalink
add stopping epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
cybershiptrooper committed Sep 2, 2024
1 parent e0be350 commit 7984bca
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, final, Type, Optional
from typing import Any, Callable, final, Optional

import numpy as np
import torch as t
Expand Down Expand Up @@ -32,6 +31,7 @@ class BaseModelPair(ABC):
wandb_method: str
rng: np.random.Generator
dataset_class: 'IITDataset'
stopping_epoch: int | None = None

##########################################
# Abstract methods you need to implement #
Expand Down Expand Up @@ -293,9 +293,8 @@ def train(
)

if early_stop and self._check_early_stop_condition(test_metrics):
self.stopping_epoch = epoch + 1
break

self._run_epoch_extras(epoch_number=epoch+1)

if use_wandb:
wandb.log({"final epoch": epoch})
Expand Down Expand Up @@ -395,7 +394,3 @@ def _print_and_log_metrics(
epoch_pbar.set_description(f"Epoch {epoch + 1}")

return current_epoch_log

def _run_epoch_extras(self, epoch_number: int) -> None:
""" Optional method for running extra code at the end of each epoch """
pass

0 comments on commit 7984bca

Please sign in to comment.