diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index da061a94..61baf4c8 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -51,6 +51,10 @@ def __init__( self.num_classes = num_classes self.compute_ce_loss = nn.CrossEntropyLoss() self.dice_weight = dice_weight + + if self.dice_weight is not None: + assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1." + self._kwargs = kwargs def _compute_loss(self, y, masks):