Skip to content

Commit

Permalink
Add channel_wise in RandScaleIntensity (#6793)
Browse files Browse the repository at this point in the history
Part of #6629 .

### Description
Add `channel_wise` in `RandScaleIntensity` and `RandScaleIntensityd`.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Jul 28, 2023
1 parent 11546e8 commit e2fa53b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 8 deletions.
33 changes: 27 additions & 6 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ def __init__(
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
"""
self.factor = factor
Expand Down Expand Up @@ -633,12 +633,20 @@ class RandScaleIntensity(RandomizableTransform):

backend = ScaleIntensity.backend

def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtype: DtypeLike = np.float32) -> None:
def __init__(
self,
factors: tuple[float, float] | float,
prob: float = 0.1,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
if single number, factor value is picked from (-factors, factors).
prob: probability of scale.
channel_wise: if True, scale on each channel separately. Please ensure
that the first dimension represents the channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
"""
Expand All @@ -650,26 +658,39 @@ def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtyp
else:
self.factors = (min(factors), max(factors))
self.factor = self.factors[0]
self.channel_wise = channel_wise
self.dtype = dtype

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
if self.channel_wise:
self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore
else:
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if randomize:
self.randomize()
self.randomize(img)

if not self._do_transform:
return convert_data_type(img, dtype=self.dtype)[0]

return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img)
ret: NdarrayOrTensor
if self.channel_wise:
out = []
for i, d in enumerate(img):
out_channel = ScaleIntensity(minv=None, maxv=None, factor=self.factor[i], dtype=self.dtype)(d) # type: ignore
out.append(out_channel)
ret = torch.stack(out) # type: ignore
else:
ret = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img)
return ret


class RandBiasField(RandomizableTransform):
Expand Down
14 changes: 12 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def __init__(
keys: KeysCollection,
factors: tuple[float, float] | float,
prob: float = 0.1,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
Expand All @@ -597,13 +598,15 @@ def __init__(
if single number, factor value is picked from (-factors, factors).
prob: probability of scale.
(Default 0.1, with 10% probability it returns a scaled array.)
channel_wise: if True, scale on each channel separately. Please ensure
that the first dimension represents the channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0)
self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0, channel_wise=channel_wise)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
Expand All @@ -620,8 +623,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# expect all the specified keys have same spatial shape and share same random holes
first_key: Hashable = self.first_key(d)
if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# all the keys share the same random scale factor
self.scaler.randomize(None)
self.scaler.randomize(d[first_key])
for key in self.key_iterator(d):
d[key] = self.scaler(d[key], randomize=False)
return d
Expand Down
16 changes: 16 additions & 0 deletions tests/test_rand_scale_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def test_value(self, p):
expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32))
assert_allclose(result, p(expected), rtol=1e-7, atol=0, type_test="tensor")

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise(self, p):
scaler = RandScaleIntensity(factors=0.5, channel_wise=True, prob=1.0)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler(im)
np.random.seed(0)
# simulate the randomize() of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
expected = p(
np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32)
)
assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)


if __name__ == "__main__":
unittest.main()
16 changes: 16 additions & 0 deletions tests/test_rand_scale_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def test_value(self):
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
assert_allclose(result[key], p(expected), type_test="tensor")

def test_channel_wise(self):
key = "img"
for p in TEST_NDARRAYS:
scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0, channel_wise=True)
scaler.set_random_state(seed=0)
result = scaler({key: p(self.imt)})
np.random.seed(0)
# simulate the randomize function of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
expected = p(
np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32)
)
assert_allclose(result[key], p(expected), type_test="tensor")


if __name__ == "__main__":
unittest.main()

0 comments on commit e2fa53b

Please sign in to comment.