From ac6439969f0403838bbec05c5fea869e954d63f9 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Sun, 8 Oct 2023 17:31:16 +0800 Subject: [PATCH 1/4] fix #7102 Signed-off-by: KumoLiu --- monai/transforms/utility/array.py | 6 +++--- monai/transforms/utility/dictionary.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f9b81865e0..97e5a4a1cb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -424,7 +424,7 @@ class EnsureType(Transform): def __init__( self, data_type: str = "tensor", - dtype: DtypeLike | torch.dtype | None = None, + dtype: DtypeLike | torch.dtype = None, device: torch.device | None = None, wrap_sequence: bool = True, track_meta: bool | None = None, @@ -435,7 +435,7 @@ def __init__( self.wrap_sequence = wrap_sequence self.track_meta = get_track_meta() if track_meta is None else bool(track_meta) - def __call__(self, data: NdarrayOrTensor): + def __call__(self, data: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -452,7 +452,7 @@ def __call__(self, data: NdarrayOrTensor): out, *_ = convert_data_type( data=data, output_type=output_type, # type: ignore - dtype=self.dtype, + dtype=self.dtype if dtype is None else dtype, device=self.device, wrap_sequence=self.wrap_sequence, ) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 692e648935..4b0cfad661 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -488,7 +488,7 @@ def __init__( self, keys: KeysCollection, data_type: str = "tensor", - dtype: DtypeLike | torch.dtype = None, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = None, device: torch.device | None = None, wrap_sequence: bool = True, track_meta: bool | None = None, @@ -508,14 +508,15 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = EnsureType( - data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta + data_type=data_type, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta ) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key in self.key_iterator(d): - d[key] = self.converter(d[key]) + for key, dtype in self.key_iterator(d, self.dtype): + d[key] = self.converter(d[key], dtype) return d From dd91c075bf710a9abf414e1878e507d5fd46b6fb Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 9 Oct 2023 09:50:46 +0800 Subject: [PATCH 2/4] add unittest Signed-off-by: KumoLiu --- tests/test_ensure_typed.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 98a41b5430..4fa942e742 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -86,14 +86,24 @@ def test_dict(self): "extra": None, } for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"] - self.assertTrue(isinstance(result, dict)) - self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) - self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) - self.assertEqual(result["meta"]["path"], "temp/test") - self.assertEqual(result["extra"], None) + trans = EnsureTyped(keys=["data", "label"], data_type=dtype, dtype=[np.float32, np.int8], device="cpu")( + {"data": test_data, "label": test_data} + ) + for key in ("data", "label"): + result = trans[key] + self.assertTrue(isinstance(result, dict)) + self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertEqual(result["meta"]["path"], "temp/test") + self.assertEqual(result["extra"], None) + assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) + assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) + if dtype == "numpy": + self.assertTrue(trans["data"]["img"].dtype == np.float32) + self.assertTrue(trans["label"]["img"].dtype == np.int8) + else: + self.assertTrue(trans["data"]["img"].dtype == torch.float32) + self.assertTrue(trans["label"]["img"].dtype == torch.int8) if __name__ == "__main__": From 8678a67c1e2eee0c3a87fcd3ac15c0411ef26689 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 9 Oct 2023 14:44:39 +0800 Subject: [PATCH 3/4] update docstring Signed-off-by: KumoLiu --- monai/transforms/utility/dictionary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 4b0cfad661..ec10bd8537 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -500,6 +500,7 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". dtype: target data content type to convert, for example: np.float32, torch.float, etc. + It also can be a sequence of dtype, each element corresponds to a key in ``keys``. device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. From c38b6fcfb057bf360f85bde009608d6330054986 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 9 Oct 2023 15:51:20 +0800 Subject: [PATCH 4/4] update docstring Signed-off-by: KumoLiu --- monai/transforms/utility/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 97e5a4a1cb..9aad12ef90 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -442,6 +442,7 @@ def __call__(self, data: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None) will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and objects keep the original. for dictionary, list or tuple, ensure every item as expected type if applicable and `wrap_sequence=False`. + dtype: target data content type to convert, for example: np.float32, torch.float, etc. """ if self.data_type == "tensor":