Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add R2 score to metrics #8093

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ Panoptic Quality metrics handler
:members:


:math:`R^{2}` score
-------------------
.. autoclass:: R2Score
:members:


Mean squared error metrics handler
----------------------------------
.. autoclass:: MeanSquaredError
Expand Down
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ Metrics
.. autoclass:: PanopticQualityMetric
:members:

:math:`R^{2}` score
-------------------
.. autofunction:: compute_r2_score

.. autoclass:: R2Metric
:members:

`Mean squared error`
--------------------
.. autoclass:: MSEMetric
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .parameter_scheduler import ParamSchedulerHandler
from .postprocessing import PostProcessing
from .probability_maps import ProbMapProducer
from .r2_score import R2Score
from .regression_metrics import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError
from .roc_auc import ROCAUC
from .smartcache_handler import SmartCacheHandler
Expand Down
56 changes: 56 additions & 0 deletions monai/handlers/r2_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import R2Metric
from monai.utils import MultiOutput


class R2Score(IgniteMetricHandler):
"""
Computes :math:`R^{2}` score accumulating predictions and the ground-truth during an epoch and applying `compute_r2_score`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As elsewhere it would help to explain where this is used and what it would be for. The math is explained in the function which is fine so no need here, just a high level idea.


Args:
multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
Type of aggregation performed on multi-output scores.
Defaults to ``"uniform_average"``.

- ``"raw_values"``: the scores for each output are returned.
- ``"uniform_average"``: the scores of all outputs are averaged with uniform weight.
- ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances
of each individual output.
p: non-negative integer.
Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score.
Defaults to 0 (standard :math:`R^{2}` score).
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. The form of `(y_pred, y)` is required by the `update()`.
`engine.state` and `output_transform` inherit from the ignite concept:
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.

See also:
:py:class:`monai.metrics.R2Metric`

"""

def __init__(
self,
multi_output: MultiOutput | str = MultiOutput.UNIFORM,
p: int = 0,
output_transform: Callable = lambda x: x,
) -> None:
metric_fn = R2Metric(multi_output=multi_output, p=p)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
1 change: 1 addition & 0 deletions monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
from .mmd import MMDMetric, compute_mmd
from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality
from .r2_score import R2Metric, compute_r2_score
from .regression import (
MAEMetric,
MSEMetric,
Expand Down
184 changes: 184 additions & 0 deletions monai/metrics/r2_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
import numpy.typing as npt

import torch

from monai.utils import MultiOutput, look_up_option

from .metric import CumulativeIterationMetric


class R2Metric(CumulativeIterationMetric):
r"""Computes :math:`R^{2}` score (coefficient of determination):

.. math::
\operatorname {R^{2}}\left(Y, \hat{Y}\right) = 1 - \frac {\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}}
{\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}},

where :math:`\bar{y}` is the mean of observed :math:`y` ; or adjusted :math:`R^{2}` score:

.. math::
\operatorname {\bar{R}^{2}} = 1 - (1-R^{2}) \frac {n-1}{n-p-1},

where :math:`p` is the number of independant variables used for the regression.

More info: https://en.wikipedia.org/wiki/Coefficient_of_determination

Input `y_pred` is compared with ground truth `y`.
`y_pred` and `y` are expected to be 1D (single-output regression) or 2D (multi-output regression) real-valued
tensors of same shape.

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Args:
multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
Type of aggregation performed on multi-output scores.
Defaults to ``"uniform_average"``.

- ``"raw_values"``: the scores for each output are returned.
- ``"uniform_average"``: the scores of all outputs are averaged with uniform weight.
- ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of
each individual output.
p: non-negative integer.
Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score.
Defaults to 0 (standard :math:`R^{2}` score).

"""

def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0) -> None:
super().__init__()
multi_output, p = _check_r2_params(multi_output, p)
self.multi_output = multi_output
self.p = p

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
_check_dim(y_pred, y)
return y_pred, y

def aggregate(self, multi_output: MultiOutput | str | None = None) -> np.ndarray | float | npt.ArrayLike:
"""
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
This function reads the buffers and computes the :math:`R^{2}` score.

Args:
multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
Type of aggregation performed on multi-output scores. Defaults to `self.multi_output`.

"""
y_pred, y = self.get_buffer()
return compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output or self.multi_output, p=self.p)


def _check_dim(y_pred: torch.Tensor, y: torch.Tensor) -> None:
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
raise ValueError("y_pred and y must be PyTorch Tensor.")

if y.shape != y_pred.shape:
raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")

dim = y.ndimension()
if dim not in (1, 2):
raise ValueError(
f"predictions and ground truths should be of shape (batch_size, num_outputs) or (batch_size, ), got {y.shape}."
)


def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutput | str, int]:
multi_output = look_up_option(multi_output, MultiOutput)
if not isinstance(p, int) or p < 0:
raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.")

return multi_output, p


def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float:
num_obs = len(y)
rss = np.sum((y_pred - y) ** 2)
tss = np.sum(y**2) - np.sum(y) ** 2 / num_obs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be tss = np.sum((y - np.mean(y)) ** 2) here?

r2 = 1 - (rss / tss)
r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1)

return r2_adjusted # type: ignore[no-any-return]


def compute_r2_score(
y_pred: torch.Tensor, y: torch.Tensor, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0
) -> np.ndarray | float | npt.ArrayLike:
"""Computes :math:`R^{2}` score (coefficient of determination).

Args:
y_pred: input data to compute :math:`R^{2}` score, the first dim must be batch.
For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables.
y: ground truth to compute :math:`R^{2}` score, the first dim must be batch.
For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables.
multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
Type of aggregation performed on multi-output scores.
Defaults to ``"uniform_average"``.

- ``"raw_values"``: the scores for each output are returned.
- ``"uniform_average"``: the scores of all outputs are averaged with uniform weight.
- ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances
each individual output.
p: non-negative integer.
Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score.
Defaults to 0 (standard :math:`R^{2}` score).

Raises:
ValueError: When ``multi_output`` is not one of ["raw_values", "uniform_average", "variance_weighted"].
ValueError: When ``p`` is not a non-negative integer.
ValueError: When ``y_pred`` or ``y`` are not PyTorch tensors.
ValueError: When ``y_pred`` and ``y`` don't have the same shape.
ValueError: When ``y_pred`` or ``y`` dimension is not one of [1, 2].
ValueError: When n_samples is less than 2.
ValueError: When ``p`` is greater or equal to n_samples - 1.

"""
multi_output, p = _check_r2_params(multi_output, p)
_check_dim(y_pred, y)
dim = y.ndimension()
n = y.shape[0]
y = y.cpu().numpy() # type: ignore[assignment]
y_pred = y_pred.cpu().numpy() # type: ignore[assignment]

if n < 2:
raise ValueError("There is no enough data for computing. Needs at least two samples to calculate r2 score.")
if p >= n - 1:
raise ValueError("`p` must be smaller than n_samples - 1, " f"got p={p}, n_samples={n}.")

if dim == 2 and y_pred.shape[1] == 1:
y_pred = np.squeeze(y_pred, axis=-1) # type: ignore[assignment]
y = np.squeeze(y, axis=-1) # type: ignore[assignment]
dim = 1

if dim == 1:
return _calculate(y_pred, y, p) # type: ignore[arg-type]

y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0)) # type: ignore[assignment]
r2_values = [_calculate(y_pred_, y_, p) for y_pred_, y_ in zip(y_pred, y)]
if multi_output == MultiOutput.RAW:
return r2_values
if multi_output == MultiOutput.UNIFORM:
return np.mean(r2_values)
if multi_output == MultiOutput.VARIANCE:
weights = np.var(y, axis=1)
return np.average(r2_values, weights=weights) # type: ignore[no-any-return]
raise ValueError(
f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].'
)
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
MetaKeys,
Method,
MetricReduction,
MultiOutput,
NdimageMode,
NumpyPadMode,
OrderingTransformations,
Expand Down
11 changes: 11 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"NdimageMode",
"GridSamplePadMode",
"Average",
"MultiOutput",
"MetricReduction",
"LossReduction",
"DiceCEReduction",
Expand Down Expand Up @@ -230,6 +231,16 @@ class Average(StrEnum):
NONE = "none"


class MultiOutput(StrEnum):
"""
See also: :py:func:`monai.metrics.r2_score.compute_r2_score`
"""

RAW = "raw_values"
UNIFORM = "uniform_average"
VARIANCE = "variance_weighted"


class MetricReduction(StrEnum):
"""
See also: :py:func:`monai.metrics.utils.do_metric_reduction`
Expand Down
Loading
Loading