Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Average Precision to metrics #8089

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ ROC AUC metrics handler
:members:


Average Precision metric handler
--------------------------------
.. autoclass:: AveragePrecision
:members:


Confusion matrix metrics handler
--------------------------------
.. autoclass:: ConfusionMatrix
Expand Down
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ Metrics
.. autoclass:: ROCAUCMetric
:members:

`Average Precision`
--------------------------
Comment on lines +83 to +84
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`Average Precision`
--------------------------
`Average Precision`
-------------------

.. autofunction:: compute_average_precision

.. autoclass:: AveragePrecisionMetric
:members:

`Confusion matrix`
------------------
.. autofunction:: get_confusion_matrix
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from .average_precision import AveragePrecision
from .checkpoint_loader import CheckpointLoader
from .checkpoint_saver import CheckpointSaver
from .classification_saver import ClassificationSaver
Expand Down
53 changes: 53 additions & 0 deletions monai/handlers/average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import AveragePrecisionMetric
from monai.utils import Average


class AveragePrecision(IgniteMetricHandler):
"""
Computes Average Precision (AP).
accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.

Args:
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification. Defaults to ``"macro"``.

- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
This does not take label imbalance into account.
- ``"weighted"``: calculate metrics for each label, and find their average,
weighted by support (the number of true instances for each label).
- ``"micro"``: calculate metrics globally by considering each element of the label
indicator matrix as a label.
- ``"none"``: the scores for each class are returned.

output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
`engine.state` and `output_transform` inherit from the ignite concept:
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.

Note:
Average Precision expects y to be comprised of 0's and 1's.
y_pred must either be probability estimates or confidence values.

"""

def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
metric_fn = AveragePrecisionMetric(average=Average(average))
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
1 change: 1 addition & 0 deletions monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
from .average_precision import AveragePrecisionMetric, compute_average_precision
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
from .cumulative_average import CumulativeAverage
from .f_beta_score import FBetaScore
Expand Down
173 changes: 173 additions & 0 deletions monai/metrics/average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, cast

import numpy as np

if TYPE_CHECKING:
import numpy.typing as npt

import torch

from monai.utils import Average, look_up_option

from .metric import CumulativeIterationMetric


class AveragePrecisionMetric(CumulativeIterationMetric):
"""
Computes Average Precision (AP). Referring to: `sklearn.metrics.average_precision_score
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the lazy it would be nice to explain here what this metric is and its intended application area.

<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Args:
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification.
Defaults to ``"macro"``.

- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
This does not take label imbalance into account.
- ``"weighted"``: calculate metrics for each label, and find their average,
weighted by support (the number of true instances for each label).
- ``"micro"``: calculate metrics globally by considering each element of the label
indicator matrix as a label.
- ``"none"``: the scores for each class are returned.

"""

def __init__(self, average: Average | str = Average.MACRO) -> None:
super().__init__()
self.average = average

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
return y_pred, y

def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:
"""
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
This function reads the buffers and computes the Average Precision.

Args:
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification. Defaults to `self.average`.

"""
y_pred, y = self.get_buffer()
# compute final value and do metric reduction
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
raise ValueError("y_pred and y must be PyTorch Tensor.")

return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)


def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
raise AssertionError("y and y_pred must be 1 dimension data with same length.")
y_unique = y.unique()
if len(y_unique) == 1:
warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.")
return float("nan")
if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.")
return float("nan")

n = len(y)
indices = y_pred.argsort(descending=True)
y = y[indices].cpu().numpy() # type: ignore[assignment]
y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
npos = ap = tmp_pos = 0.0

for i in range(n):
y_i = cast(float, y[i])
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
tmp_pos += y_i
else:
tmp_pos += y_i
npos += tmp_pos
ap += tmp_pos * npos / (i + 1)
tmp_pos = 0

return ap / npos


def compute_average_precision(
y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO
) -> np.ndarray | float | npt.ArrayLike:
"""Computes Average Precision (AP). Referring to: `sklearn.metrics.average_precision_score
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.

Args:
y_pred: input data to compute, typical classification model output.
the first dim must be batch, if multi-classes, it must be in One-Hot format.
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
y: ground truth to compute AP metric, the first dim must be batch.
if multi-classes, it must be in One-Hot format.
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification.
Defaults to ``"macro"``.

- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
This does not take label imbalance into account.
- ``"weighted"``: calculate metrics for each label, and find their average,
weighted by support (the number of true instances for each label).
- ``"micro"``: calculate metrics globally by considering each element of the label
indicator matrix as a label.
- ``"none"``: the scores for each class are returned.

Raises:
ValueError: When ``y_pred`` dimension is not one of [1, 2].
ValueError: When ``y`` dimension is not one of [1, 2].
ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

Note:
Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

"""
y_pred_ndim = y_pred.ndimension()
y_ndim = y.ndimension()
if y_pred_ndim not in (1, 2):
raise ValueError(
f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
)
if y_ndim not in (1, 2):
raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
y_pred_ndim = 1
if y_ndim == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)

if y_pred_ndim == 1:
return _calculate(y_pred, y)

if y.shape != y_pred.shape:
raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")

average = look_up_option(average, Average)
if average == Average.MICRO:
return _calculate(y_pred.flatten(), y.flatten())
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
if average == Average.NONE:
return ap_values
if average == Average.MACRO:
return np.mean(ap_values)
if average == Average.WEIGHTED:
weights = [sum(y_) for y_ in y]
return np.average(ap_values, weights=weights) # type: ignore[no-any-return]
raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
3 changes: 2 additions & 1 deletion monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ class GridSamplePadMode(StrEnum):

class Average(StrEnum):
"""
See also: :py:class:`monai.metrics.rocauc.compute_roc_auc`
See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` or
:py:class:`monai.metrics.average_precision.compute_average_precision`
"""

MACRO = "macro"
Expand Down
2 changes: 2 additions & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def run_testsuit():
"test_grid_patch",
"test_gmm",
"test_handler_metrics_reloaded",
"test_handler_average_precision",
"test_handler_average_precision_dist",
"test_handler_checkpoint_loader",
"test_handler_checkpoint_saver",
"test_handler_classification_saver",
Expand Down
Loading
Loading