diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 8ea3eb116b..8ee1da7267 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -61,7 +61,7 @@ def __init__( """ super().__init__(reduction=LossReduction(reduction).value) self.spatial_dims = spatial_dims - self.data_range = data_range + self._data_range = data_range self.kernel_type = kernel_type if not isinstance(win_size, Sequence): @@ -77,7 +77,7 @@ def __init__( self.ssim_metric = SSIMMetric( spatial_dims=self.spatial_dims, - data_range=self.data_range, + data_range=self._data_range, kernel_type=self.kernel_type, win_size=self.kernel_size, kernel_sigma=self.kernel_sigma, @@ -85,6 +85,15 @@ def __init__( k2=self.k2, ) + @property + def data_range(self) -> float: + return self._data_range + + @data_range.setter + def data_range(self, value: float) -> None: + self._data_range = value + self.ssim_metric.data_range = value + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: