From 19cc6f01766120132f964beecb06d1d561f83801 Mon Sep 17 00:00:00 2001 From: "Wei_Chuan, Chiang" <45346252+slicepaste@users.noreply.github.com> Date: Wed, 4 Sep 2024 18:42:49 +0800 Subject: [PATCH] Make MetaTensor optional printed in DataStats and DataStatsd #5905 (#7814) Fixes #5905 ### Description We simply add one argument for DataStats and DataStatsd to make MetaTensor optional printed. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wei_Chuan, Chiang Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Suraj Pai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Suraj Pai Co-authored-by: Ben Murray --- monai/transforms/utility/array.py | 7 ++++ monai/transforms/utility/dictionary.py | 28 +++++++++++-- tests/test_data_stats.py | 41 ++++++++++++++++++- tests/test_data_statsd.py | 54 +++++++++++++++++++++++++- 4 files changed, 123 insertions(+), 7 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bfd2f506c2..72dd189009 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -656,6 +656,7 @@ def __init__( data_shape: bool = True, value_range: bool = True, data_value: bool = False, + meta_info: bool = False, additional_info: Callable | None = None, name: str = "DataStats", ) -> None: @@ -667,6 +668,7 @@ def __init__( value_range: whether to show the value range of input data. data_value: whether to show the raw value of input data. a typical example is to print some properties of Nifti image: affine, pixdim, etc. + meta_info: whether to show the data of MetaTensor. additional_info: user can define callable function to extract additional info from input data. name: identifier of `logging.logger` to use, defaulting to "DataStats". @@ -681,6 +683,7 @@ def __init__( self.data_shape = data_shape self.value_range = value_range self.data_value = data_value + self.meta_info = meta_info if additional_info is not None and not callable(additional_info): raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.") self.additional_info = additional_info @@ -707,6 +710,7 @@ def __call__( data_shape: bool | None = None, value_range: bool | None = None, data_value: bool | None = None, + meta_info: bool | None = None, additional_info: Callable | None = None, ) -> NdarrayOrTensor: """ @@ -727,6 +731,9 @@ def __call__( lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})") if self.data_value if data_value is None else data_value: lines.append(f"Value: {img}") + if self.meta_info if meta_info is None else meta_info: + metadata = getattr(img, "meta", "(input is not a MetaTensor)") + lines.append(f"Meta info: {repr(metadata)}") additional_info = self.additional_info if additional_info is None else additional_info if additional_info is not None: lines.append(f"Additional info: {additional_info(img)}") diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index db5f19c0de..79d0be522d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -793,6 +793,7 @@ def __init__( data_shape: Sequence[bool] | bool = True, value_range: Sequence[bool] | bool = True, data_value: Sequence[bool] | bool = False, + meta_info: Sequence[bool] | bool = False, additional_info: Sequence[Callable] | Callable | None = None, name: str = "DataStats", allow_missing_keys: bool = False, @@ -812,6 +813,8 @@ def __init__( data_value: whether to show the raw value of input data. it also can be a sequence of bool, each element corresponds to a key in ``keys``. a typical example is to print some properties of Nifti image: affine, pixdim, etc. + meta_info: whether to show the data of MetaTensor. + it also can be a sequence of bool, each element corresponds to a key in ``keys``. additional_info: user can define callable function to extract additional info from input data. it also can be a sequence of string, each element corresponds to a key in ``keys``. @@ -825,15 +828,34 @@ def __init__( self.data_shape = ensure_tuple_rep(data_shape, len(self.keys)) self.value_range = ensure_tuple_rep(value_range, len(self.keys)) self.data_value = ensure_tuple_rep(data_value, len(self.keys)) + self.meta_info = ensure_tuple_rep(meta_info, len(self.keys)) self.additional_info = ensure_tuple_rep(additional_info, len(self.keys)) self.printer = DataStats(name=name) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( - d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info + for ( + key, + prefix, + data_type, + data_shape, + value_range, + data_value, + meta_info, + additional_info, + ) in self.key_iterator( + d, + self.prefix, + self.data_type, + self.data_shape, + self.value_range, + self.data_value, + self.meta_info, + self.additional_info, ): - d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info) + d[key] = self.printer( + d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info + ) return d diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 05453b0694..f9b424f8e1 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -23,6 +23,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import DataStats TEST_CASE_1 = [ @@ -130,20 +131,55 @@ ] TEST_CASE_8 = [ + { + "prefix": "test data", + "data_type": True, + "data_shape": True, + "value_range": True, + "data_value": True, + "additional_info": np.mean, + "name": "DataStats", + }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] +TEST_CASE_9 = [ + np.array([[0, 1], [1, 2]]), + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\n" + "Meta info: '(input is not a MetaTensor)'\n" + "Additional info: 1.0\n", +] + +TEST_CASE_10 = [ + MetaTensor( + torch.tensor([[0, 1], [1, 2]]), + affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64), + meta={"some": "info"}, + ), + "test data statistics:\nType: torch.int64\n" + "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "Value: tensor([[0, 1],\n [1, 2]])\n" + "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n" + " [0., 2., 0., 0.],\n" + " [0., 0., 2., 0.],\n" + " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n" + "Additional info: 1.0\n", +] + class TestDataStats(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] + ) def test_value(self, input_param, input_data, expected_print): transform = DataStats(**input_param) _ = transform(input_data) - @parameterized.expand([TEST_CASE_8]) + @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_data_stats.log") @@ -158,6 +194,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, + "meta_info": True, "additional_info": np.mean, "name": name, } diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index ef88300c10..a28a938c40 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -21,6 +21,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import DataStatsd TEST_CASE_1 = [ @@ -150,22 +151,70 @@ ] TEST_CASE_9 = [ + { + "keys": "img", + "prefix": "test data", + "data_shape": True, + "value_range": True, + "data_value": True, + "meta_info": False, + "additional_info": np.mean, + "name": "DataStats", + }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] +TEST_CASE_10 = [ + {"img": np.array([[0, 1], [1, 2]])}, + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\n" + "Meta info: '(input is not a MetaTensor)'\n" + "Additional info: 1.0\n", +] + +TEST_CASE_11 = [ + { + "img": ( + MetaTensor( + torch.tensor([[0, 1], [1, 2]]), + affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64), + meta={"some": "info"}, + ) + ) + }, + "test data statistics:\nType: torch.int64\n" + "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "Value: tensor([[0, 1],\n [1, 2]])\n" + "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n" + " [0., 2., 0., 0.],\n" + " [0., 0., 2., 0.],\n" + " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n" + "Additional info: 1.0\n", +] + class TestDataStatsd(unittest.TestCase): @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + ] ) def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) _ = transform(input_data) - @parameterized.expand([TEST_CASE_9]) + @parameterized.expand([TEST_CASE_10, TEST_CASE_11]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_stats.log") @@ -180,6 +229,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, + "meta_info": True, "additional_info": np.mean, "name": name, }