Skip to content

Commit

Permalink
refactor: Compute the optimal threshold as average between the max go…
Browse files Browse the repository at this point in the history
…od and min bad score when the F1 is 1
  • Loading branch information
lorenzomammana committed Apr 19, 2024
1 parent 04e2db7 commit da6cdc2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
20 changes: 17 additions & 3 deletions src/anomalib/utils/metrics/anomaly_score_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,31 @@ def compute(self) -> Tensor:
Value of the F1 score at the optimal threshold.
"""
current_targets = torch.concat(self.target)
current_preds = torch.concat(self.preds)

epsilon = 1e-3

if len(current_targets.unique()) == 1:
if current_targets.max() == 0:
self.value = torch.concat(self.preds).max() + epsilon
self.value = torch.concat(current_preds).max() + epsilon
else:
self.value = torch.concat(self.preds).min()
self.value = torch.concat(current_preds).min() - epsilon
else:
precision, recall, thresholds = super().compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
self.value = thresholds[torch.argmax(f1_score)]
optimal_f1_score = torch.max(f1_score)

if thresholds.nelement() == 1:
# Particular case when f1 score is 1 and the threshold is unique
self.value = thresholds
else:
if optimal_f1_score == 1:
# If there is a good boundary between good and bads we pick the average of the highest good
# and lowest bad
max_good_score = current_preds[torch.where(current_targets == 0)].max()
min_bad_score = current_preds[torch.where(current_targets == 1)].min()
self.value = (max_good_score + min_bad_score) / 2
else:
self.value = thresholds[torch.argmax(f1_score)]

return self.value
18 changes: 14 additions & 4 deletions src/anomalib/utils/metrics/optimal_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,36 @@ def compute(self) -> Tensor:
recall: torch.Tensor
thresholds: torch.Tensor
current_targets = torch.concat(self.precision_recall_curve.target)
current_preds = torch.concat(self.precision_recall_curve.preds)

epsilon = 1e-3
if len(current_targets.unique()) == 1:
optimal_f1_score = torch.tensor(1.0)

if current_targets.max() == 0:
self.threshold = torch.concat(self.precision_recall_curve.preds).max() + epsilon
self.threshold = current_preds.max() + epsilon
else:
self.threshold = torch.concat(self.precision_recall_curve.preds).min()
self.threshold = current_preds.min() - epsilon

return optimal_f1_score
else:
precision, recall, thresholds = self.precision_recall_curve.compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
optimal_f1_score = torch.max(f1_score)

if thresholds.nelement() == 1:
# Particular case when f1 score is 1 and the threshold is unique
self.threshold = thresholds
else:
self.threshold = thresholds[torch.argmax(f1_score)]
optimal_f1_score = torch.max(f1_score)
if optimal_f1_score == 1:
# If there is a good boundary between good and bads we pick the average of the highest good
# and lowest bad
max_good_score = current_preds[torch.where(current_targets == 0)].max()
min_bad_score = current_preds[torch.where(current_targets == 1)].min()
self.threshold = (max_good_score + min_bad_score) / 2
else:
self.threshold = thresholds[torch.argmax(f1_score)]

return optimal_f1_score

def reset(self) -> None:
Expand Down

0 comments on commit da6cdc2

Please sign in to comment.