Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for BCEWithLogitsLoss in DiceCELoss #6924

Merged
merged 8 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,8 @@ class DiceCELoss(_Loss):
"""
Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses.
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss``. In this implementation,
two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss`` and ``torch.nn.BCEWithLogitsLoss()``.
In this implementation, two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
not supported.

"""
Expand Down Expand Up @@ -646,11 +646,11 @@ def __init__(
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `CrossEntropyLoss`.
don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `CrossEntropyLoss`.
don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss` and `BCEWithLogitsLoss`.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"mean"``, ``"sum"``}
Expand All @@ -666,8 +666,9 @@ def __init__(
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
ce_weight: a rescaling weight given to each class for cross entropy loss.
See ``torch.nn.CrossEntropyLoss()`` for more information.
ce_weight: a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`.
or a rescaling weight given to the loss of each batch element for `BCEWithLogitsLoss`.
See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information.
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
Defaults to 1.0.
lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
Expand All @@ -690,6 +691,7 @@ def __init__(
batch=batch,
)
self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction=reduction)
self.binary_cross_entropy = nn.BCEWithLogitsLoss(weight=ce_weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_ce < 0.0:
Expand All @@ -700,7 +702,7 @@ def __init__(

def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute CrossEntropy loss for the input and target.
Compute CrossEntropy loss for the input logits and target.
Will remove the channel dim according to PyTorch CrossEntropyLoss:
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss.

Expand All @@ -720,6 +722,16 @@ def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

return self.cross_entropy(input, target) # type: ignore[no-any-return]

def bce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute Binary CrossEntropy loss for the input logits and target in one single class.

"""
if not torch.is_floating_point(target):
target = target.to(dtype=input.dtype)

return self.binary_cross_entropy(input, target) # type: ignore[no-any-return]

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand All @@ -738,7 +750,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
)

dice_loss = self.dice(input, target)
ce_loss = self.ce(input, target)
ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss

return total_loss
Expand Down
10 changes: 9 additions & 1 deletion tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@
},
0.3133,
],
[ # shape: (2, 1, 3), (2, 1, 3), bceloss
{"ce_weight": torch.tensor([1.0, 1.0, 1.0]), "sigmoid": True},
{
"input": torch.tensor([[[0.8, 0.6, 0.0]], [[0.0, 0.0, 0.9]]]),
"target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),
},
1.5608,
],
]


Expand All @@ -97,7 +105,7 @@ def test_ill_reduction(self):

def test_script(self):
loss = DiceCELoss()
test_input = torch.ones(2, 1, 8, 8)
test_input = torch.ones(2, 2, 8, 8)
test_script_save(loss, test_input, test_input)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_ds_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_ill_reduction(self):
@SkipIfBeforePyTorchVersion((1, 10))
def test_script(self):
loss = DeepSupervisionLoss(DiceCELoss())
test_input = torch.ones(2, 1, 8, 8)
test_input = torch.ones(2, 2, 8, 8)
test_script_save(loss, test_input, test_input)


Expand Down
Loading