Skip to content

Commit

Permalink
Make MetaTensor optional printed in DataStats and DataStatsd #5905 (#…
Browse files Browse the repository at this point in the history
…7814)

Fixes #5905 

### Description

We simply add one argument for DataStats and DataStatsd to make
MetaTensor optional printed.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wei_Chuan, Chiang <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Signed-off-by: Suraj Pai <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: Suraj Pai <[email protected]>
Co-authored-by: Ben Murray <[email protected]>
  • Loading branch information
6 people authored Sep 4, 2024
1 parent 4e70bf6 commit 19cc6f0
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 7 deletions.
7 changes: 7 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def __init__(
data_shape: bool = True,
value_range: bool = True,
data_value: bool = False,
meta_info: bool = False,
additional_info: Callable | None = None,
name: str = "DataStats",
) -> None:
Expand All @@ -667,6 +668,7 @@ def __init__(
value_range: whether to show the value range of input data.
data_value: whether to show the raw value of input data.
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
meta_info: whether to show the data of MetaTensor.
additional_info: user can define callable function to extract additional info from input data.
name: identifier of `logging.logger` to use, defaulting to "DataStats".
Expand All @@ -681,6 +683,7 @@ def __init__(
self.data_shape = data_shape
self.value_range = value_range
self.data_value = data_value
self.meta_info = meta_info
if additional_info is not None and not callable(additional_info):
raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.")
self.additional_info = additional_info
Expand All @@ -707,6 +710,7 @@ def __call__(
data_shape: bool | None = None,
value_range: bool | None = None,
data_value: bool | None = None,
meta_info: bool | None = None,
additional_info: Callable | None = None,
) -> NdarrayOrTensor:
"""
Expand All @@ -727,6 +731,9 @@ def __call__(
lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})")
if self.data_value if data_value is None else data_value:
lines.append(f"Value: {img}")
if self.meta_info if meta_info is None else meta_info:
metadata = getattr(img, "meta", "(input is not a MetaTensor)")
lines.append(f"Meta info: {repr(metadata)}")
additional_info = self.additional_info if additional_info is None else additional_info
if additional_info is not None:
lines.append(f"Additional info: {additional_info(img)}")
Expand Down
28 changes: 25 additions & 3 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ def __init__(
data_shape: Sequence[bool] | bool = True,
value_range: Sequence[bool] | bool = True,
data_value: Sequence[bool] | bool = False,
meta_info: Sequence[bool] | bool = False,
additional_info: Sequence[Callable] | Callable | None = None,
name: str = "DataStats",
allow_missing_keys: bool = False,
Expand All @@ -812,6 +813,8 @@ def __init__(
data_value: whether to show the raw value of input data.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
meta_info: whether to show the data of MetaTensor.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
additional_info: user can define callable function to extract
additional info from input data. it also can be a sequence of string, each element
corresponds to a key in ``keys``.
Expand All @@ -825,15 +828,34 @@ def __init__(
self.data_shape = ensure_tuple_rep(data_shape, len(self.keys))
self.value_range = ensure_tuple_rep(value_range, len(self.keys))
self.data_value = ensure_tuple_rep(data_value, len(self.keys))
self.meta_info = ensure_tuple_rep(meta_info, len(self.keys))
self.additional_info = ensure_tuple_rep(additional_info, len(self.keys))
self.printer = DataStats(name=name)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator(
d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info
for (
key,
prefix,
data_type,
data_shape,
value_range,
data_value,
meta_info,
additional_info,
) in self.key_iterator(
d,
self.prefix,
self.data_type,
self.data_shape,
self.value_range,
self.data_value,
self.meta_info,
self.additional_info,
):
d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info)
d[key] = self.printer(
d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info
)
return d


Expand Down
41 changes: 39 additions & 2 deletions tests/test_data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from parameterized import parameterized

from monai.data.meta_tensor import MetaTensor
from monai.transforms import DataStats

TEST_CASE_1 = [
Expand Down Expand Up @@ -130,20 +131,55 @@
]

TEST_CASE_8 = [
{
"prefix": "test data",
"data_type": True,
"data_shape": True,
"value_range": True,
"data_value": True,
"additional_info": np.mean,
"name": "DataStats",
},
np.array([[0, 1], [1, 2]]),
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
"Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
]

TEST_CASE_9 = [
np.array([[0, 1], [1, 2]]),
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
"Value: [[0 1]\n [1 2]]\n"
"Meta info: '(input is not a MetaTensor)'\n"
"Additional info: 1.0\n",
]

TEST_CASE_10 = [
MetaTensor(
torch.tensor([[0, 1], [1, 2]]),
affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),
meta={"some": "info"},
),
"test data statistics:\nType: <class 'monai.data.meta_tensor.MetaTensor'> torch.int64\n"
"Shape: torch.Size([2, 2])\nValue range: (0, 2)\n"
"Value: tensor([[0, 1],\n [1, 2]])\n"
"Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n"
" [0., 2., 0., 0.],\n"
" [0., 0., 2., 0.],\n"
" [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n"
"Additional info: 1.0\n",
]


class TestDataStats(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
)
def test_value(self, input_param, input_data, expected_print):
transform = DataStats(**input_param)
_ = transform(input_data)

@parameterized.expand([TEST_CASE_8])
@parameterized.expand([TEST_CASE_9, TEST_CASE_10])
def test_file(self, input_data, expected_print):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_data_stats.log")
Expand All @@ -158,6 +194,7 @@ def test_file(self, input_data, expected_print):
"data_shape": True,
"value_range": True,
"data_value": True,
"meta_info": True,
"additional_info": np.mean,
"name": name,
}
Expand Down
54 changes: 52 additions & 2 deletions tests/test_data_statsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from parameterized import parameterized

from monai.data.meta_tensor import MetaTensor
from monai.transforms import DataStatsd

TEST_CASE_1 = [
Expand Down Expand Up @@ -150,22 +151,70 @@
]

TEST_CASE_9 = [
{
"keys": "img",
"prefix": "test data",
"data_shape": True,
"value_range": True,
"data_value": True,
"meta_info": False,
"additional_info": np.mean,
"name": "DataStats",
},
{"img": np.array([[0, 1], [1, 2]])},
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
"Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
]

TEST_CASE_10 = [
{"img": np.array([[0, 1], [1, 2]])},
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
"Value: [[0 1]\n [1 2]]\n"
"Meta info: '(input is not a MetaTensor)'\n"
"Additional info: 1.0\n",
]

TEST_CASE_11 = [
{
"img": (
MetaTensor(
torch.tensor([[0, 1], [1, 2]]),
affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),
meta={"some": "info"},
)
)
},
"test data statistics:\nType: <class 'monai.data.meta_tensor.MetaTensor'> torch.int64\n"
"Shape: torch.Size([2, 2])\nValue range: (0, 2)\n"
"Value: tensor([[0, 1],\n [1, 2]])\n"
"Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n"
" [0., 2., 0., 0.],\n"
" [0., 0., 2., 0.],\n"
" [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n"
"Additional info: 1.0\n",
]


class TestDataStatsd(unittest.TestCase):

@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
[
TEST_CASE_1,
TEST_CASE_2,
TEST_CASE_3,
TEST_CASE_4,
TEST_CASE_5,
TEST_CASE_6,
TEST_CASE_7,
TEST_CASE_8,
TEST_CASE_9,
]
)
def test_value(self, input_param, input_data, expected_print):
transform = DataStatsd(**input_param)
_ = transform(input_data)

@parameterized.expand([TEST_CASE_9])
@parameterized.expand([TEST_CASE_10, TEST_CASE_11])
def test_file(self, input_data, expected_print):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_stats.log")
Expand All @@ -180,6 +229,7 @@ def test_file(self, input_data, expected_print):
"data_shape": True,
"value_range": True,
"data_value": True,
"meta_info": True,
"additional_info": np.mean,
"name": name,
}
Expand Down

0 comments on commit 19cc6f0

Please sign in to comment.