From 825021d46d6c1a97d6c38bcbfbb0e5e0e148995e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:02:14 +0800 Subject: [PATCH] `EnsureTyped` flexible dtype (#7104) Fixes #7102 ### Description Make dtype in `EnsureTyped` configurable as different dtypes for difference keys ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu --- monai/transforms/utility/array.py | 7 ++++--- monai/transforms/utility/dictionary.py | 10 ++++++---- tests/test_ensure_typed.py | 26 ++++++++++++++++++-------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f9b81865e0..9aad12ef90 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -424,7 +424,7 @@ class EnsureType(Transform): def __init__( self, data_type: str = "tensor", - dtype: DtypeLike | torch.dtype | None = None, + dtype: DtypeLike | torch.dtype = None, device: torch.device | None = None, wrap_sequence: bool = True, track_meta: bool | None = None, @@ -435,13 +435,14 @@ def __init__( self.wrap_sequence = wrap_sequence self.track_meta = get_track_meta() if track_meta is None else bool(track_meta) - def __call__(self, data: NdarrayOrTensor): + def __call__(self, data: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and objects keep the original. for dictionary, list or tuple, ensure every item as expected type if applicable and `wrap_sequence=False`. + dtype: target data content type to convert, for example: np.float32, torch.float, etc. """ if self.data_type == "tensor": @@ -452,7 +453,7 @@ def __call__(self, data: NdarrayOrTensor): out, *_ = convert_data_type( data=data, output_type=output_type, # type: ignore - dtype=self.dtype, + dtype=self.dtype if dtype is None else dtype, device=self.device, wrap_sequence=self.wrap_sequence, ) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 692e648935..ec10bd8537 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -488,7 +488,7 @@ def __init__( self, keys: KeysCollection, data_type: str = "tensor", - dtype: DtypeLike | torch.dtype = None, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = None, device: torch.device | None = None, wrap_sequence: bool = True, track_meta: bool | None = None, @@ -500,6 +500,7 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". dtype: target data content type to convert, for example: np.float32, torch.float, etc. + It also can be a sequence of dtype, each element corresponds to a key in ``keys``. device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. @@ -508,14 +509,15 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = EnsureType( - data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta + data_type=data_type, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta ) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key in self.key_iterator(d): - d[key] = self.converter(d[key]) + for key, dtype in self.key_iterator(d, self.dtype): + d[key] = self.converter(d[key], dtype) return d diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 98a41b5430..4fa942e742 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -86,14 +86,24 @@ def test_dict(self): "extra": None, } for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"] - self.assertTrue(isinstance(result, dict)) - self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) - self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) - self.assertEqual(result["meta"]["path"], "temp/test") - self.assertEqual(result["extra"], None) + trans = EnsureTyped(keys=["data", "label"], data_type=dtype, dtype=[np.float32, np.int8], device="cpu")( + {"data": test_data, "label": test_data} + ) + for key in ("data", "label"): + result = trans[key] + self.assertTrue(isinstance(result, dict)) + self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertEqual(result["meta"]["path"], "temp/test") + self.assertEqual(result["extra"], None) + assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) + assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) + if dtype == "numpy": + self.assertTrue(trans["data"]["img"].dtype == np.float32) + self.assertTrue(trans["label"]["img"].dtype == np.int8) + else: + self.assertTrue(trans["data"]["img"].dtype == torch.float32) + self.assertTrue(trans["label"]["img"].dtype == torch.int8) if __name__ == "__main__":