diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 13b4952ab3..aded5fdcb7 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -97,7 +97,7 @@ def test_ill_reduction(self): def test_script(self): loss = DiceCELoss() - test_input = torch.ones(2, 1, 8, 8) + test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input)