Skip to content

Commit

Permalink
Merge pull request #20 from FlyingPumba/feature/early_stop_accuracy_t…
Browse files Browse the repository at this point in the history
…hreshold

Allow configuring accuracy threshold for early stop
  • Loading branch information
cybershiptrooper authored Sep 16, 2024
2 parents f3c96e8 + 903b854 commit 70f5b36
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
7 changes: 6 additions & 1 deletion iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,17 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo
"""
Returns True if all types of accuracy metrics reach 100%
"""
assert "early_stop_accuracy_threshold" in self.training_args, ValueError(
"early_stop_accuracy_threshold not found in training_args"
)
early_stop_accuracy_threshold = float(self.training_args["early_stop_accuracy_threshold"])

got_accuracy_metric = False
for metric in test_metrics:
if metric.type == MetricType.ACCURACY:
got_accuracy_metric = True
val = metric.get_value()
if isinstance(val, float) and val < 99.5:
if isinstance(val, float) and val < early_stop_accuracy_threshold:
return False
if not got_accuracy_metric:
raise ValueError("No accuracy metric found in test_metrics!")
Expand Down
1 change: 1 addition & 0 deletions iit/model_pairs/iit_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
"batch_size": 256,
"num_workers": 0,
"early_stop": True,
"early_stop_accuracy_threshold": 99.5,
"lr_scheduler": None,
"scheduler_val_metric": ["val/accuracy", "val/IIA"],
"scheduler_mode": "max",
Expand Down

0 comments on commit 70f5b36

Please sign in to comment.