diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f9b81865e0..9aad12ef90 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -424,7 +424,7 @@ class EnsureType(Transform): def __init__( self, data_type: str = "tensor", - dtype: DtypeLike | torch.dtype | None = None, + dtype: DtypeLike | torch.dtype = None, device: torch.device | None = None, wrap_sequence: bool = True, track_meta: bool | None = None, @@ -435,13 +435,14 @@ def __init__( self.wrap_sequence = wrap_sequence self.track_meta = get_track_meta() if track_meta is None else bool(track_meta) - def __call__(self, data: NdarrayOrTensor): + def __call__(self, data: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and objects keep the original. for dictionary, list or tuple, ensure every item as expected type if applicable and `wrap_sequence=False`. + dtype: target data content type to convert, for example: np.float32, torch.float, etc. """ if self.data_type == "tensor": @@ -452,7 +453,7 @@ def __call__(self, data: NdarrayOrTensor): out, *_ = convert_data_type( data=data, output_type=output_type, # type: ignore - dtype=self.dtype, + dtype=self.dtype if dtype is None else dtype, device=self.device, wrap_sequence=self.wrap_sequence, ) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 692e648935..ec10bd8537 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -488,7 +488,7 @@ def __init__( self, keys: KeysCollection, data_type: str = "tensor", - dtype: DtypeLike | torch.dtype = None, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = None, device: torch.device | None = None, wrap_sequence: bool = True, track_meta: bool | None = None, @@ -500,6 +500,7 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". dtype: target data content type to convert, for example: np.float32, torch.float, etc. + It also can be a sequence of dtype, each element corresponds to a key in ``keys``. device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. @@ -508,14 +509,15 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = EnsureType( - data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta + data_type=data_type, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta ) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key in self.key_iterator(d): - d[key] = self.converter(d[key]) + for key, dtype in self.key_iterator(d, self.dtype): + d[key] = self.converter(d[key], dtype) return d diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 98a41b5430..4fa942e742 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -86,14 +86,24 @@ def test_dict(self): "extra": None, } for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"] - self.assertTrue(isinstance(result, dict)) - self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) - self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) - self.assertEqual(result["meta"]["path"], "temp/test") - self.assertEqual(result["extra"], None) + trans = EnsureTyped(keys=["data", "label"], data_type=dtype, dtype=[np.float32, np.int8], device="cpu")( + {"data": test_data, "label": test_data} + ) + for key in ("data", "label"): + result = trans[key] + self.assertTrue(isinstance(result, dict)) + self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertEqual(result["meta"]["path"], "temp/test") + self.assertEqual(result["extra"], None) + assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) + assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) + if dtype == "numpy": + self.assertTrue(trans["data"]["img"].dtype == np.float32) + self.assertTrue(trans["label"]["img"].dtype == np.int8) + else: + self.assertTrue(trans["data"]["img"].dtype == torch.float32) + self.assertTrue(trans["label"]["img"].dtype == torch.int8) if __name__ == "__main__":