diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 270083f717..b48869d01e 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -77,6 +77,12 @@ Panoptic Quality metrics handler :members: +:math:`R^{2}` score +------------------- +.. autoclass:: R2Score + :members: + + Mean squared error metrics handler ---------------------------------- .. autoclass:: MeanSquaredError diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 616f0fe385..751c624405 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -117,6 +117,13 @@ Metrics .. autoclass:: PanopticQualityMetric :members: +:math:`R^{2}` score +------------------- +.. autofunction:: compute_r2_score + +.. autoclass:: R2Metric + :members: + `Mean squared error` -------------------- .. autoclass:: MSEMetric diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c1fa448f25..fed8504722 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -34,6 +34,7 @@ from .parameter_scheduler import ParamSchedulerHandler from .postprocessing import PostProcessing from .probability_maps import ProbMapProducer +from .r2_score import R2Score from .regression_metrics import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError from .roc_auc import ROCAUC from .smartcache_handler import SmartCacheHandler diff --git a/monai/handlers/r2_score.py b/monai/handlers/r2_score.py new file mode 100644 index 0000000000..dc94182885 --- /dev/null +++ b/monai/handlers/r2_score.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import R2Metric +from monai.utils import MultiOutput + + +class R2Score(IgniteMetricHandler): + """ + Computes :math:`R^{2}` score accumulating predictions and the ground-truth during an epoch and applying `compute_r2_score`. + + Args: + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. + Defaults to ``"uniform_average"``. + + - ``"raw_values"``: the scores for each output are returned. + - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances + of each individual output. + p: non-negative integer. + Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Defaults to 0 (standard :math:`R^{2}` score). + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. The form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + + See also: + :py:class:`monai.metrics.R2Metric` + + """ + + def __init__( + self, + multi_output: MultiOutput | str = MultiOutput.UNIFORM, + p: int = 0, + output_transform: Callable = lambda x: x, + ) -> None: + metric_fn = R2Metric(multi_output=multi_output, p=p) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 201acdfa50..db0de24eb0 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -25,6 +25,7 @@ from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric from .mmd import MMDMetric, compute_mmd from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality +from .r2_score import R2Metric, compute_r2_score from .regression import ( MAEMetric, MSEMetric, diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py new file mode 100644 index 0000000000..0ad2e133a5 --- /dev/null +++ b/monai/metrics/r2_score.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + import numpy.typing as npt + +import torch + +from monai.utils import MultiOutput, look_up_option + +from .metric import CumulativeIterationMetric + + +class R2Metric(CumulativeIterationMetric): + r"""Computes :math:`R^{2}` score (coefficient of determination): + + .. math:: + \operatorname {R^{2}}\left(Y, \hat{Y}\right) = 1 - \frac {\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}} + {\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}}, + + where :math:`\bar{y}` is the mean of observed :math:`y` ; or adjusted :math:`R^{2}` score: + + .. math:: + \operatorname {\bar{R}^{2}} = 1 - (1-R^{2}) \frac {n-1}{n-p-1}, + + where :math:`p` is the number of independant variables used for the regression. + + More info: https://en.wikipedia.org/wiki/Coefficient_of_determination + + Input `y_pred` is compared with ground truth `y`. + `y_pred` and `y` are expected to be 1D (single-output regression) or 2D (multi-output regression) real-valued + tensors of same shape. + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Args: + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. + Defaults to ``"uniform_average"``. + + - ``"raw_values"``: the scores for each output are returned. + - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of + each individual output. + p: non-negative integer. + Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Defaults to 0 (standard :math:`R^{2}` score). + + """ + + def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0) -> None: + super().__init__() + multi_output, p = _check_r2_params(multi_output, p) + self.multi_output = multi_output + self.p = p + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + _check_dim(y_pred, y) + return y_pred, y + + def aggregate(self, multi_output: MultiOutput | str | None = None) -> np.ndarray | float | npt.ArrayLike: + """ + Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration, + This function reads the buffers and computes the :math:`R^{2}` score. + + Args: + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. Defaults to `self.multi_output`. + + """ + y_pred, y = self.get_buffer() + return compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output or self.multi_output, p=self.p) + + +def _check_dim(y_pred: torch.Tensor, y: torch.Tensor) -> None: + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") + + if y.shape != y_pred.shape: + raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.") + + dim = y.ndimension() + if dim not in (1, 2): + raise ValueError( + f"predictions and ground truths should be of shape (batch_size, num_outputs) or (batch_size, ), got {y.shape}." + ) + + +def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutput | str, int]: + multi_output = look_up_option(multi_output, MultiOutput) + if not isinstance(p, int) or p < 0: + raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.") + + return multi_output, p + + +def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float: + num_obs = len(y) + rss = np.sum((y_pred - y) ** 2) + tss = np.sum(y**2) - np.sum(y) ** 2 / num_obs + r2 = 1 - (rss / tss) + r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1) + + return r2_adjusted # type: ignore[no-any-return] + + +def compute_r2_score( + y_pred: torch.Tensor, y: torch.Tensor, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0 +) -> np.ndarray | float | npt.ArrayLike: + """Computes :math:`R^{2}` score (coefficient of determination). + + Args: + y_pred: input data to compute :math:`R^{2}` score, the first dim must be batch. + For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables. + y: ground truth to compute :math:`R^{2}` score, the first dim must be batch. + For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables. + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. + Defaults to ``"uniform_average"``. + + - ``"raw_values"``: the scores for each output are returned. + - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances + each individual output. + p: non-negative integer. + Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Defaults to 0 (standard :math:`R^{2}` score). + + Raises: + ValueError: When ``multi_output`` is not one of ["raw_values", "uniform_average", "variance_weighted"]. + ValueError: When ``p`` is not a non-negative integer. + ValueError: When ``y_pred`` or ``y`` are not PyTorch tensors. + ValueError: When ``y_pred`` and ``y`` don't have the same shape. + ValueError: When ``y_pred`` or ``y`` dimension is not one of [1, 2]. + ValueError: When n_samples is less than 2. + ValueError: When ``p`` is greater or equal to n_samples - 1. + + """ + multi_output, p = _check_r2_params(multi_output, p) + _check_dim(y_pred, y) + dim = y.ndimension() + n = y.shape[0] + y = y.cpu().numpy() # type: ignore[assignment] + y_pred = y_pred.cpu().numpy() # type: ignore[assignment] + + if n < 2: + raise ValueError("There is no enough data for computing. Needs at least two samples to calculate r2 score.") + if p >= n - 1: + raise ValueError("`p` must be smaller than n_samples - 1, " f"got p={p}, n_samples={n}.") + + if dim == 2 and y_pred.shape[1] == 1: + y_pred = np.squeeze(y_pred, axis=-1) # type: ignore[assignment] + y = np.squeeze(y, axis=-1) # type: ignore[assignment] + dim = 1 + + if dim == 1: + return _calculate(y_pred, y, p) # type: ignore[arg-type] + + y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0)) # type: ignore[assignment] + r2_values = [_calculate(y_pred_, y_, p) for y_pred_, y_ in zip(y_pred, y)] + if multi_output == MultiOutput.RAW: + return r2_values + if multi_output == MultiOutput.UNIFORM: + return np.mean(r2_values) + if multi_output == MultiOutput.VARIANCE: + weights = np.var(y, axis=1) + return np.average(r2_values, weights=weights) # type: ignore[no-any-return] + raise ValueError( + f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].' + ) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 4e36e3cd47..84c9ac8f82 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -47,6 +47,7 @@ MetaKeys, Method, MetricReduction, + MultiOutput, NdimageMode, NumpyPadMode, OrderingTransformations, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 7838a2e741..05cc94500c 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -30,6 +30,7 @@ "NdimageMode", "GridSamplePadMode", "Average", + "MultiOutput", "MetricReduction", "LossReduction", "DiceCEReduction", @@ -230,6 +231,16 @@ class Average(StrEnum): NONE = "none" +class MultiOutput(StrEnum): + """ + See also: :py:func:`monai.metrics.r2_score.compute_r2_score` + """ + + RAW = "raw_values" + UNIFORM = "uniform_average" + VARIANCE = "variance_weighted" + + class MetricReduction(StrEnum): """ See also: :py:func:`monai.metrics.utils.do_metric_reduction` diff --git a/tests/test_compute_r2_score.py b/tests/test_compute_r2_score.py new file mode 100644 index 0000000000..0cea11cf47 --- /dev/null +++ b/tests/test_compute_r2_score.py @@ -0,0 +1,150 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import R2Metric, compute_r2_score + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" +TEST_CASE_1 = [ + torch.tensor([0.1, -0.25, 3.0, 0.99], device=_device), + torch.tensor([0.1, -0.2, -2.7, 1.58], device=_device), + "uniform_average", + 0, + -2.469944, +] + +TEST_CASE_2 = [ + torch.tensor([0.1, -0.25, 3.0, 0.99]), + torch.tensor([0.1, -0.2, 2.7, 1.58]), + "uniform_average", + 2, + 0.75828, +] + +TEST_CASE_3 = [ + torch.tensor([[0.1], [-0.25], [3.0], [0.99]]), + torch.tensor([[0.1], [-0.2], [2.7], [1.58]]), + "raw_values", + 2, + 0.75828, +] + +TEST_CASE_4 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "raw_values", + 1, + [0.87914, 0.844375], +] + +TEST_CASE_5 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "variance_weighted", + 1, + 0.867314, +] + +TEST_CASE_6 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + 0, + 0.907838, +] + +TEST_CASE_ERROR_1 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "abc", + 0, +] + +TEST_CASE_ERROR_2 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + -1, +] + +TEST_CASE_ERROR_3 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + np.array([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + 0, +] + +TEST_CASE_ERROR_4 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1]]), + "uniform_average", + 0, +] + +TEST_CASE_ERROR_5 = [ + torch.tensor([[[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]]), + torch.tensor([[[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]]), + "uniform_average", + 0, +] + +TEST_CASE_ERROR_6 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + 3, +] + +TEST_CASE_ERROR_7 = [torch.tensor([[0.1, 1.0]]), torch.tensor([[0.1, 0.82]]), "uniform_average", 0] + + +class TestComputeR2Score(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_value(self, y_pred, y, multi_output, p, expected_value): + result = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p) + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + @parameterized.expand( + [ + TEST_CASE_ERROR_1, + TEST_CASE_ERROR_2, + TEST_CASE_ERROR_3, + TEST_CASE_ERROR_4, + TEST_CASE_ERROR_5, + TEST_CASE_ERROR_6, + TEST_CASE_ERROR_7, + ] + ) + def test_error(self, y_pred, y, multi_output, p): + with self.assertRaises(ValueError): + _ = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_class_value(self, y_pred, y, multi_output, p, expected_value): + metric = R2Metric(multi_output=multi_output, p=p) + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + result = metric.aggregate(multi_output=multi_output) # test optional argument + metric.reset() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_r2_score.py b/tests/test_handler_r2_score.py new file mode 100644 index 0000000000..f2fa243719 --- /dev/null +++ b/tests/test_handler_r2_score.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch + +from monai.handlers import R2Score + + +class TestHandlerR2Score(unittest.TestCase): + + def test_compute(self): + r2_score = R2Score(multi_output="variance_weighted", p=1) + + y_pred = [torch.Tensor([0.1, 1.0]), torch.Tensor([-0.25, 0.5])] + y = [torch.Tensor([0.1, 0.82]), torch.Tensor([-0.2, 0.01])] + r2_score.update([y_pred, y]) + + y_pred = [torch.Tensor([3.0, -0.2]), torch.Tensor([0.99, 2.1])] + y = [torch.Tensor([2.7, -0.1]), torch.Tensor([1.58, 2.0])] + + r2_score.update([y_pred, y]) + + r2 = r2_score.compute() + np.testing.assert_allclose(0.867314, r2, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_r2_score_dist.py b/tests/test_handler_r2_score_dist.py new file mode 100644 index 0000000000..378989d555 --- /dev/null +++ b/tests/test_handler_r2_score_dist.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from monai.handlers import R2Score +from tests.utils import DistCall, DistTestCase + + +class DistributedR2Score(DistTestCase): + + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self): + r2_score = R2Score(multi_output="variance_weighted", p=1) + + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" + if dist.get_rank() == 0: + y_pred = [torch.tensor([0.1, 1.0], device=device), torch.tensor([-0.25, 0.5], device=device)] + y = [torch.tensor([0.1, 0.82], device=device), torch.tensor([-0.2, 0.01], device=device)] + + if dist.get_rank() == 1: + y_pred = [ + torch.tensor([3.0, -0.2], device=device), + torch.tensor([0.99, 2.1], device=device), + torch.tensor([-0.1, 0.0], device=device), + ] + y = [ + torch.tensor([2.7, -0.1], device=device), + torch.tensor([1.58, 2.0], device=device), + torch.tensor([-1.0, -0.1], device=device), + ] + + r2_score.update([y_pred, y]) + + result = r2_score.compute() + np.testing.assert_allclose(0.829185, result, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main()