-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify Dice, Jaccard and Tversky losses #8138
base: dev
Are you sure you want to change the base?
Conversation
Signed-off-by: Zifu Wang <[email protected]>
Hi @zifuwanggg thanks for the contribution. I have an issue with this change in that the behaviour of the losses is very different now as seen in the CICD errors. I would instead suggest adding new loss functions in a "soft_losses.py" file or something like that instead of changing existing losses. Other uses may rely on existing behaviour, and in situations where non-binarises values are accidentally used due to incorrect postprocessing there is less feedback about the problem. |
Hi @ericspod, thank you for reviewing my code. While adding new loss functions as separate .py files could be a workaround, my concern is that this approach would lead to a lot of duplicated code, as the core differences are only in 2-3 lines. Would it make sense to add an attribute to the existing loss classes and create a new helper function, so that the default behavior remains unchanged? Something like the following.
|
Hi @zifuwanggg I appreciate wanting to reduce duplicate code, we have too much of that in these loss functions as it stands so yes adding more isn't great. I think we can try to parameterise the loss functions in some way, either a function as you suggest or some other way, so long as the default behaviour is preserved. If you want to have a go at refactoring to do that we can return to it, I think in the future we do need to refactor all these loss functions to reduce duplication anyway. |
Signed-off-by: Zifu Wang <[email protected]>
Signed-off-by: Zifu Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a minor comment about the check in the helper function being expensive to compute, but otherwise we do also need tests for soft labels to ensure that formulation of the losses works. I do want to get others to review this as well to be doubly sure the changes are compatible. Thanks again.
if torch.unique(target).shape[0] > 2 and not soft_label: | ||
warnings.warn("soft labels are used, but `soft_label == False`.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little worried that torch.unique
is an expensive calculation to be making every time. It would help detect mistakes but it may not be worth having here for speed reasons.
@@ -89,6 +99,7 @@ def __init__( | |||
of the sequence should be the same as the number of classes. If not ``include_background``, | |||
the number of classes should not include the background category class 0). | |||
The value/values should be no less than 0. Defaults to None. | |||
soft_label: whether the target contains non-binary values or not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soft_label: whether the target contains non-binary values or not | |
soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. |
This clarifies a little bit I feel, the same should be done with the other modified losses.
Fixes #8094.
Description
The Dice, Jaccard and Tversky losses in
monai.losses.dice
andmonai.losses.tversky
are modified based on JDTLoss and segmentation_models.pytorch.In the original versions, when$\frac{|x|_p^p + |y|_p^p - |x-y|_p^p}{2}$ . When $p$ is 2 ($\langle x,y \rangle$ . When $p$ is 1 (
squared_pred=False
, the loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten assquared_pred=True
), this reformulation becomes the classical inner product:squared_pred=False
), the reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new versions are minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, they resolves the issue with soft labels [1, 2].In summary, there are three scenarios:
Due to these differences, particularly in Scenarios 2 and 3, some tests fail with the new versions:
test_multi_scale
test_dice_loss
,test_tversky_loss
,test_generalized_dice_loss
,test_masked_loss
,test_seg_loss_integration
The failures in
test_multi_scale
are expected since the original versions are incorrectly defined for non-binary targets. Furthermore, because Dice, Jaccard, and Tversky losses are fundamentally defined over probabilities—which should be nonnegative—the new versions should not be tested against negative input or target values.Example
References
[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.
[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.