diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index 78434b5..f61f3eb 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -374,12 +374,17 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo """ Returns True if all types of accuracy metrics reach 100% """ + assert "early_stop_accuracy_threshold" in self.training_args, ValueError( + "early_stop_accuracy_threshold not found in training_args" + ) + early_stop_accuracy_threshold = float(self.training_args["early_stop_accuracy_threshold"]) + got_accuracy_metric = False for metric in test_metrics: if metric.type == MetricType.ACCURACY: got_accuracy_metric = True val = metric.get_value() - if isinstance(val, float) and val < 99.5: + if isinstance(val, float) and val < early_stop_accuracy_threshold: return False if not got_accuracy_metric: raise ValueError("No accuracy metric found in test_metrics!") diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index 0e0af62..a97fd4e 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -28,6 +28,7 @@ def __init__( "batch_size": 256, "num_workers": 0, "early_stop": True, + "early_stop_accuracy_threshold": 99.5, "lr_scheduler": None, "scheduler_val_metric": ["val/accuracy", "val/IIA"], "scheduler_mode": "max",