Skip to content

Commit

Permalink
EnsureTyped flexible dtype (#7104)
Browse files Browse the repository at this point in the history
Fixes #7102

### Description
Make dtype in `EnsureTyped` configurable as different dtypes for
difference keys

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Oct 9, 2023
1 parent 2b0a95e commit 825021d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
7 changes: 4 additions & 3 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class EnsureType(Transform):
def __init__(
self,
data_type: str = "tensor",
dtype: DtypeLike | torch.dtype | None = None,
dtype: DtypeLike | torch.dtype = None,
device: torch.device | None = None,
wrap_sequence: bool = True,
track_meta: bool | None = None,
Expand All @@ -435,13 +435,14 @@ def __init__(
self.wrap_sequence = wrap_sequence
self.track_meta = get_track_meta() if track_meta is None else bool(track_meta)

def __call__(self, data: NdarrayOrTensor):
def __call__(self, data: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None):
"""
Args:
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and
objects keep the original. for dictionary, list or tuple, ensure every item as expected type
if applicable and `wrap_sequence=False`.
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
"""
if self.data_type == "tensor":
Expand All @@ -452,7 +453,7 @@ def __call__(self, data: NdarrayOrTensor):
out, *_ = convert_data_type(
data=data,
output_type=output_type, # type: ignore
dtype=self.dtype,
dtype=self.dtype if dtype is None else dtype,
device=self.device,
wrap_sequence=self.wrap_sequence,
)
Expand Down
10 changes: 6 additions & 4 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def __init__(
self,
keys: KeysCollection,
data_type: str = "tensor",
dtype: DtypeLike | torch.dtype = None,
dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = None,
device: torch.device | None = None,
wrap_sequence: bool = True,
track_meta: bool | None = None,
Expand All @@ -500,6 +500,7 @@ def __init__(
See also: :py:class:`monai.transforms.compose.MapTransform`
data_type: target data type to convert, should be "tensor" or "numpy".
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
It also can be a sequence of dtype, each element corresponds to a key in ``keys``.
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
Expand All @@ -508,14 +509,15 @@ def __init__(
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
self.converter = EnsureType(
data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta
data_type=data_type, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta
)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
for key, dtype in self.key_iterator(d, self.dtype):
d[key] = self.converter(d[key], dtype)
return d


Expand Down
26 changes: 18 additions & 8 deletions tests/test_ensure_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,24 @@ def test_dict(self):
"extra": None,
}
for dtype in ("tensor", "numpy"):
result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"]
self.assertTrue(isinstance(result, dict))
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False)
self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False)
self.assertEqual(result["meta"]["path"], "temp/test")
self.assertEqual(result["extra"], None)
trans = EnsureTyped(keys=["data", "label"], data_type=dtype, dtype=[np.float32, np.int8], device="cpu")(
{"data": test_data, "label": test_data}
)
for key in ("data", "label"):
result = trans[key]
self.assertTrue(isinstance(result, dict))
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray))
self.assertEqual(result["meta"]["path"], "temp/test")
self.assertEqual(result["extra"], None)
assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False)
assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False)
if dtype == "numpy":
self.assertTrue(trans["data"]["img"].dtype == np.float32)
self.assertTrue(trans["label"]["img"].dtype == np.int8)
else:
self.assertTrue(trans["data"]["img"].dtype == torch.float32)
self.assertTrue(trans["label"]["img"].dtype == torch.int8)


if __name__ == "__main__":
Expand Down

0 comments on commit 825021d

Please sign in to comment.