Skip to content

Commit

Permalink
MapLabelValue support tensor/gpu backend (#6872)
Browse files Browse the repository at this point in the history
### Description
- adding pytorch backend support for `MapLabelValue`
- fixes #6869 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Aug 15, 2023
1 parent 5833b1c commit e24b969
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 21 deletions.
47 changes: 28 additions & 19 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,41 +1207,50 @@ class MapLabelValue:
"""

backend = [TransformBackends.NUMPY]
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None:
"""
Args:
orig_labels: original labels that map to others.
target_labels: expected label values, 1: 1 map to the `orig_labels`.
dtype: convert the output data to dtype, default to float32.
if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend.
"""
if len(orig_labels) != len(target_labels):
raise ValueError("orig_labels and target_labels must have the same length.")
if all(o == z for o, z in zip(orig_labels, target_labels)):
raise ValueError("orig_labels and target_labels are exactly the same, should be different to map.")

self.orig_labels = orig_labels
self.target_labels = target_labels
self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray)
self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t)
type_dtype = type(dtype)
if getattr(type_dtype, "__module__", "") == "torch":
self.use_numpy = False
self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor)
else:
self.use_numpy = True
self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray)

def __call__(self, img: NdarrayOrTensor):
img_np, *_ = convert_data_type(img, np.ndarray)
img_flat = img_np.flatten()
try:
out_flat = np.array(img_flat, dtype=self.dtype)
except ValueError:
# can't copy unchanged labels as the expected dtype is not supported, must map all the label values
out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype)

for o, t in zip(self.orig_labels, self.target_labels):
if o == t:
continue
np.place(out_flat, img_flat == o, t)

reshaped = out_flat.reshape(img_np.shape)
out, *_ = convert_to_dst_type(src=reshaped, dst=img, dtype=self.dtype)
if self.use_numpy:
img_np, *_ = convert_data_type(img, np.ndarray)
_out_shape = img_np.shape
img_flat = img_np.flatten()
try:
out_flat = img_flat.astype(self.dtype)
except ValueError:
# can't copy unchanged labels as the expected dtype is not supported, must map all the label values
out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype)
for o, t in self.pair:
out_flat[img_flat == o] = t
out_t = out_flat.reshape(_out_shape)
else:
img_t, *_ = convert_data_type(img, torch.Tensor)
out_t = img_t.detach().clone().to(self.dtype) # type: ignore
for o, t in self.pair:
out_t[img_t == o] = t
out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype)
return out


Expand Down
1 change: 1 addition & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ def __init__(
orig_labels: original labels that map to others.
target_labels: expected label values, 1: 1 map to the `orig_labels`.
dtype: convert the output data to dtype, default to float32.
if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend.
allow_missing_keys: don't raise exception if key is missing.
"""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_map_label_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
p([2, 0, 0, 1]),
]
)
TESTS.append(
[
{"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": torch.int8},
p([3.5, 1.5, 1.5, 2.5]),
p([2, 0, 0, 1]),
]
)
TESTS.extend(
[
[
Expand Down
16 changes: 14 additions & 2 deletions tests/test_map_label_valued.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import MapLabelValued
from tests.utils import assert_allclose

TEST_CASE_1 = [
{"keys": "seg", "orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]},
Expand Down Expand Up @@ -47,6 +49,11 @@
{"seg": np.array([3.5, 1.5, 1.5, 2.5])},
np.array([2, 0, 0, 1]),
]
TEST_CASE_5_1 = [
{"keys": "seg", "orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": torch.int8},
{"seg": torch.as_tensor([3.5, 1.5, 1.5, 2.5])},
torch.as_tensor([2.0, 0.0, 0.0, 1.0]),
]

TEST_CASE_6 = [
{"keys": "seg", "orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]},
Expand All @@ -62,10 +69,15 @@


class TestMapLabelValued(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_5_1, TEST_CASE_6, TEST_CASE_7]
)
def test_shape(self, input_param, input_data, expected_value):
result = MapLabelValued(**input_param)(input_data)
np.testing.assert_equal(result["seg"], expected_value)
if isinstance(expected_value, torch.Tensor):
assert_allclose(result["seg"], expected_value)
else:
np.testing.assert_equal(result["seg"], expected_value)
self.assertTupleEqual(result["seg"].shape, expected_value.shape)


Expand Down

0 comments on commit e24b969

Please sign in to comment.