Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Dice, Jaccard and Tversky losses #8138

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from

Conversation

zifuwanggg
Copy link

@zifuwanggg zifuwanggg commented Oct 10, 2024

Fixes #8094.

Description

The Dice, Jaccard and Tversky losses in monai.losses.dice and monai.losses.tversky are modified based on JDTLoss and segmentation_models.pytorch.

In the original versions, when squared_pred=False, the loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{|x|_p^p + |y|_p^p - |x-y|_p^p}{2}$. When $p$ is 2 (squared_pred=True), this reformulation becomes the classical inner product: $\langle x,y \rangle$. When $p$ is 1 (squared_pred=False), the reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new versions are minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, they resolves the issue with soft labels [1, 2].

In summary, there are three scenarios:

  • [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions are the same as the original versions.
  • [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect.
  • [Scenario 3] Either $x$ or $y$ is negative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect.

Due to these differences, particularly in Scenarios 2 and 3, some tests fail with the new versions:

  • The target is non-binary: test_multi_scale
  • The input is negative: test_dice_loss, test_tversky_loss, test_generalized_dice_loss, test_masked_loss, test_seg_loss_integration

The failures in test_multi_scale are expected since the original versions are incorrectly defined for non-binary targets. Furthermore, because Dice, Jaccard, and Tversky losses are fundamentally defined over probabilities—which should be nonnegative—the new versions should not be tested against negative input or target values.

Example

import torch
import torch.linalg as LA
import torch.nn.functional as F

torch.manual_seed(0)

b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)

pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float()

def dice_old(x, y, ord, dims):
    cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
    intersection = torch.sum(x * y, dim=dims)
    return 2 * intersection / cardinality

def dice_new(x, y, ord, dims):
    cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
    difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality

print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims))
print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims))
print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims))

print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims))
print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims))
print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims))

# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])

# tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935])
# tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503])
# tensor([1., 1., 1.]) tensor([1., 1., 1.])

References

[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@ericspod
Copy link
Member

Hi @zifuwanggg thanks for the contribution. I have an issue with this change in that the behaviour of the losses is very different now as seen in the CICD errors. I would instead suggest adding new loss functions in a "soft_losses.py" file or something like that instead of changing existing losses. Other uses may rely on existing behaviour, and in situations where non-binarises values are accidentally used due to incorrect postprocessing there is less feedback about the problem.

@zifuwanggg
Copy link
Author

Hi @ericspod, thank you for reviewing my code. While adding new loss functions as separate .py files could be a workaround, my concern is that this approach would lead to a lot of duplicated code, as the core differences are only in 2-3 lines.

Would it make sense to add an attribute to the existing loss classes and create a new helper function, so that the default behavior remains unchanged? Something like the following.

class DiceLoss(_Loss):
    def __init__(
        ...
        binary_label: bool = True,
    ):
        ...
        self.binary_label = binary_label

    def forward(...):
        ...
        f = compute_score(self.binary_label)
        ...


class GeneralizedDiceLoss(_Loss):
    def __init__(
        ...
        binary_label: bool = True,
    ):
        ...
        self.binary_label = binary_label
    
    def forward(...):
        ...
        f = compute_score(self.binary_label)
        ...


def compute_score(binary_label):
    if binary_label == True:
        ...
    else:
        ...

@ericspod
Copy link
Member

Hi @ericspod, thank you for reviewing my code. While adding new loss functions as separate .py files could be a workaround, my concern is that this approach would lead to a lot of duplicated code, as the core differences are only in 2-3 lines.

Hi @zifuwanggg I appreciate wanting to reduce duplicate code, we have too much of that in these loss functions as it stands so yes adding more isn't great. I think we can try to parameterise the loss functions in some way, either a function as you suggest or some other way, so long as the default behaviour is preserved. If you want to have a go at refactoring to do that we can return to it, I think in the future we do need to refactor all these loss functions to reduce duplication anyway.

@zifuwanggg
Copy link
Author

zifuwanggg commented Oct 21, 2024

Hi @ericspod, I've created losses/utils.py and put a helper function that is shared by both dice.py and tversky.py.

Unit tests pass, but mypy tests fail. This seems related to #8149 and #8161.

@zifuwanggg
Copy link
Author

Hi @ericspod, all CICD tests pass. @KumoLiu, thanks for the commit.

Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a minor comment about the check in the helper function being expensive to compute, but otherwise we do also need tests for soft labels to ensure that formulation of the losses works. I do want to get others to review this as well to be doubly sure the changes are compatible. Thanks again.

Comment on lines +32 to +33
if torch.unique(target).shape[0] > 2 and not soft_label:
warnings.warn("soft labels are used, but `soft_label == False`.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried that torch.unique is an expensive calculation to be making every time. It would help detect mistakes but it may not be worth having here for speed reasons.

@@ -89,6 +99,7 @@ def __init__(
of the sequence should be the same as the number of classes. If not ``include_background``,
the number of classes should not include the background category class 0).
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values or not
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
soft_label: whether the target contains non-binary values or not
soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used.

This clarifies a little bit I feel, the same should be done with the other modified losses.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Jaccard, Dice and Tversky losses are incompatible with soft labels
3 participants