Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Dice, Jaccard and Tversky losses #8138

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.losses.utils import compute_tp_fp_fn
from monai.networks import one_hot
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after

Expand All @@ -39,8 +40,16 @@ class DiceLoss(_Loss):
The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
the inter-over-union calculation to smooth results respectively, these values should be small.

The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric
Medical Image Segmentation, 3DV, 2016.
The original papers:

Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
Medical Image Segmentation. 3DV 2016.

Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
Soft Labels. NeurIPS 2023.

Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
Soft Labels. MICCAI 2023.

"""

Expand All @@ -58,6 +67,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
) -> None:
"""
Args:
Expand Down Expand Up @@ -89,6 +99,7 @@ def __init__(
of the sequence should be the same as the number of classes. If not ``include_background``,
the number of classes should not include the background category class 0).
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values or not
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
soft_label: whether the target contains non-binary values or not
soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used.

This clarifies a little bit I feel, the same should be done with the other modified losses.


Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -114,6 +125,7 @@ def __init__(
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.soft_label = soft_label

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -174,21 +186,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis

intersection = torch.sum(target * input, dim=reduce_axis)

if self.squared_pred:
ground_o = torch.sum(target**2, dim=reduce_axis)
pred_o = torch.sum(input**2, dim=reduce_axis)
else:
ground_o = torch.sum(target, dim=reduce_axis)
pred_o = torch.sum(input, dim=reduce_axis)

denominator = ground_o + pred_o

if self.jaccard:
denominator = 2.0 * (denominator - intersection)
ord = 2 if self.squared_pred else 1
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
if not self.jaccard:
fp *= 0.5
fn *= 0.5
numerator = 2 * tp + self.smooth_nr
denominator = 2 * (tp + fp + fn) + self.smooth_dr

f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)
f: torch.Tensor = 1 - numerator / denominator

num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
Expand Down Expand Up @@ -272,6 +278,7 @@ def __init__(
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
) -> None:
"""
Args:
Expand All @@ -295,6 +302,7 @@ def __init__(
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, intersection over union is computed from each item in the batch.
If True, the class-weighted intersection and union areas are first summed across the batches.
soft_label: whether the target contains non-binary values or not

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -319,6 +327,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label

def w_func(self, grnd):
if self.w_type == str(Weight.SIMPLE):
Expand Down Expand Up @@ -370,13 +379,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
reduce_axis = [0] + reduce_axis
intersection = torch.sum(target * input, reduce_axis)

ground_o = torch.sum(target, reduce_axis)
pred_o = torch.sum(input, reduce_axis)

denominator = ground_o + pred_o
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)
fp *= 0.5
fn *= 0.5
denominator = 2 * (tp + fp + fn)

ground_o = torch.sum(target, reduce_axis)
w = self.w_func(ground_o.float())
infs = torch.isinf(w)
if self.batch:
Expand All @@ -388,7 +397,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
w = w + infs * max_values

final_reduce_dim = 0 if self.batch else 1
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)

Expand Down
18 changes: 10 additions & 8 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from torch.nn.modules.loss import _Loss

from monai.losses.utils import compute_tp_fp_fn
from monai.networks import one_hot
from monai.utils import LossReduction

Expand All @@ -28,6 +29,9 @@ class TverskyLoss(_Loss):
Sadegh et al. (2017) Tversky loss function for image segmentation
using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721)

Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
Soft Labels. MICCAI 2023.

Adapted from:
https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631

Expand All @@ -46,6 +50,7 @@ def __init__(
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
) -> None:
"""
Args:
Expand All @@ -70,6 +75,7 @@ def __init__(
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
soft_label: whether the target contains non-binary values or not

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -93,6 +99,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -134,20 +141,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

p0 = input
p1 = 1 - p0
g0 = target
g1 = 1 - g0

# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis

tp = torch.sum(p0 * g0, reduce_axis)
fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
fn = self.beta * torch.sum(p1 * g0, reduce_axis)
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)
fp *= self.alpha
fn *= self.beta
numerator = tp + self.smooth_nr
denominator = tp + fp + fn + self.smooth_dr

Expand Down
60 changes: 60 additions & 0 deletions monai/losses/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
import torch.linalg as LA


def compute_tp_fp_fn(
input: torch.Tensor,
target: torch.Tensor,
reduce_axis: list[int],
ord: int,
soft_label: bool,
decoupled: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Adapted from:
https://github.com/zifuwanggg/JDTLosses
"""
if torch.unique(target).shape[0] > 2 and not soft_label:
warnings.warn("soft labels are used, but `soft_label == False`.")
Comment on lines +32 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried that torch.unique is an expensive calculation to be making every time. It would help detect mistakes but it may not be worth having here for speed reasons.


# the original implementation that is erroneous with soft labels
if ord == 1 and not soft_label:
tp = torch.sum(input * target, dim=reduce_axis)
# the original implementation of Dice and Jaccard loss
if decoupled:
fp = torch.sum(input, dim=reduce_axis) - tp
fn = torch.sum(target, dim=reduce_axis) - tp
# the original implementation of Tversky loss
else:
fp = torch.sum(input * (1 - target), dim=reduce_axis)
fn = torch.sum((1 - input) * target, dim=reduce_axis)
else:
pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)

if ord > 1:
pred_o = torch.pow(pred_o, exponent=ord)
ground_o = torch.pow(ground_o, exponent=ord)
difference = torch.pow(difference, exponent=ord)

tp = (pred_o + ground_o - difference) / 2
fp = pred_o - tp
fn = ground_o - tp

return tp, fp, fn
Loading