diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f8eadcfb1b..8cd15083c9 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -493,8 +493,8 @@ def __init__( fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling to ensure that the output has the same mean as the input. channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied - on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the - channel of the image if True. + on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the + channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. """ self.factor = factor @@ -633,12 +633,20 @@ class RandScaleIntensity(RandomizableTransform): backend = ScaleIntensity.backend - def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtype: DtypeLike = np.float32) -> None: + def __init__( + self, + factors: tuple[float, float] | float, + prob: float = 0.1, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, + ) -> None: """ Args: factors: factor range to randomly scale by ``v = v * (1 + factor)``. if single number, factor value is picked from (-factors, factors). prob: probability of scale. + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. """ @@ -650,13 +658,17 @@ def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtyp else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] + self.channel_wise = channel_wise self.dtype = dtype def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) + if self.channel_wise: + self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore + else: + self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ @@ -664,12 +676,21 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: - self.randomize() + self.randomize(img) if not self._do_transform: return convert_data_type(img, dtype=self.dtype)[0] - return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img) + ret: NdarrayOrTensor + if self.channel_wise: + out = [] + for i, d in enumerate(img): + out_channel = ScaleIntensity(minv=None, maxv=None, factor=self.factor[i], dtype=self.dtype)(d) # type: ignore + out.append(out_channel) + ret = torch.stack(out) # type: ignore + else: + ret = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img) + return ret class RandBiasField(RandomizableTransform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 91acff0c3d..32052ad406 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -586,6 +586,7 @@ def __init__( keys: KeysCollection, factors: tuple[float, float] | float, prob: float = 0.1, + channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: @@ -597,13 +598,15 @@ def __init__( if single number, factor value is picked from (-factors, factors). prob: probability of scale. (Default 0.1, with 10% probability it returns a scaled array.) + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0) + self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0, channel_wise=channel_wise) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -620,8 +623,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d + # expect all the specified keys have same spatial shape and share same random holes + first_key: Hashable = self.first_key(d) + if first_key == (): + for key in self.key_iterator(d): + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + return d + # all the keys share the same random scale factor - self.scaler.randomize(None) + self.scaler.randomize(d[first_key]) for key in self.key_iterator(d): d[key] = self.scaler(d[key], randomize=False) return d diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 5f5ca076a8..a857c0cefb 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -33,6 +33,22 @@ def test_value(self, p): expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) assert_allclose(result, p(expected), rtol=1e-7, atol=0, type_test="tensor") + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + scaler = RandScaleIntensity(factors=0.5, channel_wise=True, prob=1.0) + scaler.set_random_state(seed=0) + im = p(self.imt) + result = scaler(im) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)] + expected = p( + np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32) + ) + assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index 6b5a04a8f3..8d928ac157 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -32,6 +32,22 @@ def test_value(self): expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) assert_allclose(result[key], p(expected), type_test="tensor") + def test_channel_wise(self): + key = "img" + for p in TEST_NDARRAYS: + scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0, channel_wise=True) + scaler.set_random_state(seed=0) + result = scaler({key: p(self.imt)}) + np.random.seed(0) + # simulate the randomize function of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)] + expected = p( + np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32) + ) + assert_allclose(result[key], p(expected), type_test="tensor") + if __name__ == "__main__": unittest.main()