Skip to content

Commit

Permalink
Fix generalized dice computation
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Pai <[email protected]>

Similar functionality to torchmetrics

Update

Lint and update sum_over_labels

Update docstring

Update docstring
  • Loading branch information
surajpaib committed Jul 31, 2024
1 parent f1ef3e8 commit 123c778
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 68 deletions.
112 changes: 58 additions & 54 deletions monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,109 +20,108 @@


class GeneralizedDiceScore(CumulativeIterationMetric):
"""Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
"""
Compute the Generalized Dice Score metric between tensors.
This metric is the complement of the Generalized Dice Loss defined in:
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
loss function for highly unbalanced segmentations. DLMIA 2017.
loss function for highly unbalanced segmentations. DLMIA 2017.
The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
or batch-first tensors, i.e., CHW[D] or BCHW[D].
The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
Args:
include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
include_background: Whether to include the background class (assumed to be in channel 0) in the
score computation. Defaults to True.
reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
reduction: Define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean"`, ``"sum"`}. Defaults to ``"mean"``.
If "none", will not do reduction.
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground truth volume into a weight factor. Defaults to ``"square"``.
Raises:
ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
ValueError: When the `reduction` is not one of MetricReduction enum.
"""

def __init__(
self,
include_background: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
weight_type: Weight | str = Weight.SQUARE,
self, include_background: bool = True, reduction: str = "mean", weight_type: Weight | str = Weight.SQUARE
) -> None:
super().__init__()
self.include_background = include_background
reduction_options = [
"none",
"mean_batch",
"sum_batch",
MetricReduction.NONE,
MetricReduction.MEAN_BATCH,
MetricReduction.SUM_BATCH,
]
self.reduction = reduction
if self.reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")
self.reduction = look_up_option(reduction, MetricReduction)
self.weight_type = look_up_option(weight_type, Weight)
self.sum_over_labels = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.
Args:
y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
y_pred: Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
y: Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
Returns:
torch.Tensor: Per batch and per class Generalized Dice Score.
Raises:
ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
"""
return compute_generalized_dice(
y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
y_pred=y_pred,
y=y,
include_background=self.include_background,
weight_type=self.weight_type,
sum_over_labels=self.sum_over_labels,
)

def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
def aggregate(self) -> torch.Tensor:
"""
Execute reduction logic for the output of `compute_generalized_dice`.
Args:
reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
Defaults to ``"mean"``. If "none", will not do reduction.
Returns:
torch.Tensor: Aggregated metric value.
Raises:
ValueError: If the data to aggregate is not a PyTorch Tensor.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("The data to aggregate must be a PyTorch Tensor.")

# Validate reduction argument if specified
if reduction is not None:
reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
if reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")

# Do metric reduction and return
f, _ = do_metric_reduction(data, reduction or self.reduction)
f, _ = do_metric_reduction(data, self.reduction)

return f


def compute_generalized_dice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
weight_type: Weight | str = Weight.SQUARE,
sum_over_labels: bool = False,
) -> torch.Tensor:
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.
Args:
y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
y_pred: Binarized segmentation model output. It should be binarized, in one-hot format
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background (bool, optional): whether to include score computation on the first channel of the
y: Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background: Whether to include score computation on the first channel of the
predicted output. Defaults to True.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
transform ground truth volume into a weight factor. Defaults to ``"square"``.
sum_over_labels: Whether to sum the numerator and denominator across all labels before the final computation.
Returns:
torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
Raises:
ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
or `y_pred` and `y` don't have the same shape.
"""
# Ensure tensors have at least 3 dimensions and have the same shape
Expand Down Expand Up @@ -158,16 +157,21 @@ def compute_generalized_dice(
b[infs] = 0
b[infs] = torch.max(b)

# Compute the weighted numerator and denominator, summing along the class axis
numer = 2.0 * (intersection * w).sum(dim=1)
denom = (denominator * w).sum(dim=1)
# Compute the weighted numerator and denominator, summing along the class axis when sum_over_labels is True
if sum_over_labels:
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
denom = (denominator * w).sum(dim=1, keepdim=True)
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
else:
numer = 2.0 * (intersection * w)
denom = denominator * w
y_pred_o = y_pred_o

# Compute the score
generalized_dice_score = numer / denom

# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
# Where denom == 0 but the prediction volume is not 0, score is 0
y_pred_o = y_pred_o.sum(dim=-1)
denom_zeros = denom == 0
generalized_dice_score[denom_zeros] = torch.where(
(y_pred_o == 0)[denom_zeros],
Expand Down
40 changes: 26 additions & 14 deletions tests/test_compute_generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_device = "cuda:0" if torch.cuda.is_available() else "cpu"

# keep background
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1)
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)
{
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
Expand All @@ -32,7 +32,7 @@
]

# remove background
TEST_CASE_2 = [ # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background)
TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background)
{
"y_pred": torch.tensor(
[
Expand All @@ -48,11 +48,11 @@
),
"include_background": False,
},
[0.1667, 0.6667],
[0.416667],
]

# should return 0 for both cases
TEST_CASE_3 = [
TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 3)
{
"y_pred": torch.tensor(
[
Expand All @@ -68,7 +68,7 @@
),
"include_background": True,
},
[0.0, 0.0],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
]

TEST_CASE_4 = [
Expand All @@ -87,11 +87,11 @@
]
),
},
[0.5455],
[0.678571, 0.2, 0.333333],
]

TEST_CASE_5 = [
{"include_background": True, "reduction": "sum_batch"},
{"include_background": True, "reduction": "sum"},
{
"y_pred": torch.tensor(
[
Expand All @@ -106,16 +106,28 @@
]
),
},
1.0455,
[1.045455],
]

TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]]
TEST_CASE_6 = [
{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))},
[[1.0000, 1.0000], [1.0000, 1.0000]],
]

TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]]
TEST_CASE_7 = [
{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))},
[[0.0000, 0.0000], [0.0000, 0.0000]],
]

TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]]
TEST_CASE_8 = [
{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))},
[[0.0000, 0.0000], [0.0000, 0.0000]],
]

TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]]
TEST_CASE_9 = [
{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))},
[[1.0000, 1.0000], [1.0000, 1.0000]],
]


class TestComputeGeneralizedDiceScore(unittest.TestCase):
Expand All @@ -126,7 +138,7 @@ def test_device(self, input_data, _expected_value):
np.testing.assert_equal(result.device, input_data["y_pred"].device)

# Functional part tests
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9])
@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9])
def test_value(self, input_data, expected_value):
result = compute_generalized_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
Expand All @@ -146,7 +158,7 @@ def test_value_class(self, input_data, expected_value):
vals["y"] = input_data.pop("y")
generalized_dice_score = GeneralizedDiceScore(**input_data)
generalized_dice_score(**vals)
result = generalized_dice_score.aggregate(reduction="none")
result = generalized_dice_score.aggregate()
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# Aggregation tests
Expand Down

0 comments on commit 123c778

Please sign in to comment.