diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 4f53e3a7d8..214265499c 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -614,8 +614,8 @@ class DiceCELoss(_Loss): """ Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses. The details of Dice loss is shown in ``monai.losses.DiceLoss``. - The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss``. In this implementation, - two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are + The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss`` and ``torch.nn.BCEWithLogitsLoss()``. + In this implementation, two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are not supported. """ @@ -646,11 +646,11 @@ def __init__( to_onehot_y: whether to convert the ``target`` into the one-hot format, using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`, - don't need to specify activation function for `CrossEntropyLoss`. + don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`. softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`, - don't need to specify activation function for `CrossEntropyLoss`. + don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`. other_act: callable function to execute other activation layers, Defaults to ``None``. for example: - ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`. + ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss` and `BCEWithLogitsLoss`. squared_pred: use squared versions of targets and predictions in the denominator or not. jaccard: compute Jaccard Index (soft IoU) instead of dice or not. reduction: {``"mean"``, ``"sum"``} @@ -666,8 +666,9 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. - ce_weight: a rescaling weight given to each class for cross entropy loss. - See ``torch.nn.CrossEntropyLoss()`` for more information. + ce_weight: a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`. + or a rescaling weight given to the loss of each batch element for `BCEWithLogitsLoss`. + See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information. lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0. @@ -690,6 +691,7 @@ def __init__( batch=batch, ) self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction=reduction) + self.binary_cross_entropy = nn.BCEWithLogitsLoss(weight=ce_weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_ce < 0.0: @@ -700,7 +702,7 @@ def __init__( def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Compute CrossEntropy loss for the input and target. + Compute CrossEntropy loss for the input logits and target. Will remove the channel dim according to PyTorch CrossEntropyLoss: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. @@ -720,6 +722,16 @@ def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return self.cross_entropy(input, target) # type: ignore[no-any-return] + def bce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute Binary CrossEntropy loss for the input logits and target in one single class. + + """ + if not torch.is_floating_point(target): + target = target.to(dtype=input.dtype) + + return self.binary_cross_entropy(input, target) # type: ignore[no-any-return] + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: @@ -738,7 +750,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ) dice_loss = self.dice(input, target) - ce_loss = self.ce(input, target) + ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss return total_loss diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 13b4952ab3..334bcc946b 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -75,6 +75,14 @@ }, 0.3133, ], + [ # shape: (2, 1, 3), (2, 1, 3), bceloss + {"ce_weight": torch.tensor([1.0, 1.0, 1.0]), "sigmoid": True}, + { + "input": torch.tensor([[[0.8, 0.6, 0.0]], [[0.0, 0.0, 0.9]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 1.5608, + ], ] @@ -97,7 +105,7 @@ def test_ill_reduction(self): def test_script(self): loss = DiceCELoss() - test_input = torch.ones(2, 1, 8, 8) + test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input) diff --git a/tests/test_ds_loss.py b/tests/test_ds_loss.py index 51200d9584..de7aec1ced 100644 --- a/tests/test_ds_loss.py +++ b/tests/test_ds_loss.py @@ -154,7 +154,7 @@ def test_ill_reduction(self): @SkipIfBeforePyTorchVersion((1, 10)) def test_script(self): loss = DeepSupervisionLoss(DiceCELoss()) - test_input = torch.ones(2, 1, 8, 8) + test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input)