diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 270083f717..49c84dab28 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -53,6 +53,12 @@ ROC AUC metrics handler :members: +Average Precision metric handler +-------------------------------- +.. autoclass:: AveragePrecision + :members: + + Confusion matrix metrics handler -------------------------------- .. autoclass:: ConfusionMatrix diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 616f0fe385..16593a7d0a 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -80,6 +80,13 @@ Metrics .. autoclass:: ROCAUCMetric :members: +`Average Precision` +-------------------------- +.. autofunction:: compute_average_precision + +.. autoclass:: AveragePrecisionMetric + :members: + `Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c1fa448f25..ed5db8a7f3 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .average_precision import AveragePrecision from .checkpoint_loader import CheckpointLoader from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver diff --git a/monai/handlers/average_precision.py b/monai/handlers/average_precision.py new file mode 100644 index 0000000000..608d7eea72 --- /dev/null +++ b/monai/handlers/average_precision.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import AveragePrecisionMetric +from monai.utils import Average + + +class AveragePrecision(IgniteMetricHandler): + """ + Computes Average Precision (AP). + accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + + Note: + Average Precision expects y to be comprised of 0's and 1's. + y_pred must either be probability estimates or confidence values. + + """ + + def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None: + metric_fn = AveragePrecisionMetric(average=Average(average)) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 201acdfa50..7176f3311f 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score +from .average_precision import AveragePrecisionMetric, compute_average_precision from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore diff --git a/monai/metrics/average_precision.py b/monai/metrics/average_precision.py new file mode 100644 index 0000000000..0d9d3fe228 --- /dev/null +++ b/monai/metrics/average_precision.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, cast + +import numpy as np + +if TYPE_CHECKING: + import numpy.typing as npt + +import torch + +from monai.utils import Average, look_up_option + +from .metric import CumulativeIterationMetric + + +class AveragePrecisionMetric(CumulativeIterationMetric): + """ + Computes Average Precision (AP). Referring to: `sklearn.metrics.average_precision_score + `_. + The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor. + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. + Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + """ + + def __init__(self, average: Average | str = Average.MACRO) -> None: + super().__init__() + self.average = average + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + return y_pred, y + + def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike: + """ + Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration, + This function reads the buffers and computes the Average Precision. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. Defaults to `self.average`. + + """ + y_pred, y = self.get_buffer() + # compute final value and do metric reduction + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") + + return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average) + + +def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: + if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)): + raise AssertionError("y and y_pred must be 1 dimension data with same length.") + y_unique = y.unique() + if len(y_unique) == 1: + warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.") + return float("nan") + if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)): + warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.") + return float("nan") + + n = len(y) + indices = y_pred.argsort(descending=True) + y = y[indices].cpu().numpy() # type: ignore[assignment] + y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment] + npos = ap = tmp_pos = 0.0 + + for i in range(n): + y_i = cast(float, y[i]) + if i + 1 < n and y_pred[i] == y_pred[i + 1]: + tmp_pos += y_i + else: + tmp_pos += y_i + npos += tmp_pos + ap += tmp_pos * npos / (i + 1) + tmp_pos = 0 + + return ap / npos + + +def compute_average_precision( + y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO +) -> np.ndarray | float | npt.ArrayLike: + """Computes Average Precision (AP). Referring to: `sklearn.metrics.average_precision_score + `_. + + Args: + y_pred: input data to compute, typical classification model output. + the first dim must be batch, if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + y: ground truth to compute AP metric, the first dim must be batch. + if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. + Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + Raises: + ValueError: When ``y_pred`` dimension is not one of [1, 2]. + ValueError: When ``y`` dimension is not one of [1, 2]. + ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"]. + + Note: + Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. + + """ + y_pred_ndim = y_pred.ndimension() + y_ndim = y.ndimension() + if y_pred_ndim not in (1, 2): + raise ValueError( + f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}." + ) + if y_ndim not in (1, 2): + raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.") + if y_pred_ndim == 2 and y_pred.shape[1] == 1: + y_pred = y_pred.squeeze(dim=-1) + y_pred_ndim = 1 + if y_ndim == 2 and y.shape[1] == 1: + y = y.squeeze(dim=-1) + + if y_pred_ndim == 1: + return _calculate(y_pred, y) + + if y.shape != y_pred.shape: + raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.") + + average = look_up_option(average, Average) + if average == Average.MICRO: + return _calculate(y_pred.flatten(), y.flatten()) + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)] + if average == Average.NONE: + return ap_values + if average == Average.MACRO: + return np.mean(ap_values) + if average == Average.WEIGHTED: + weights = [sum(y_) for y_ in y] + return np.average(ap_values, weights=weights) # type: ignore[no-any-return] + raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].') diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 1fbf3ffa05..01c62c751c 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -213,7 +213,8 @@ class GridSamplePadMode(StrEnum): class Average(StrEnum): """ - See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` + See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` or + :py:class:`monai.metrics.average_precision.compute_average_precision` """ MACRO = "macro" diff --git a/tests/min_tests.py b/tests/min_tests.py index f39d3f9843..6b8f7df7a7 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -76,6 +76,8 @@ def run_testsuit(): "test_grid_patch", "test_gmm", "test_handler_metrics_reloaded", + "test_handler_average_precision", + "test_handler_average_precision_dist", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", diff --git a/tests/test_compute_average_precision.py b/tests/test_compute_average_precision.py new file mode 100644 index 0000000000..ed3841bc30 --- /dev/null +++ b/tests/test_compute_average_precision.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import decollate_batch +from monai.metrics import AveragePrecisionMetric, compute_average_precision +from monai.transforms import Activations, AsDiscrete, Compose, ToTensor + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" +TEST_CASE_1 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[0], [0], [1], [1]], device=_device), + True, + 2, + "macro", + 0.41667, +] + +TEST_CASE_2 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[1], [1], [0], [0]], device=_device), + True, + 2, + "micro", + 0.85417, +] + +TEST_CASE_3 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[0], [1], [0], [1]], device=_device), + True, + 2, + "macro", + 0.83333, +] + +TEST_CASE_4 = [ + torch.tensor([[0.5], [0.5], [0.2], [8.3]]), + torch.tensor([[0], [1], [0], [1]]), + False, + None, + "macro", + 0.83333, +] + +TEST_CASE_5 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333] + +TEST_CASE_6 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333] + +TEST_CASE_7 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [1], [0], [1]]), + True, + 2, + "none", + [0.83333, 0.83333], +] + +TEST_CASE_8 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + None, + "weighted", + 0.66667, +] + +TEST_CASE_9 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + None, + "micro", + 0.71111, +] + +TEST_CASE_10 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [0], [0], [0]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_11 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[1], [1], [1], [1]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_12 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]]), + True, + None, + "macro", + float("nan"), +] + + +class TestComputeAveragePrecision(unittest.TestCase): + + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, + ] + ) + def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) + y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) + y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) + result = compute_average_precision(y_pred=y_pred, y=y, average=average) + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, + ] + ) + def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) + y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] + y = [y_trans(i) for i in decollate_batch(y)] + metric = AveragePrecisionMetric(average=average) + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + result = metric.aggregate(average=average) # test optional argument + metric.reset() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_average_precision.py b/tests/test_handler_average_precision.py new file mode 100644 index 0000000000..d771cd1c1b --- /dev/null +++ b/tests/test_handler_average_precision.py @@ -0,0 +1,48 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch + +from monai.handlers import AveragePrecision +from monai.transforms import Activations, AsDiscrete + + +class TestHandlerAveragePrecision(unittest.TestCase): + + def test_compute(self): + ap_metric = AveragePrecision() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=2) + + y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + ap_metric.update([y_pred, y]) + + y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + + ap_metric.update([y_pred, y]) + + ap = ap_metric.compute() + np.testing.assert_allclose(0.8333333, ap) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_average_precision_dist.py b/tests/test_handler_average_precision_dist.py new file mode 100644 index 0000000000..4305993c5a --- /dev/null +++ b/tests/test_handler_average_precision_dist.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from monai.handlers import AveragePrecision +from monai.transforms import Activations, AsDiscrete +from tests.utils import DistCall, DistTestCase + + +class DistributedAveragePrecision(DistTestCase): + + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self): + ap_metric = AveragePrecision() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=2) + + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" + if dist.get_rank() == 0: + y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device)] + + if dist.get_rank() == 1: + y_pred = [ + torch.tensor([0.2, 0.1], device=device), + torch.tensor([0.1, 0.5], device=device), + torch.tensor([0.3, 0.4], device=device), + ] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)] + + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + ap_metric.update([y_pred, y]) + + result = ap_metric.compute() + np.testing.assert_allclose(0.7778, result, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main()