From bfb1ec75f78c96fc96b005710203539f9dbb3515 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 1 Sep 2023 11:07:25 +0800 Subject: [PATCH 1/6] fix #6923 Signed-off-by: KumoLiu --- monai/losses/dice.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 4f53e3a7d8..4068098e2f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -729,6 +729,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When number of channels for input is equal to 1. """ if len(input.shape) != len(target.shape): @@ -736,6 +737,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: "the number of dimensions for input and target should be the same, " f"got shape {input.shape} and {target.shape}." ) + if input.shape[1] == 1: + raise ValueError( + "the number of channels for input should be larger than 1," + f"got shape {input.shape}." + ) dice_loss = self.dice(input, target) ce_loss = self.ce(input, target) From b1fb99f406edb92949a1d0bc362dfd68ba82d504 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 1 Sep 2023 11:25:17 +0800 Subject: [PATCH 2/6] fix flake8 Signed-off-by: KumoLiu --- monai/losses/dice.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 4068098e2f..9fdba7f37b 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -738,10 +738,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f"got shape {input.shape} and {target.shape}." ) if input.shape[1] == 1: - raise ValueError( - "the number of channels for input should be larger than 1," - f"got shape {input.shape}." - ) + raise ValueError("the number of channels for input should be larger than 1," f"got shape {input.shape}.") dice_loss = self.dice(input, target) ce_loss = self.ce(input, target) From 954332c684b7c6f70fa18338f093202826cf96de Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 1 Sep 2023 13:41:04 +0800 Subject: [PATCH 3/6] fix ci Signed-off-by: KumoLiu --- tests/test_dice_ce_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 13b4952ab3..aded5fdcb7 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -97,7 +97,7 @@ def test_ill_reduction(self): def test_script(self): loss = DiceCELoss() - test_input = torch.ones(2, 1, 8, 8) + test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input) From c8f0da92c68276dac2f80789b8ed6aae3c122763 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 1 Sep 2023 14:00:52 +0800 Subject: [PATCH 4/6] fix ci Signed-off-by: KumoLiu --- tests/test_ds_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ds_loss.py b/tests/test_ds_loss.py index 51200d9584..de7aec1ced 100644 --- a/tests/test_ds_loss.py +++ b/tests/test_ds_loss.py @@ -154,7 +154,7 @@ def test_ill_reduction(self): @SkipIfBeforePyTorchVersion((1, 10)) def test_script(self): loss = DeepSupervisionLoss(DiceCELoss()) - test_input = torch.ones(2, 1, 8, 8) + test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input) From cac5e5213ee56cd304a7c16ef9d78914bb1bf7fd Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 1 Sep 2023 15:06:55 +0800 Subject: [PATCH 5/6] add support for BCEWithLogitsLoss Signed-off-by: KumoLiu --- monai/losses/dice.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 9fdba7f37b..c76e5c9b1c 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -614,8 +614,8 @@ class DiceCELoss(_Loss): """ Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses. The details of Dice loss is shown in ``monai.losses.DiceLoss``. - The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss``. In this implementation, - two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are + The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss`` and ``torch.nn.BCEWithLogitsLoss()``. + In this implementation, two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are not supported. """ @@ -646,11 +646,11 @@ def __init__( to_onehot_y: whether to convert the ``target`` into the one-hot format, using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`, - don't need to specify activation function for `CrossEntropyLoss`. + don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`. softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`, - don't need to specify activation function for `CrossEntropyLoss`. + don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`. other_act: callable function to execute other activation layers, Defaults to ``None``. for example: - ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`. + ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss` and `BCEWithLogitsLoss`. squared_pred: use squared versions of targets and predictions in the denominator or not. jaccard: compute Jaccard Index (soft IoU) instead of dice or not. reduction: {``"mean"``, ``"sum"``} @@ -667,7 +667,7 @@ def __init__( Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. ce_weight: a rescaling weight given to each class for cross entropy loss. - See ``torch.nn.CrossEntropyLoss()`` for more information. + See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information. lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0. @@ -690,6 +690,7 @@ def __init__( batch=batch, ) self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction=reduction) + self.binary_cross_entropy = nn.BCEWithLogitsLoss(weight=ce_weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_ce < 0.0: @@ -700,7 +701,7 @@ def __init__( def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Compute CrossEntropy loss for the input and target. + Compute CrossEntropy loss for the input logits and target. Will remove the channel dim according to PyTorch CrossEntropyLoss: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. @@ -720,6 +721,16 @@ def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return self.cross_entropy(input, target) # type: ignore[no-any-return] + def bce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute Binary CrossEntropy loss for the input logits and target in one single class. + + """ + if not torch.is_floating_point(target): + target = target.to(dtype=input.dtype) + + return self.binary_cross_entropy(input, target) + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: @@ -729,7 +740,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. - ValueError: When number of channels for input is equal to 1. """ if len(input.shape) != len(target.shape): @@ -737,11 +747,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: "the number of dimensions for input and target should be the same, " f"got shape {input.shape} and {target.shape}." ) - if input.shape[1] == 1: - raise ValueError("the number of channels for input should be larger than 1," f"got shape {input.shape}.") dice_loss = self.dice(input, target) - ce_loss = self.ce(input, target) + ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss return total_loss From a7a690ffb781c806b73500e68ba05be6fcdfd11a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 1 Sep 2023 15:28:20 +0800 Subject: [PATCH 6/6] add unittests Signed-off-by: KumoLiu --- monai/losses/dice.py | 5 +++-- tests/test_dice_ce_loss.py | 8 ++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index c76e5c9b1c..214265499c 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -666,7 +666,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. - ce_weight: a rescaling weight given to each class for cross entropy loss. + ce_weight: a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`. + or a rescaling weight given to the loss of each batch element for `BCEWithLogitsLoss`. See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information. lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. @@ -729,7 +730,7 @@ def bce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if not torch.is_floating_point(target): target = target.to(dtype=input.dtype) - return self.binary_cross_entropy(input, target) + return self.binary_cross_entropy(input, target) # type: ignore[no-any-return] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index aded5fdcb7..334bcc946b 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -75,6 +75,14 @@ }, 0.3133, ], + [ # shape: (2, 1, 3), (2, 1, 3), bceloss + {"ce_weight": torch.tensor([1.0, 1.0, 1.0]), "sigmoid": True}, + { + "input": torch.tensor([[[0.8, 0.6, 0.0]], [[0.0, 0.0, 0.9]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 1.5608, + ], ]