You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
The SoftDiceclDiceLoss implementation is different from Dice loss and in its current form could not be switched with Dice or other popular losses offered. There is no option for excluding background, applying activation functions etc. Also there is no description of the expected input (y_pred)- can this be probabilities or logits or should it be a binary mask similar to the ground truth (y_true).
To Reproduce
Use this enhanced class that mimics other MONAI losses:
from torch.nn.modules.loss import _Loss
from monai.networks import one_hot
from monai.losses import SoftDiceclDiceLoss
import warnings
class EnhancedSoftDiceClDiceLoss(_Loss):
"""
Enhanced version of SoftDiceClDiceLoss with support for:
- Excluding background channel
- Applying activations (sigmoid, softmax, or custom)
- Handling one-hot encoded targets
- Flexible reduction (mean, sum, none)
"""
def __init__(
self,
iter_: int = 3,
alpha: float = 0.5,
smooth: float = 1.0,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: callable | None = None,
reduction: str = "mean",
) -> None:
"""
Args:
iter_: Number of iterations for skeletonization
smooth: Smoothing parameter
alpha: Weighing factor for cldice
include_background: If False, excludes the background channel from the loss computation.
to_onehot_y: If True, converts `y` into one-hot format. Defaults to False.
sigmoid: If True, applies sigmoid activation to predictions.
softmax: If True, applies softmax activation to predictions.
other_act: Callable function for custom activation (e.g., torch.tanh).
threshold: Threshold value for discretization (applies after sigmoid).
argmax: If True, applies argmax for discretization (applies after softmax).
reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction applied to the output.
"""
super().__init__()
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act
self.reduction = reduction.lower()
# Validate activation settings
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
raise ValueError("Only one of [sigmoid=True, softmax=True, other_act] can be set.")
if self.reduction not in ["mean", "sum", "none"]:
raise ValueError(f"Unsupported reduction mode: {self.reduction}")
# Create an instance of the original SoftDiceclDiceLoss
self.base_loss = SoftDiceclDiceLoss(iter_=iter_, alpha=alpha, smooth=smooth)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: Predicted tensor with shape [B, C, H, W, ...].
y_true: Ground truth tensor with shape [B, C, H, W, ...].
Returns:
Computed loss value.
"""
n_pred_ch = y_pred.shape[1]
# Convert ground truth to one-hot if necessary
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("Single channel prediction, to_onehot_y=True ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
# Exclude background channel if specified
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("Single channel prediction, include_background=False ignored.")
else:
y_pred = y_pred[:, 1:]
y_true = y_true[:, 1:]
# Apply activation if specified
if self.sigmoid:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.sigmoid((y_pred - 0.5) * 10) # Differentiable approximation
elif self.softmax:
if y_pred.shape[1] == 1:
warnings.warn("Single channel prediction, softmax=True ignored and sigmoid applied")
y_pred = torch.sigmoid(y_pred)
else:
y_pred = torch.softmax(y_pred, dim=1)
elif self.other_act is not None:
y_pred = self.other_act(y_pred)
# Ensure shapes match
if y_pred.shape != y_true.shape:
raise AssertionError(f"Shape mismatch: y_pred {y_pred.shape}, y_true {y_true.shape}")
# Delegate loss computation to the original SoftDiceclDiceLoss
loss = self.base_loss(y_true, y_pred)
# Apply reduction if necessary
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
Expected behavior
The loss calculated is zero, even after using a custom class with enhancements, need input on how to avoid zero losses.
Screenshots
Environment
Ensuring you use the relevant python executable, please paste the output of:
/root/.cache/pypoetry/virtualenvs/segmentation-codebase-os60uNmW-py3.10/lib/python3.10/site-packages/_distutils_hack/__init__.py:55: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml
warnings.warn(
/root/.cache/pypoetry/virtualenvs/segmentation-codebase-os60uNmW-py3.10/lib/python3.10/site-packages/ignite/handlers/checkpoint.py:17: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
from torch.distributed.optim import ZeroRedundancyOptimizer
================================
Printing MONAI config...
================================
MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.4.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /<username>/.cache/pypoetry/virtualenvs/segmentation-codebase-os60uNmW-py3.10/lib/python3.10/site-packages/monai/__init__.py
Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.0
Nibabel version: 5.2.1
scikit-image version: 0.24.0
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.17.0
gdown version: 5.2.0
TorchVision version: 0.19.0+cu121
tqdm version: 4.66.5
lmdb version: 1.5.1
psutil version: 6.0.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.15.1
pynrrd version: 1.0.0
clearml version: NOT INSTALLED or UNKNOWN VERSION.
For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 22.04.2 LTS
Platform: Linux-5.4.0-169-generic-x86_64-with-glibc2.35
Processor: x86_64
Machine: x86_64
Python version: 3.10.12
Process name: pt_main_thread
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 128
Num logical CPUs: 256
Num usable CPUs: 256
CPU usage (%): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 2.8, 1.7, 0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.8, 2.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 3.3, 0.0, 0.0, 3.3, 3.1, 3.3, 2.8, 0.0, 0.0, 0.0, 3.1, 3.1, 1.1, 3.3, 0.0, 2.8, 3.1, 3.1, 0.0, 3.1, 0.0, 0.0, 0.0, 0.6, 0.0, 0.3, 0.0, 5.0, 0.0, 0.0, 3.3, 0.0, 0.0, 2.8, 0.0, 3.1, 2.8, 2.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 3.3, 2.8, 3.3, 3.1, 3.9, 3.3, 3.9, 3.4, 3.3, 3.6, 3.6, 0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.1, 0.0, 0.3, 0.0, 0.0, 0.0, 3.1, 0.0, 3.3, 3.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.1, 0.0, 3.1, 0.0, 0.0, 0.0, 3.3, 0.0, 0.0, 3.3, 3.3, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.9, 0.0, 0.0, 3.3, 3.6, 3.3, 3.3, 3.1, 3.3, 3.6, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 0.0, 1.9, 99.2]
CPU freq. (MHz): 2924
Load avg. in last 1, 5, 15 mins (%): [0.1, 0.1, 0.2]
Disk usage (%): 95.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.7
Available memory (GB): 965.3
Used memory (GB): 34.6
================================
Printing GPU config...
================================
Num GPUs: 8
Has CUDA: True
CUDA version: 12.1
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: 1
cuDNN version: 90100
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A100-SXM4-40GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 108
GPU 0 Total memory (GB): 39.4
GPU 0 CUDA capability (maj.min): 8.0
GPU 1 Name: NVIDIA A100-SXM4-40GB
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 108
GPU 1 Total memory (GB): 39.4
GPU 1 CUDA capability (maj.min): 8.0
GPU 2 Name: NVIDIA A100-SXM4-40GB
GPU 2 Is integrated: False
GPU 2 Is multi GPU board: False
GPU 2 Multi processor count: 108
GPU 2 Total memory (GB): 39.4
GPU 2 CUDA capability (maj.min): 8.0
GPU 3 Name: NVIDIA A100-SXM4-40GB
GPU 3 Is integrated: False
GPU 3 Is multi GPU board: False
GPU 3 Multi processor count: 108
GPU 3 Total memory (GB): 39.4
GPU 3 CUDA capability (maj.min): 8.0
GPU 4 Name: NVIDIA A100-SXM4-40GB
GPU 4 Is integrated: False
GPU 4 Is multi GPU board: False
GPU 4 Multi processor count: 108
GPU 4 Total memory (GB): 39.4
GPU 4 CUDA capability (maj.min): 8.0
GPU 5 Name: NVIDIA A100-SXM4-40GB
GPU 5 Is integrated: False
GPU 5 Is multi GPU board: False
GPU 5 Multi processor count: 108
GPU 5 Total memory (GB): 39.4
GPU 5 CUDA capability (maj.min): 8.0
GPU 6 Name: NVIDIA A100-SXM4-40GB
GPU 6 Is integrated: False
GPU 6 Is multi GPU board: False
GPU 6 Multi processor count: 108
GPU 6 Total memory (GB): 39.4
GPU 6 CUDA capability (maj.min): 8.0
GPU 7 Name: NVIDIA A100-SXM4-40GB
GPU 7 Is integrated: False
GPU 7 Is multi GPU board: False
GPU 7 Multi processor count: 108
GPU 7 Total memory (GB): 39.4
GPU 7 CUDA capability (maj.min): 8.0
Additional context
Trying to use this loss for airway segmentations
The text was updated successfully, but these errors were encountered:
I am using this loss for vascular system segmentation, I realized that if I drop the learning rate very low (1e-12) the loss is no longer 0.000 but a very small number that barely changes.
As a temporary solution I am using an implementation from https://github.com/PengchengShi1220/cbDice/tree/main another Git repository. Which works, and propose an extension of the Centerline-Dice loss, the Centerline-Boundary-Dice loss.
@YannBov1 - Thanks for the insights, and yes I also came across that repo you shared and was going to implement it in my codebase. I am trying to use this for airway segmentation in really noisy MRI images.
Describe the bug
The SoftDiceclDiceLoss implementation is different from Dice loss and in its current form could not be switched with Dice or other popular losses offered. There is no option for excluding background, applying activation functions etc. Also there is no description of the expected input (y_pred)- can this be probabilities or logits or should it be a binary mask similar to the ground truth (y_true).
To Reproduce
Use this enhanced class that mimics other MONAI losses:
class EnhancedSoftDiceClDiceLoss(_Loss):
"""
Enhanced version of SoftDiceClDiceLoss with support for:
- Excluding background channel
- Applying activations (sigmoid, softmax, or custom)
- Handling one-hot encoded targets
- Flexible reduction (mean, sum, none)
"""
Expected behavior
The loss calculated is zero, even after using a custom class with enhancements, need input on how to avoid zero losses.
Screenshots
Environment
Ensuring you use the relevant python executable, please paste the output of:
Additional context
Trying to use this loss for airway segmentations
The text was updated successfully, but these errors were encountered: