Skip to content

Commit

Permalink
Add options to skip operations for RestoreLabeld Transform (#8125)
Browse files Browse the repository at this point in the history
Fixes #6380 

### Description

Four new bool parameters are added into `RestoreLabeld` to allow users
to selectively enable or disable each restoration operation as needed,
and a corresponding test case is added to verify that the function runs
correctly.

This design allows users to selectively enable or disable each
restoration operation as needed, providing greater flexibility.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Hsin Tong <[email protected]>
Signed-off-by: Hsin-Tong Hsieh <[email protected]>
Signed-off-by: kbbbbkb <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: kbbbbkb <[email protected]>
  • Loading branch information
4 people authored Oct 12, 2024
1 parent 76ef9f4 commit 796271c
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 24 deletions.
69 changes: 46 additions & 23 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,14 @@ class RestoreLabeld(MapTransform):
original_shape_key: key that records original shape for foreground.
cropped_shape_key: key that records cropped shape for foreground.
allow_missing_keys: don't raise exception if key is missing.
restore_resizing: used to enable or disable resizing restoration, default is True.
If True, the transform will resize the items back to its original shape.
restore_cropping: used to enable or disable cropping restoration, default is True.
If True, the transform will restore the items to its uncropped size.
restore_spacing: used to enable or disable spacing restoration, default is True.
If True, the transform will resample the items back to the spacing it had before being altered.
restore_slicing: used to enable or disable slicing restoration, default is True.
If True, the transform will reassemble the full volume by restoring the slices to their original positions.
"""

def __init__(
Expand All @@ -819,6 +827,10 @@ def __init__(
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
restore_resizing: bool = True,
restore_cropping: bool = True,
restore_spacing: bool = True,
restore_slicing: bool = True,
) -> None:
super().__init__(keys, allow_missing_keys)
self.ref_image = ref_image
Expand All @@ -833,6 +845,10 @@ def __init__(
self.end_coord_key = end_coord_key
self.original_shape_key = original_shape_key
self.cropped_shape_key = cropped_shape_key
self.restore_resizing = restore_resizing
self.restore_cropping = restore_cropping
self.restore_spacing = restore_spacing
self.restore_slicing = restore_slicing

def __call__(self, data: Any) -> dict:
d = dict(data)
Expand All @@ -842,38 +858,45 @@ def __call__(self, data: Any) -> dict:
image = d[key]

# Undo Resize
current_shape = image.shape
cropped_shape = meta_dict[self.cropped_shape_key]
if np.any(np.not_equal(current_shape, cropped_shape)):
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
image = resizer(image, mode=mode, align_corners=align_corners)
if self.restore_resizing:
current_shape = image.shape
cropped_shape = meta_dict[self.cropped_shape_key]
if np.any(np.not_equal(current_shape, cropped_shape)):
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
image = resizer(image, mode=mode, align_corners=align_corners)

# Undo Crop
original_shape = meta_dict[self.original_shape_key]
result = np.zeros(original_shape, dtype=np.float32)
box_start = meta_dict[self.start_coord_key]
box_end = meta_dict[self.end_coord_key]

spatial_dims = min(len(box_start), len(image.shape[1:]))
slices = tuple(
[slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
)
result[slices] = image
if self.restore_cropping:
original_shape = meta_dict[self.original_shape_key]
result = np.zeros(original_shape, dtype=np.float32)
box_start = meta_dict[self.start_coord_key]
box_end = meta_dict[self.end_coord_key]

spatial_dims = min(len(box_start), len(image.shape[1:]))
slices = tuple(
[slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
)
result[slices] = image
else:
result = image

# Undo Spacing
current_size = result.shape[1:]
# change spatial_shape from HWD to DHW
spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
spatial_size = spatial_shape[-len(current_size) :]
if self.restore_spacing:
current_size = result.shape[1:]
# change spatial_shape from HWD to DHW
spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
spatial_size = spatial_shape[-len(current_size) :]

if np.any(np.not_equal(current_size, spatial_size)):
resizer = Resize(spatial_size=spatial_size, mode=mode)
result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore
if np.any(np.not_equal(current_size, spatial_size)):
resizer = Resize(spatial_size=spatial_size, mode=mode)
result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore

# Undo Slicing
slice_idx = meta_dict.get("slice_idx")
final_result: NdarrayOrTensor
if slice_idx is None or self.slice_only:
if not self.restore_slicing: # do nothing if restore slicing isn't requested
final_result = result
elif slice_idx is None or self.slice_only:
final_result = result if len(result.shape) <= 3 else result[0]
else:
slice_idx = meta_dict["slice_idx"][0]
Expand Down
95 changes: 94 additions & 1 deletion tests/test_deepgrow_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@

DATA_12 = {"image": np.arange(27).reshape(3, 3, 3), PostFix.meta("image"): {}, "guidance": [[0, 0, 0], [0, 1, 1], 1]}

DATA_13 = {
"image": np.arange(64).reshape((1, 4, 4, 4)),
PostFix.meta("image"): {
"spatial_shape": [8, 8, 4],
"foreground_start_coord": np.array([1, 1, 1]),
"foreground_end_coord": np.array([3, 3, 3]),
"foreground_original_shape": (1, 4, 4, 4),
"foreground_cropped_shape": (1, 2, 2, 2),
"original_affine": np.array(
[[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]
),
},
"pred": np.array([[[[10, 20], [30, 40]], [[50, 60], [70, 80]]]]),
}

FIND_SLICE_TEST_CASE_1 = [{"label": "label", "sids": "sids"}, DATA_1, [0]]

FIND_SLICE_TEST_CASE_2 = [{"label": "label", "sids": "sids"}, DATA_2, [0, 1]]
Expand Down Expand Up @@ -329,6 +344,74 @@

RESTORE_LABEL_TEST_CASE_2 = [{"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, DATA_11, RESULT]

RESTORE_LABEL_TEST_CASE_3_RESULT = np.zeros((10, 20, 20))
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 0:10] = 1
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 10:20] = 2
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 0:10] = 3
RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 10:20] = 4
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 0:10] = 5
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 10:20] = 6
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 0:10] = 7
RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 10:20] = 8

RESTORE_LABEL_TEST_CASE_3 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_cropping": False},
DATA_11,
RESTORE_LABEL_TEST_CASE_3_RESULT,
]

RESTORE_LABEL_TEST_CASE_4_RESULT = np.zeros((4, 8, 8))
RESTORE_LABEL_TEST_CASE_4_RESULT[1, 2:6, 2:6] = np.array(
[[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]
)
RESTORE_LABEL_TEST_CASE_4_RESULT[2, 2:6, 2:6] = np.array(
[[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]
)

RESTORE_LABEL_TEST_CASE_4 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_resizing": False},
DATA_13,
RESTORE_LABEL_TEST_CASE_4_RESULT,
]

RESTORE_LABEL_TEST_CASE_5_RESULT = np.zeros((4, 4, 4))
RESTORE_LABEL_TEST_CASE_5_RESULT[1, 1:3, 1:3] = np.array([[10.0, 20.0], [30.0, 40.0]])
RESTORE_LABEL_TEST_CASE_5_RESULT[2, 1:3, 1:3] = np.array([[50.0, 60.0], [70.0, 80.0]])

RESTORE_LABEL_TEST_CASE_5 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_spacing": False},
DATA_13,
RESTORE_LABEL_TEST_CASE_5_RESULT,
]

RESTORE_LABEL_TEST_CASE_6_RESULT = np.zeros((1, 4, 8, 8))
RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 1, 2:6, 2:6] = np.array(
[[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]
)
RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 2, 2:6, 2:6] = np.array(
[[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]
)

RESTORE_LABEL_TEST_CASE_6 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_slicing": False},
DATA_13,
RESTORE_LABEL_TEST_CASE_6_RESULT,
]

RESTORE_LABEL_TEST_CASE_7 = [
{
"keys": ["pred"],
"ref_image": "image",
"mode": "nearest",
"restore_resizing": False,
"restore_cropping": False,
"restore_spacing": False,
"restore_slicing": False,
},
DATA_11,
np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]),
]

FETCH_2D_SLICE_TEST_CASE_1 = [
{"keys": ["image"], "guidance": "guidance"},
DATA_12,
Expand Down Expand Up @@ -445,7 +528,17 @@ def test_correct_results(self, arguments, input_data, expected_result):

class TestRestoreLabeld(unittest.TestCase):

@parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2])
@parameterized.expand(
[
RESTORE_LABEL_TEST_CASE_1,
RESTORE_LABEL_TEST_CASE_2,
RESTORE_LABEL_TEST_CASE_3,
RESTORE_LABEL_TEST_CASE_4,
RESTORE_LABEL_TEST_CASE_5,
RESTORE_LABEL_TEST_CASE_6,
RESTORE_LABEL_TEST_CASE_7,
]
)
def test_correct_results(self, arguments, input_data, expected_result):
result = RestoreLabeld(**arguments)(input_data)
np.testing.assert_allclose(result["pred"], expected_result)
Expand Down

0 comments on commit 796271c

Please sign in to comment.