From 4addc5d56c26c5d00828bc8796cea73ef95372b7 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Thu, 20 Jul 2023 07:22:05 +0100 Subject: [PATCH 1/5] update base image to 2306 (#6741) Fixes #6740 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li --- .github/workflows/cron.yml | 18 +++++++++--------- .github/workflows/pythonapp-gpu.yml | 8 ++++---- Dockerfile | 2 +- tests/test_grid_distortion.py | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index c1015cd541..e986a27670 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -15,8 +15,8 @@ jobs: environment: - "PT191+CUDA113" - "PT110+CUDA113" - - "PT112+CUDA113" - - "PTLATEST+CUDA118" + - "PT113+CUDA113" + - "PTLATEST+CUDA121" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT191+CUDA113 @@ -25,12 +25,12 @@ jobs: - environment: PT110+CUDA113 pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113" base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PT112+CUDA113 - pytorch: "torch==1.12.1 torchvision==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu113" + - environment: PT113+CUDA113 + pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113" base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PTLATEST+CUDA118 + - environment: PTLATEST+CUDA121 pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118" - base: "nvcr.io/nvidia/pytorch:23.03-py3" # CUDA 11.8 + base: "nvcr.io/nvidia/pytorch:23.06-py3" # CUDA 12.1 container: image: ${{ matrix.base }} options: "--gpus all" @@ -76,7 +76,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:22.09", "pytorch:22.11", "pytorch:23.03"] + container: ["pytorch:22.10", "pytorch:23.06"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -121,7 +121,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:22.09", "pytorch:22.11", "pytorch:23.03"] + container: ["pytorch:23.06"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -221,7 +221,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:23.03-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:23.06-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, integration] steps: diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 723e060d11..65ee29f4e3 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -25,7 +25,7 @@ jobs: - "PT110+CUDA111" - "PT112+CUDA118DOCKER" - "PT113+CUDA116" - - "PT114+CUDA120DOCKER" + - "PT210+CUDA121DOCKER" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT19+CUDA114DOCKER @@ -42,10 +42,10 @@ jobs: - environment: PT113+CUDA116 pytorch: "torch==1.13.1 torchvision==0.14.1" base: "nvcr.io/nvidia/cuda:11.6.1-devel-ubuntu18.04" - - environment: PT114+CUDA120DOCKER - # 23.03: 2.0.0a0+1767026 + - environment: PT210+CUDA121DOCKER + # 23.06: 2.1.0a0+4136153 pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error - base: "nvcr.io/nvidia/pytorch:23.03-py3" + base: "nvcr.io/nvidia/pytorch:23.06-py3" container: image: ${{ matrix.base }} options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true # workaround for unsatisfied condition: cuda>=11.6 diff --git a/Dockerfile b/Dockerfile index 653dd1571c..adfa5390ed 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.03-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.06-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py index d776d49f4d..1a698140af 100644 --- a/tests/test_grid_distortion.py +++ b/tests/test_grid_distortion.py @@ -81,7 +81,7 @@ ) TESTS.append( [ - dict(num_cells=2, distort_steps=[(1.25,) * 3] * 3, mode="nearest", padding_mode="zeros"), + dict(num_cells=2, distort_steps=[(1.26,) * 3] * 3, mode="nearest", padding_mode="zeros"), p(np.indices([3, 3, 3])[:1].astype(np.float32)), p( np.array( From 644c9e5d58082f442e3e8230fcf6ae1d6e8ee5b8 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Sat, 22 Jul 2023 16:35:27 +0100 Subject: [PATCH 2/5] enhances auto3dseg data analyzer info (#6758) ### Description - output file name extension to be consistent with `self.fmt` - improve logging messages when writing files - export file format may be spelt as 'yml' instead of 'yaml' ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li --- monai/apps/auto3dseg/data_analyzer.py | 14 ++++++++------ monai/bundle/config_parser.py | 4 ++-- tests/test_auto3dseg.py | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index ded6390601..350bb61a34 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -70,7 +70,7 @@ class DataAnalyzer: the DataAnalyzer will skip looking for labels and all label-related operations. hist_bins: bins to compute histogram for each image channel. hist_range: ranges to compute histogram for each image channel. - fmt: format used to save the analysis results. Defaults to "yaml". + fmt: format used to save the analysis results. Currently support ``"json"`` and ``"yaml"``, defaults to "yaml". histogram_only: whether to only compute histograms. Defaults to False. extra_params: other optional arguments. Currently supported arguments are : 'allowed_shape_difference' (default 5) can be used to change the default tolerance of @@ -164,6 +164,7 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool: constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys] for prop in constant_props: if "stdev" in prop and np.any(prop["stdev"]): + logger.debug(f"summary image_stats {prop} has non-zero stdev {prop['stdev']}.") return False return True @@ -242,15 +243,16 @@ def get_all_case_stats(self, key="training", transform_list=None): if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): logger.info("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") if self.output_path: + logger.info(f"Writing data stats to {self.output_path}.") ConfigParser.export_config_file( result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False ) + by_case_path = self.output_path.replace(f".{self.fmt}", f"_by_case.{self.fmt}") + if by_case_path == self.output_path: # self.output_path not ended with self.fmt? + by_case_path += f".by_case.{self.fmt}" + logger.info(f"Writing by-case data stats to {by_case_path}, this may take a while.") ConfigParser.export_config_file( - result_bycase, - self.output_path.replace(".yaml", "_by_case.yaml"), - fmt=self.fmt, - default_flow_style=None, - sort_keys=False, + result_bycase, by_case_path, fmt=self.fmt, default_flow_style=None, sort_keys=False ) # release memory if self.device.type == "cuda": diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index d03ca8e43b..e2553a5ffd 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -438,12 +438,12 @@ def export_config_file(cls, config: dict, filepath: PathLike, fmt: str = "json", """ _filepath: str = str(Path(filepath)) - writer = look_up_option(fmt.lower(), {"json", "yaml"}) + writer = look_up_option(fmt.lower(), {"json", "yaml", "yml"}) with open(_filepath, "w") as f: if writer == "json": json.dump(config, f, **kwargs) return - if writer == "yaml": + if writer == "yaml" or writer == "yml": return yaml.safe_dump(config, f, **kwargs) raise ValueError(f"only support JSON or YAML config file so far, got {writer}.") diff --git a/tests/test_auto3dseg.py b/tests/test_auto3dseg.py index 53f25051ec..272fb52f1a 100644 --- a/tests/test_auto3dseg.py +++ b/tests/test_auto3dseg.py @@ -170,7 +170,7 @@ def setUp(self): work_dir = self.test_dir.name self.dataroot_dir = os.path.join(work_dir, "sim_dataroot") self.datalist_file = os.path.join(work_dir, "sim_datalist.json") - self.datastat_file = os.path.join(work_dir, "datastats.yaml") + self.datastat_file = os.path.join(work_dir, "datastats.yml") ConfigParser.export_config_file(sim_datalist, self.datalist_file) @parameterized.expand(SIM_CPU_TEST_CASES) From 3410794132f571deb5aa03aef47668ba901113b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vahit=20Bu=C4=9Fra=20YE=C5=9E=C4=B0LKAYNAK?= Date: Mon, 24 Jul 2023 01:48:46 +0200 Subject: [PATCH 3/5] Add ultrasound confidence map to transforms (#6709) ### Description This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005 to compute confidence maps on ultrasound images. ### Possible Problems - I am not entirely sure if the "transforms" section is the right place for this method but I found it the most suitable since it is not "deep learning" and it is "pre-processing" in a way. - Current version of the implementation requires GNU Octave to be installed and defined in the path. This is an odd dependency, I am aware of that, yet using SciPy does not provide satisfactory results in terms of speed. If this kind of dependency is not suitable, I also have a pure SciPy implementation, yet it runs about x15 slower, and it is slow to work in real-time, I am open to any feedback. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Vahit Bugra YESILKAYNAK --- docs/requirements.txt | 1 + docs/source/installation.md | 8 +- monai/config/deviceconfig.py | 1 + monai/data/__init__.py | 2 + monai/data/ultrasound_confidence_map.py | 352 ++++++++ monai/transforms/__init__.py | 1 + monai/transforms/intensity/array.py | 80 ++ requirements-dev.txt | 2 +- setup.cfg | 3 + tests/min_tests.py | 1 + ...est_ultrasound_confidence_map_transform.py | 757 ++++++++++++++++++ 11 files changed, 1203 insertions(+), 5 deletions(-) create mode 100644 monai/data/ultrasound_confidence_map.py create mode 100644 tests/test_ultrasound_confidence_map_transform.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 07b189dd79..701b7998a9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,6 +6,7 @@ itk>=5.2 nibabel parameterized scikit-image>=0.19.0 +scipy>=1.7.1 tensorboard commonmark==0.9.1 recommonmark==0.6.0 diff --git a/docs/source/installation.md b/docs/source/installation.md index fb0409bf89..bc79040546 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -10,8 +10,8 @@ - [Uninstall the packages](#uninstall-the-packages) - [From conda-forge](#from-conda-forge) - [From GitHub](#from-github) - - [Option 1 (as a part of your system-wide module)](#option-1-as-a-part-of-your-system-wide-module) - - [Option 2 (editable installation)](#option-2-editable-installation) + - [Option 1 (as a part of your system-wide module):](#option-1-as-a-part-of-your-system-wide-module) + - [Option 2 (editable installation):](#option-2-editable-installation) - [Validating the install](#validating-the-install) - [MONAI version string](#monai-version-string) - [From DockerHub](#from-dockerhub) @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr] ``` -which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, +which correspond to `nibabel`, `scikit-image`, `scipy`, `pillow`, `tensorboard`, `gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 6ee454ac06..1bd5f1a4cd 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -71,6 +71,7 @@ def get_optional_config_values(): output["ITK"] = get_package_version("itk") output["Nibabel"] = get_package_version("nibabel") output["scikit-image"] = get_package_version("skimage") + output["scipy"] = get_package_version("scipy") output["Pillow"] = get_package_version("PIL") output["Tensorboard"] = get_package_version("tensorboard") output["gdown"] = get_package_version("gdown") diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 0e9759aaf1..9339897d7a 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -150,3 +150,5 @@ def reduce_meta_tensor(meta_tensor): return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata) ForkingPickler.register(MetaTensor, reduce_meta_tensor) + +from .ultrasound_confidence_map import UltrasoundConfidenceMap diff --git a/monai/data/ultrasound_confidence_map.py b/monai/data/ultrasound_confidence_map.py new file mode 100644 index 0000000000..8aff2988ea --- /dev/null +++ b/monai/data/ultrasound_confidence_map.py @@ -0,0 +1,352 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np +from numpy.typing import NDArray + +from monai.utils import min_version, optional_import + +__all__ = ["UltrasoundConfidenceMap"] + +cv2, _ = optional_import("cv2") +csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix") +spsolve, _ = optional_import("scipy.sparse.linalg", "1.7.1", min_version, "spsolve") +hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert") + + +class UltrasoundConfidenceMap: + """Compute confidence map from an ultrasound image. + This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005. + It generates a confidence map by setting source and sink points in the image and computing the probability + for random walks to reach the source for each pixel. + + Args: + alpha (float, optional): Alpha parameter. Defaults to 2.0. + beta (float, optional): Beta parameter. Defaults to 90.0. + gamma (float, optional): Gamma parameter. Defaults to 0.05. + mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'. + sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling + the transform. Can be 'all', 'mid', 'min', or 'mask'. + """ + + def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all"): + # The hyperparameters for confidence map estimation + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.mode = mode + self.sink_mode = sink_mode + + # The precision to use for all computations + self.eps = np.finfo("float64").eps + + # Store sink indices for external use + self._sink_indices = np.array([], dtype="float64") + + def sub2ind(self, size: tuple[int, ...], rows: NDArray, cols: NDArray) -> NDArray: + """Converts row and column subscripts into linear indices, + basically the copy of the MATLAB function of the same name. + https://www.mathworks.com/help/matlab/ref/sub2ind.html + + This function is Pythonic so the indices start at 0. + + Args: + size Tuple[int]: Size of the matrix + rows (NDArray): Row indices + cols (NDArray): Column indices + + Returns: + indices (NDArray): 1-D array of linear indices + """ + indices: NDArray = rows + cols * size[0] + return indices + + def get_seed_and_labels( + self, data: NDArray, sink_mode: str = "all", sink_mask: NDArray | None = None + ) -> tuple[NDArray, NDArray]: + """Get the seed and label arrays for the max-flow algorithm + + Args: + data: Input array + sink_mode (str, optional): Sink mode. Defaults to 'all'. + sink_mask (NDArray, optional): Sink mask. Defaults to None. + + Returns: + Tuple[NDArray, NDArray]: Seed and label arrays + """ + + # Seeds and labels (boundary conditions) + seeds = np.array([], dtype="float64") + labels = np.array([], dtype="float64") + + # Indices for all columns + sc = np.arange(data.shape[1], dtype="float64") + + # SOURCE ELEMENTS - 1st matrix row + # Indices for 1st row, it will be broadcasted with sc + sr_up = np.array([0]) + seed = self.sub2ind(data.shape, sr_up, sc).astype("float64") + seed = np.unique(seed) + seeds = np.concatenate((seeds, seed)) + + # Label 1 + label = np.ones_like(seed) + labels = np.concatenate((labels, label)) + + # Create seeds for sink elements + + if sink_mode == "all": + # All elements in the last row + sr_down = np.ones_like(sc) * (data.shape[0] - 1) + self._sink_indices = np.array([sr_down, sc], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc).astype("float64") + + elif sink_mode == "mid": + # Middle element in the last row + sc_down = np.array([data.shape[1] // 2]) + sr_down = np.ones_like(sc_down) * (data.shape[0] - 1) + self._sink_indices = np.array([sr_down, sc_down], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64") + + elif sink_mode == "min": + # Minimum element in the last row (excluding 10% from the edges) + ten_percent = int(data.shape[1] * 0.1) + min_val = np.min(data[-1, ten_percent:-ten_percent]) + min_idxs = np.where(data[-1, ten_percent:-ten_percent] == min_val)[0] + ten_percent + sc_down = min_idxs + sr_down = np.ones_like(sc_down) * (data.shape[0] - 1) + self._sink_indices = np.array([sr_down, sc_down], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64") + + elif sink_mode == "mask": + # All elements in the mask + coords = np.where(sink_mask != 0) + sr_down = coords[0] + sc_down = coords[1] + self._sink_indices = np.array([sr_down, sc_down], dtype="int32") + seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64") + + seed = np.unique(seed) + seeds = np.concatenate((seeds, seed)) + + # Label 2 + label = np.ones_like(seed) * 2 + labels = np.concatenate((labels, label)) + + return seeds, labels + + def normalize(self, inp: NDArray) -> NDArray: + """Normalize an array to [0, 1]""" + normalized_array: NDArray = (inp - np.min(inp)) / (np.ptp(inp) + self.eps) + return normalized_array + + def attenuation_weighting(self, img: NDArray, alpha: float) -> NDArray: + """Compute attenuation weighting + + Args: + img (NDArray): Image + alpha: Attenuation coefficient (see publication) + + Returns: + w (NDArray): Weighting expressing depth-dependent attenuation + """ + + # Create depth vector and repeat it for each column + dw = np.linspace(0, 1, img.shape[0], dtype="float64") + dw = np.tile(dw.reshape(-1, 1), (1, img.shape[1])) + + w: NDArray = 1.0 - np.exp(-alpha * dw) # Compute exp inline + + return w + + def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, beta: float, gamma: float): + """Compute 6-Connected Laplacian for confidence estimation problem + + Args: + padded_index (NDArray): The index matrix of the image with boundary padding. + padded_image (NDArray): The padded image. + beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function. + gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. + + Returns: + L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation. + """ + + m, _ = padded_index.shape + + padded_index = padded_index.T.flatten() + padded_image = padded_image.T.flatten() + + p = np.where(padded_index > 0)[0] + + i = padded_index[p] - 1 # Index vector + j = padded_index[p] - 1 # Index vector + # Entries vector, initially for diagonal + s = np.zeros_like(p, dtype="float64") + + edge_templates = [ + -1, # Vertical edges + 1, + m - 1, # Diagonal edges + m + 1, + -m - 1, + -m + 1, + m, # Horizontal edges + -m, + ] + + vertical_end = None + + for iter_idx, k in enumerate(edge_templates): + neigh_idxs = padded_index[p + k] + + q = np.where(neigh_idxs > 0)[0] + + ii = padded_index[p[q]] - 1 + i = np.concatenate((i, ii)) + jj = neigh_idxs[q] - 1 + j = np.concatenate((j, jj)) + w = np.abs(padded_image[p[ii]] - padded_image[p[jj]]) # Intensity derived weight + s = np.concatenate((s, w)) + + if iter_idx == 1: + vertical_end = s.shape[0] # Vertical edges length + elif iter_idx == 5: + s.shape[0] # Diagonal edges length + + # Normalize weights + s = self.normalize(s) + + # Horizontal penalty + s[:vertical_end] += gamma + # s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2) + # since the diagonal edges are longer yet does not exist in the original code + + # Normalize differences + s = self.normalize(s) + + # Gaussian weighting function + s = -( + (np.exp(-beta * s, dtype="float64")) + 1.0e-6 + ) # --> This epsilon changes results drastically default: 1.e-6 + + # Create Laplacian, diagonal missing + lap = csc_matrix((s, (i, j))) + + # Reset diagonal weights to zero for summing + # up the weighted edge degree in the next step + lap.setdiag(0) + + # Weighted edge degree + diag = np.abs(lap.sum(axis=0).A)[0] + + # Finalize Laplacian by completing the diagonal + lap.setdiag(diag) + + return lap + + def _solve_linear_system(self, lap, rhs): + x = spsolve(lap, rhs) + + return x + + def confidence_estimation(self, img, seeds, labels, beta, gamma): + """Compute confidence map + + Args: + img (NDArray): Processed image. + seeds (NDArray): Seeds for the random walks framework. These are indices of the source and sink nodes. + labels (NDArray): Labels for the random walks framework. These represent the classes or groups of the seeds. + beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function. + gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. + + Returns: + map: Confidence map which shows the probability of each pixel belonging to the source or sink group. + """ + + # Index matrix with boundary padding + idx = np.arange(1, img.shape[0] * img.shape[1] + 1).reshape(img.shape[1], img.shape[0]).T + pad = 1 + + padded_idx = np.pad(idx, (pad, pad), "constant", constant_values=(0, 0)) + padded_img = np.pad(img, (pad, pad), "constant", constant_values=(0, 0)) + + # Laplacian + lap = self.confidence_laplacian(padded_idx, padded_img, beta, gamma) + + # Select marked columns from Laplacian to create L_M and B^T + b = lap[:, seeds] + + # Select marked nodes to create B^T + n = np.sum(padded_idx > 0).item() + i_u = np.setdiff1d(np.arange(n), seeds.astype(int)) # Index of unmarked nodes + b = b[i_u, :] + + # Remove marked nodes from Laplacian by deleting rows and cols + keep_indices = np.setdiff1d(np.arange(lap.shape[0]), seeds) + lap = csc_matrix(lap[keep_indices, :][:, keep_indices]) + + # Define M matrix + m = np.zeros((seeds.shape[0], 1), dtype="float64") + m[:, 0] = labels == 1 + + # Right-handside (-B^T*M) + rhs = -b @ m # type: ignore + + # Solve linear system + x = self._solve_linear_system(lap, rhs) + + # Prepare output + probabilities = np.zeros((n,), dtype="float64") + # Probabilities for unmarked nodes + probabilities[i_u] = x + # Max probability for marked node + probabilities[seeds[labels == 1].astype(int)] = 1.0 + + # Final reshape with same size as input image (no padding) + probabilities = probabilities.reshape((img.shape[1], img.shape[0])).T + + return probabilities + + def __call__(self, data: NDArray, sink_mask: NDArray | None = None) -> NDArray: + """Compute the confidence map + + Args: + data (NDArray): RF ultrasound data (one scanline per column) [H x W] 2D array + + Returns: + map (NDArray): Confidence map [H x W] 2D array + """ + + # Normalize data + data = data.astype("float64") + data = self.normalize(data) + + if self.mode == "RF": + # MATLAB hilbert applies the Hilbert transform to columns + data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore + + seeds, labels = self.get_seed_and_labels(data, self.sink_mode, sink_mask) + + # Attenuation with Beer-Lambert + w = self.attenuation_weighting(data, self.alpha) + + # Apply weighting directly to image + # Same as applying it individually during the formation of the + # Laplacian + data = data * w + + # Find condidence values + map_: NDArray = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma) + + return map_ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cdad6ec6c3..477ec7a8bd 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -129,6 +129,7 @@ ShiftIntensity, StdShiftIntensity, ThresholdIntensity, + UltrasoundConfidenceMapTransform, ) from .intensity.dictionary import ( AdjustContrastd, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index d5a90e1964..56d2778090 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta +from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform @@ -38,6 +39,7 @@ skimage, _ = optional_import("skimage", "0.19.0", min_version) + __all__ = [ "RandGaussianNoise", "RandRicianNoise", @@ -77,6 +79,7 @@ "RandIntensityRemap", "ForegroundMask", "ComputeHoVerMaps", + "UltrasoundConfidenceMapTransform", ] @@ -2577,3 +2580,80 @@ def __call__(self, mask: NdarrayOrTensor): hv_maps = convert_to_tensor(np.concatenate([h_map, v_map]), track_meta=get_track_meta()) return hv_maps + + +class UltrasoundConfidenceMapTransform(Transform): + """Compute confidence map from an ultrasound image. + This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005. + It generates a confidence map by setting source and sink points in the image and computing the probability + for random walks to reach the source for each pixel. + + Args: + alpha (float, optional): Alpha parameter. Defaults to 2.0. + beta (float, optional): Beta parameter. Defaults to 90.0. + gamma (float, optional): Gamma parameter. Defaults to 0.05. + mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'. + sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when + calling the transform. Can be one of 'all', 'mid', 'min', 'mask'. + """ + + def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all") -> None: + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.mode = mode + self.sink_mode = sink_mode + + if self.mode not in ["B", "RF"]: + raise ValueError(f"Unknown mode: {self.mode}. Supported modes are 'B' and 'RF'.") + + if self.sink_mode not in ["all", "mid", "min", "mask"]: + raise ValueError( + f"Unknown sink mode: {self.sink_mode}. Supported modes are 'all', 'mid', 'min' and 'mask'." + ) + + self._compute_conf_map = UltrasoundConfidenceMap(self.alpha, self.beta, self.gamma, self.mode, self.sink_mode) + + def __call__(self, img: NdarrayOrTensor, mask: NdarrayOrTensor | None = None) -> NdarrayOrTensor: + """Compute confidence map from an ultrasound image. + + Args: + img (ndarray or Tensor): Ultrasound image of shape [1, H, W] or [1, D, H, W]. If the image has channels, + they will be averaged before computing the confidence map. + mask (ndarray or Tensor, optional): Mask of shape [1, H, W]. Defaults to None. Must be + provided when sink mode is 'mask'. The non-zero values of the mask are used as sink points. + + Returns: + ndarray or Tensor: Confidence map of shape [1, H, W]. + """ + + if self.sink_mode == "mask" and mask is None: + raise ValueError("A mask must be provided when sink mode is 'mask'.") + + if img.shape[0] != 1: + raise ValueError("The correct shape of the image is [1, H, W] or [1, D, H, W].") + + _img = convert_to_tensor(img, track_meta=get_track_meta()) + img_np, *_ = convert_data_type(_img, np.ndarray) + img_np = img_np[0] # Remove the first dimension + + mask_np = None + if mask is not None: + mask = convert_to_tensor(mask, dtype=torch.bool, track_meta=get_track_meta()) + mask_np, *_ = convert_data_type(mask, np.ndarray) + mask_np = mask_np[0] # Remove the first dimension + + # If the image is RGB, convert it to grayscale + if len(img_np.shape) == 3: + img_np = np.mean(img_np, axis=0) + + if mask_np is not None and mask_np.shape != img_np.shape: + raise ValueError("The mask must have the same shape as the image.") + + # Compute confidence map + conf_map: NdarrayOrTensor = self._compute_conf_map(img_np, mask_np) + + if type(img) is torch.Tensor: + conf_map = torch.from_numpy(conf_map) + + return conf_map diff --git a/requirements-dev.txt b/requirements-dev.txt index 32b36d457c..0ad08e56d2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ -r requirements-min.txt pytorch-ignite==0.4.11 gdown>=4.4.0 -scipy +scipy>=1.7.1 itk>=5.2 nibabel pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 diff --git a/setup.cfg b/setup.cfg index e059dced76..a61a42395f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,7 @@ all = nibabel ninja scikit-image>=0.14.2 + scipy>=1.7.1 pillow tensorboard gdown>=4.4.0 @@ -86,6 +87,8 @@ ninja = ninja skimage = scikit-image>=0.14.2 +scipy = + scipy>=1.7.1 pillow = pillow!=8.3.0 tensorboard = diff --git a/tests/min_tests.py b/tests/min_tests.py index e3b09e7c84..9a7d920a2e 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -204,6 +204,7 @@ def run_testsuit(): "test_spatial_combine_transforms", "test_bundle_workflow", "test_zarr_avg_merger", + "test_ultrasound_confidence_map_transform", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_ultrasound_confidence_map_transform.py b/tests/test_ultrasound_confidence_map_transform.py new file mode 100644 index 0000000000..3325f297f0 --- /dev/null +++ b/tests/test_ultrasound_confidence_map_transform.py @@ -0,0 +1,757 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch + +from monai.transforms import UltrasoundConfidenceMapTransform +from tests.utils import assert_allclose + +TEST_INPUT = np.array( + [ + [1, 2, 3, 23, 13, 22, 5, 1, 2, 3], + [1, 2, 3, 12, 4, 6, 9, 1, 2, 3], + [1, 2, 3, 8, 7, 10, 11, 1, 2, 3], + [1, 2, 3, 14, 15, 16, 17, 1, 2, 3], + [1, 2, 3, 18, 19, 20, 21, 1, 2, 3], + [1, 2, 3, 24, 25, 26, 27, 1, 2, 3], + [1, 2, 3, 28, 29, 30, 31, 1, 2, 3], + [1, 2, 3, 32, 33, 34, 35, 1, 2, 3], + [1, 2, 3, 36, 37, 38, 39, 1, 2, 3], + [1, 2, 3, 40, 41, 42, 43, 1, 2, 3], + ] +) + +TEST_MASK = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) + + +SINK_ALL_OUTPUT = np.array( + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [ + 0.97514489, + 0.96762971, + 0.96164186, + 0.95463443, + 0.9941512, + 0.99023054, + 0.98559401, + 0.98230057, + 0.96601224, + 0.95119599, + ], + [ + 0.92960533, + 0.92638451, + 0.9056675, + 0.9487176, + 0.9546961, + 0.96165853, + 0.96172303, + 0.92686401, + 0.92122613, + 0.89957239, + ], + [ + 0.86490963, + 0.85723665, + 0.83798141, + 0.90816201, + 0.90816097, + 0.90815301, + 0.9081427, + 0.85933627, + 0.85146935, + 0.82948586, + ], + [ + 0.77430346, + 0.76731372, + 0.74372311, + 0.89128774, + 0.89126885, + 0.89125066, + 0.89123521, + 0.76858589, + 0.76106647, + 0.73807776, + ], + [ + 0.66098109, + 0.65327697, + 0.63090644, + 0.33086588, + 0.3308383, + 0.33081937, + 0.33080718, + 0.6557468, + 0.64825099, + 0.62593375, + ], + [ + 0.52526945, + 0.51832586, + 0.49709412, + 0.25985059, + 0.25981009, + 0.25977729, + 0.25975222, + 0.52118958, + 0.51426328, + 0.49323164, + ], + [ + 0.3697845, + 0.36318971, + 0.34424661, + 0.17386804, + 0.17382046, + 0.17377993, + 0.17374668, + 0.36689317, + 0.36036096, + 0.3415582, + ], + [ + 0.19546374, + 0.1909659, + 0.17319999, + 0.08423318, + 0.08417993, + 0.08413242, + 0.08409104, + 0.19393909, + 0.18947485, + 0.17185031, + ], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] +) + +SINK_MID_OUTPUT = np.array( + [ + [ + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + ], + [ + 9.99996103e-01, + 9.99994823e-01, + 9.99993550e-01, + 9.99930863e-01, + 9.99990782e-01, + 9.99984683e-01, + 9.99979000e-01, + 9.99997804e-01, + 9.99995985e-01, + 9.99994325e-01, + ], + [ + 9.99989344e-01, + 9.99988600e-01, + 9.99984099e-01, + 9.99930123e-01, + 9.99926598e-01, + 9.99824297e-01, + 9.99815032e-01, + 9.99991228e-01, + 9.99990881e-01, + 9.99988462e-01, + ], + [ + 9.99980787e-01, + 9.99979264e-01, + 9.99975828e-01, + 9.59669286e-01, + 9.59664779e-01, + 9.59656566e-01, + 9.59648332e-01, + 9.99983882e-01, + 9.99983038e-01, + 9.99980732e-01, + ], + [ + 9.99970181e-01, + 9.99969032e-01, + 9.99965730e-01, + 9.45197806e-01, + 9.45179593e-01, + 9.45163629e-01, + 9.45151458e-01, + 9.99973352e-01, + 9.99973254e-01, + 9.99971098e-01, + ], + [ + 9.99958608e-01, + 9.99957307e-01, + 9.99953444e-01, + 4.24743523e-01, + 4.24713305e-01, + 4.24694646e-01, + 4.24685271e-01, + 9.99960948e-01, + 9.99961829e-01, + 9.99960347e-01, + ], + [ + 9.99946675e-01, + 9.99945139e-01, + 9.99940312e-01, + 3.51353224e-01, + 3.51304003e-01, + 3.51268260e-01, + 3.51245366e-01, + 9.99947688e-01, + 9.99950165e-01, + 9.99949512e-01, + ], + [ + 9.99935877e-01, + 9.99934088e-01, + 9.99928982e-01, + 2.51197134e-01, + 2.51130273e-01, + 2.51080014e-01, + 2.51045852e-01, + 9.99936187e-01, + 9.99939716e-01, + 9.99940022e-01, + ], + [ + 9.99927846e-01, + 9.99925911e-01, + 9.99920188e-01, + 1.31550973e-01, + 1.31462736e-01, + 1.31394558e-01, + 1.31346069e-01, + 9.99927275e-01, + 9.99932142e-01, + 9.99933313e-01, + ], + [ + 9.99924204e-01, + 9.99922004e-01, + 9.99915767e-01, + 3.04861147e-04, + 1.95998056e-04, + 0.00000000e00, + 2.05182682e-05, + 9.99923115e-01, + 9.99928835e-01, + 9.99930535e-01, + ], + ] +) + +SINK_MIN_OUTPUT = np.array( + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [ + 0.99997545, + 0.99996582, + 0.99995245, + 0.99856594, + 0.99898314, + 0.99777223, + 0.99394423, + 0.98588113, + 0.97283215, + 0.96096504, + ], + [ + 0.99993872, + 0.99993034, + 0.9998832, + 0.9986147, + 0.99848741, + 0.9972981, + 0.99723719, + 0.94157173, + 0.9369832, + 0.91964243, + ], + [ + 0.99990802, + 0.99989475, + 0.99986873, + 0.98610197, + 0.98610047, + 0.98609749, + 0.98609423, + 0.88741275, + 0.88112911, + 0.86349156, + ], + [ + 0.99988924, + 0.99988509, + 0.99988698, + 0.98234089, + 0.98233591, + 0.98233065, + 0.98232562, + 0.81475172, + 0.80865978, + 0.79033138, + ], + [ + 0.99988418, + 0.99988484, + 0.99988323, + 0.86796555, + 0.86795874, + 0.86795283, + 0.86794756, + 0.72418193, + 0.71847704, + 0.70022037, + ], + [ + 0.99988241, + 0.99988184, + 0.99988103, + 0.85528225, + 0.85527303, + 0.85526389, + 0.85525499, + 0.61716519, + 0.61026209, + 0.59503671, + ], + [ + 0.99988015, + 0.99987985, + 0.99987875, + 0.84258114, + 0.84257121, + 0.84256042, + 0.84254897, + 0.48997924, + 0.49083978, + 0.46891561, + ], + [ + 0.99987865, + 0.99987827, + 0.9998772, + 0.83279589, + 0.83278624, + 0.83277384, + 0.83275897, + 0.36345545, + 0.33690244, + 0.35696828, + ], + [ + 0.99987796, + 0.99987756, + 0.99987643, + 0.82873223, + 0.82872648, + 0.82871803, + 0.82870711, + 0.0, + 0.26106012, + 0.29978657, + ], + ] +) + +SINK_MASK_OUTPUT = np.array( + [ + [ + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 2.86416400e-01, + 7.93271181e-01, + 5.81341234e-01, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.98395623e-01, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 2.66733297e-01, + 2.80741490e-01, + 4.14078784e-02, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 7.91676486e-04, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.86244537e-04, + 1.53413401e-04, + 7.85806495e-05, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 5.09797387e-06, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 9.62904581e-07, + 7.23946225e-07, + 3.68824440e-07, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 4.79525316e-08, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.50939343e-10, + 1.17724874e-10, + 6.21760843e-11, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 6.08922784e-10, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 2.57593754e-13, + 1.94066716e-13, + 9.83784370e-14, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 9.80828665e-12, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 4.22323494e-16, + 3.17556633e-16, + 1.60789400e-16, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.90789819e-13, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 7.72677888e-19, + 5.83029424e-19, + 2.95946659e-19, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 4.97038275e-15, + ], + [ + 2.71345908e-24, + 5.92006757e-24, + 2.25580089e-23, + 3.82601970e-18, + 3.82835349e-18, + 3.83302158e-18, + 3.84002606e-18, + 8.40760586e-16, + 1.83433696e-15, + 1.11629633e-15, + ], + ] +) + + +class TestUltrasoundConfidenceMapTransform(unittest.TestCase): + def setUp(self): + self.input_img_np = np.expand_dims(TEST_INPUT, axis=0) # mock image (numpy array) + self.input_mask_np = np.expand_dims(TEST_MASK, axis=0) # mock mask (numpy array) + + self.input_img_torch = torch.from_numpy(TEST_INPUT).unsqueeze(0) # mock image (torch tensor) + self.input_mask_torch = torch.from_numpy(TEST_MASK).unsqueeze(0) # mock mask (torch tensor) + + def test_parameters(self): + # Unknown mode + with self.assertRaises(ValueError): + UltrasoundConfidenceMapTransform(mode="unknown") + + # Unknown sink_mode + with self.assertRaises(ValueError): + UltrasoundConfidenceMapTransform(sink_mode="unknown") + + def test_rgb(self): + # RGB image + input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 3, axis=0), axis=0) + input_img_rgb_torch = torch.from_numpy(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="all") + result_torch = transform(input_img_rgb_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mid") + result_torch = transform(input_img_rgb_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(sink_mode="min") + result_torch = transform(input_img_rgb_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mask") + result_torch = transform(input_img_rgb_torch, self.input_mask_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb, self.input_mask_np) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4) + + def test_multi_channel_2d(self): + # 2D multi-channel image + input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 17, axis=0), axis=0) + input_img_rgb_torch = torch.from_numpy(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="all") + result_torch = transform(input_img_rgb_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mid") + result_torch = transform(input_img_rgb_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(sink_mode="min") + result_torch = transform(input_img_rgb_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mask") + result_torch = transform(input_img_rgb_torch, self.input_mask_torch) + self.assertIsInstance(result_torch, torch.Tensor) + assert_allclose(result_torch, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4) + result_np = transform(input_img_rgb, self.input_mask_np) + self.assertIsInstance(result_np, np.ndarray) + assert_allclose(result_np, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4) + + def test_non_one_first_dim(self): + # Image without first dimension as 1 + input_img_rgb = np.repeat(self.input_img_np, 3, axis=0) + input_img_rgb_torch = torch.from_numpy(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="all") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mid") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="min") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mask") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch, self.input_mask_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb, self.input_mask_np) + + def test_no_first_dim(self): + # Image without first dimension + input_img_rgb = self.input_img_np[0] + input_img_rgb_torch = torch.from_numpy(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="all") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mid") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="min") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb) + + transform = UltrasoundConfidenceMapTransform(sink_mode="mask") + with self.assertRaises(ValueError): + transform(input_img_rgb_torch, self.input_mask_torch) + with self.assertRaises(ValueError): + transform(input_img_rgb, self.input_mask_np) + + def test_sink_all(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="all") + + # This should not raise an exception for torch tensor + result_torch = transform(self.input_img_torch) + self.assertIsInstance(result_torch, torch.Tensor) + + # This should not raise an exception for numpy array + result_np = transform(self.input_img_np) + self.assertIsInstance(result_np, np.ndarray) + + def test_sink_mid(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="mid") + + # This should not raise an exception for torch tensor + result_torch = transform(self.input_img_torch) + self.assertIsInstance(result_torch, torch.Tensor) + + # This should not raise an exception for numpy array + result_np = transform(self.input_img_np) + self.assertIsInstance(result_np, np.ndarray) + + def test_sink_min(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="min") + + # This should not raise an exception for torch tensor + result_torch = transform(self.input_img_torch) + self.assertIsInstance(result_torch, torch.Tensor) + + # This should not raise an exception for numpy array + result_np = transform(self.input_img_np) + self.assertIsInstance(result_np, np.ndarray) + + def test_sink_mask(self): + transform = UltrasoundConfidenceMapTransform(sink_mode="mask") + + # This should not raise an exception for torch tensor with mask + result_torch = transform(self.input_img_torch, self.input_mask_torch) + self.assertIsInstance(result_torch, torch.Tensor) + + # This should not raise an exception for numpy array with mask + result_np = transform(self.input_img_np, self.input_mask_np) + self.assertIsInstance(result_np, np.ndarray) + + # This should raise an exception for torch tensor without mask + with self.assertRaises(ValueError): + transform(self.input_img_torch) + + # This should raise an exception for numpy array without mask + with self.assertRaises(ValueError): + transform(self.input_img_np) + + def test_func(self): + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="all") + output = transform(self.input_img_np) + assert_allclose(output, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mid") + output = transform(self.input_img_np) + assert_allclose(output, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="min") + output = transform(self.input_img_np) + assert_allclose(output, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mask") + output = transform(self.input_img_np, self.input_mask_np) + assert_allclose(output, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="all") + output = transform(self.input_img_torch) + assert_allclose(output, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mid") + output = transform(self.input_img_torch) + assert_allclose(output, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="min") + output = transform(self.input_img_torch) + assert_allclose(output, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4) + + transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode="B", sink_mode="mask") + output = transform(self.input_img_torch, self.input_mask_torch) + assert_allclose(output, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() From 28c9083f0e94934a72f9dc04d600ce2f2bd3f12b Mon Sep 17 00:00:00 2001 From: monai-bot <64792179+monai-bot@users.noreply.github.com> Date: Mon, 24 Jul 2023 11:48:36 +0100 Subject: [PATCH 4/5] auto updates (#6760) --- .github/workflows/docker.yml | 5 ++--- monai/apps/deepgrow/transforms.py | 4 ++-- monai/data/ultrasound_confidence_map.py | 4 ++-- monai/transforms/intensity/array.py | 1 - tests/test_ultrasound_confidence_map_transform.py | 1 - 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index d498386d1d..2c809b9817 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -85,12 +85,11 @@ jobs: container: image: docker://projectmonai/monai:latest options: "--shm-size=4g --ipc=host" - runs-on: ubuntu-latest + runs-on: [self-hosted, linux, X64, docker] steps: - name: Import run: | - export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1) - echo $CUDA_VISIBLE_DEVICES + export CUDA_VISIBLE_DEVICES= # cpu-only python -c 'import monai; monai.config.print_debug_info()' cd /opt/monai ls -al diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 7078777f92..6b9894227c 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -441,8 +441,8 @@ def __call__(self, data): if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) - box_start = np.array([s.start for s in cropper.slices]) # type: ignore - box_end = np.array([s.stop for s in cropper.slices]) # type: ignore + box_start = np.array([s.start for s in cropper.slices]) + box_end = np.array([s.stop for s in cropper.slices]) else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) diff --git a/monai/data/ultrasound_confidence_map.py b/monai/data/ultrasound_confidence_map.py index 8aff2988ea..03813e7559 100644 --- a/monai/data/ultrasound_confidence_map.py +++ b/monai/data/ultrasound_confidence_map.py @@ -301,7 +301,7 @@ def confidence_estimation(self, img, seeds, labels, beta, gamma): m[:, 0] = labels == 1 # Right-handside (-B^T*M) - rhs = -b @ m # type: ignore + rhs = -b @ m # Solve linear system x = self._solve_linear_system(lap, rhs) @@ -334,7 +334,7 @@ def __call__(self, data: NDArray, sink_mask: NDArray | None = None) -> NDArray: if self.mode == "RF": # MATLAB hilbert applies the Hilbert transform to columns - data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore + data = np.abs(hilbert(data, axis=0)).astype("float64") seeds, labels = self.get_seed_and_labels(data, self.sink_mode, sink_mask) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 56d2778090..f8eadcfb1b 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -39,7 +39,6 @@ skimage, _ = optional_import("skimage", "0.19.0", min_version) - __all__ = [ "RandGaussianNoise", "RandRicianNoise", diff --git a/tests/test_ultrasound_confidence_map_transform.py b/tests/test_ultrasound_confidence_map_transform.py index 3325f297f0..fbf0c4fe97 100644 --- a/tests/test_ultrasound_confidence_map_transform.py +++ b/tests/test_ultrasound_confidence_map_transform.py @@ -49,7 +49,6 @@ ] ) - SINK_ALL_OUTPUT = np.array( [ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], From 2800a76764bc2f58402718fcef8b38f3a60776d0 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 25 Jul 2023 14:58:25 +0530 Subject: [PATCH 5/5] feat: add `clDice` loss (#6763) Fixes #5938 ### Description This PR aims to add the `SoftclDiceLoss` and the `SoftDiceclDiceLoss` from [clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation](https://openaccess.thecvf.com/content/CVPR2021/papers/Shit_clDice_-_A_Novel_Topology-Preserving_Loss_Function_for_Tubular_Structure_CVPR_2021_paper.pdf) ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Saurav Maheshkar --- monai/losses/__init__.py | 1 + monai/losses/cldice.py | 184 ++++++++++++++++++++++++++++++++++++++ tests/test_cldice_loss.py | 56 ++++++++++++ 3 files changed, 241 insertions(+) create mode 100644 monai/losses/cldice.py create mode 100644 tests/test_cldice_loss.py diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 9e09b0b123..db6b133ef0 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py new file mode 100644 index 0000000000..5c6a721e1d --- /dev/null +++ b/monai/losses/cldice.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore + """ + Perform soft erosion on the input image + + Args: + img: the shape should be BCH(WD) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6 + """ + if len(img.shape) == 4: + p1 = -(F.max_pool2d(-img, (3, 1), (1, 1), (1, 0))) + p2 = -(F.max_pool2d(-img, (1, 3), (1, 1), (0, 1))) + return torch.min(p1, p2) # type: ignore + elif len(img.shape) == 5: + p1 = -(F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0))) + p2 = -(F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0))) + p3 = -(F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1))) + return torch.min(torch.min(p1, p2), p3) # type: ignore + + +def soft_dilate(img: torch.Tensor) -> torch.Tensor: # type: ignore + """ + Perform soft dilation on the input image + + Args: + img: the shape should be BCH(WD) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18 + """ + if len(img.shape) == 4: + return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) # type: ignore + elif len(img.shape) == 5: + return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) # type: ignore + + +def soft_open(img: torch.Tensor) -> torch.Tensor: + """ + Wrapper function to perform soft opening on the input image + + Args: + img: the shape should be BCH(WD) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25 + """ + eroded_image = soft_erode(img) + dilated_image = soft_dilate(eroded_image) + return dilated_image + + +def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: + """ + Perform soft skeletonization on the input image + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29 + + Args: + img: the shape should be BCH(WD) + iter_: number of iterations for skeletonization + + Returns: + skeletonized image + """ + img1 = soft_open(img) + skel = F.relu(img - img1) + for _ in range(iter_): + img = soft_erode(img) + img1 = soft_open(img) + delta = F.relu(img - img1) + skel = skel + F.relu(delta - skel * delta) + return skel + + +def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: + """ + Function to compute soft dice loss + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22 + + Args: + y_true: the shape should be BCH(WD) + y_pred: the shape should be BCH(WD) + + Returns: + dice loss + """ + intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) + coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) + soft_dice: torch.Tensor = 1.0 - coeff + return soft_dice + + +class SoftclDiceLoss(_Loss): + """ + Compute the Soft clDice loss defined in: + + Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function + for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 + """ + + def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: + """ + Args: + iter_: Number of iterations for skeletonization + smooth: Smoothing parameter + """ + super().__init__() + self.iter = iter_ + self.smooth = smooth + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + skel_pred = soft_skel(y_pred, self.iter) + skel_true = soft_skel(y_true, self.iter) + tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + ) + tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_true[:, 1:, ...]) + self.smooth + ) + cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + return cl_dice + + +class SoftDiceclDiceLoss(_Loss): + """ + Compute the Soft clDice loss defined in: + + Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function + for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 + """ + + def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: + """ + Args: + iter_: Number of iterations for skeletonization + smooth: Smoothing parameter + alpha: Weighing factor for cldice + """ + super().__init__() + self.iter = iter_ + self.smooth = smooth + self.alpha = alpha + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + dice = soft_dice(y_true, y_pred, self.smooth) + skel_pred = soft_skel(y_pred, self.iter) + skel_true = soft_skel(y_true, self.iter) + tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + ) + tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_true[:, 1:, ...]) + self.smooth + ) + cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice + return total_loss diff --git a/tests/test_cldice_loss.py b/tests/test_cldice_loss.py new file mode 100644 index 0000000000..109186b5d1 --- /dev/null +++ b/tests/test_cldice_loss.py @@ -0,0 +1,56 @@ +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss + +TEST_CASES = [ + [ # shape: (1, 4), (1, 4) + {"y_pred": torch.ones((100, 3, 256, 256)), "y_true": torch.ones((100, 3, 256, 256))}, + 0.0, + ], + [ # shape: (1, 5), (1, 5) + {"y_pred": torch.ones((100, 3, 256, 256, 5)), "y_true": torch.ones((100, 3, 256, 256, 5))}, + 0.0, + ], +] + + +class TestclDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, y_pred_data, expected_val): + loss = SoftclDiceLoss() + loss_dice = SoftDiceclDiceLoss() + result = loss(**y_pred_data) + result_dice = loss_dice(**y_pred_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(result_dice.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_with_cuda(self): + loss = SoftclDiceLoss() + loss_dice = SoftDiceclDiceLoss() + i = torch.ones((100, 3, 256, 256)) + j = torch.ones((100, 3, 256, 256)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + output_dice = loss_dice(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main()