From b3d063cc8bf6cb1a293c47cb4aa5300370b2da36 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 29 Jan 2021 21:14:08 +0800 Subject: [PATCH] 886 add MetricsSaver handler to save metrics and details into files (#1497) * [DLMED] add IterationHandler refer to the EpochHandler in ignite Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] fix flake8 issue Signed-off-by: Nic Ma * [DLMED] fix the multi-gpu issue Signed-off-by: Nic Ma * [DLMED] fix typo Signed-off-by: Nic Ma * [DLMED] fix distributed tests Signed-off-by: Nic Ma * [DLMED] fix flake8 issue Signed-off-by: Nic Ma * [DLMED] add engine to metrics Signed-off-by: Nic Ma * [DLMED] share metric details in engine Signed-off-by: Nic Ma * [DLMED] add metrics report Signed-off-by: Nic Ma * [DLMED] add average value to report Signed-off-by: Nic Ma * [DLMED] add summary report Signed-off-by: Nic Ma * [DLMED] add docs Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] fix flake8 issue Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] add unit tests and distributed tests Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] fix flake8 issue Signed-off-by: Nic Ma * [DLMED] fix typo Signed-off-by: Nic Ma * [DLMED] remove from min_tests Signed-off-by: Nic Ma * [DLMED] remove useless var Signed-off-by: Nic Ma * [DLMED] add skip flag Signed-off-by: Nic Ma * [DLMED] update according to comments Signed-off-by: Nic Ma * [DLMED] add dist tests Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] fix flake8 issue Signed-off-by: Nic Ma * [DLMED] enhance some unit tests Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] remove from min_tests Signed-off-by: Nic Ma * [DLMED] change to standlone APIs to write files Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] add file type check Signed-off-by: Nic Ma * [DLMED] add output_type arg Signed-off-by: Nic Ma * [DLMED] develop standlone API Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] fix flake8 issue Signed-off-by: Nic Ma * [DLMED] fix flake8 error Signed-off-by: Nic Ma * [DLMED] fix min test Signed-off-by: Nic Ma Co-authored-by: monai-bot --- docs/source/handlers.rst | 7 + monai/engines/workflow.py | 1 + monai/handlers/__init__.py | 3 +- monai/handlers/confusion_matrix.py | 10 +- monai/handlers/hausdorff_distance.py | 10 +- monai/handlers/iteration_metric.py | 46 ++++-- monai/handlers/mean_dice.py | 10 +- monai/handlers/metrics_saver.py | 137 ++++++++++++++++++ monai/handlers/surface_distance.py | 10 +- monai/handlers/utils.py | 132 +++++++++++++++-- tests/min_tests.py | 3 + .../test_evenly_divisible_all_gather_dist.py | 42 ++++++ tests/test_handler_confusion_matrix.py | 11 +- tests/test_handler_confusion_matrix_dist.py | 6 + tests/test_handler_hausdorff_distance.py | 8 + tests/test_handler_mean_dice.py | 17 ++- tests/test_handler_metrics_saver.py | 84 +++++++++++ tests/test_handler_metrics_saver_dist.py | 106 ++++++++++++++ tests/test_handler_surface_distance.py | 8 + tests/test_write_metrics_reports.py | 64 ++++++++ 20 files changed, 682 insertions(+), 33 deletions(-) create mode 100644 monai/handlers/metrics_saver.py create mode 100644 tests/test_evenly_divisible_all_gather_dist.py create mode 100644 tests/test_handler_metrics_saver.py create mode 100644 tests/test_handler_metrics_saver_dist.py create mode 100644 tests/test_write_metrics_reports.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index d1ce257cb7..81d28fb4ac 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -16,6 +16,13 @@ Model checkpoint saver .. autoclass:: CheckpointSaver :members: + +Metrics saver +------------- +.. autoclass:: MetricsSaver + :members: + + CSV saver --------- .. autoclass:: ClassificationSaver diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 67fdacad4a..d6415c1966 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -110,6 +110,7 @@ def set_sampler_epoch(engine: Engine): output=None, batch=None, metrics={}, + metric_details={}, dataloader=None, device=device, key_metric_name=None, # we can set many metrics, only use key_metric to compare and save the best model diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index a873cd8b15..6b190518fb 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -18,11 +18,12 @@ from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice from .metric_logger import MetricLogger +from .metrics_saver import MetricsSaver from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver from .smartcache_handler import SmartCacheHandler from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler -from .utils import all_gather, stopping_fn_from_loss, stopping_fn_from_metric +from .utils import evenly_divisible_all_gather, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index 46226f530b..1741aa305a 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -29,6 +29,7 @@ def __init__( metric_name: str = "hit_rate", output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -44,6 +45,8 @@ def __init__( and you can also input those names instead. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:meth:`monai.metrics.confusion_matrix` @@ -55,7 +58,12 @@ def __init__( reduction=MetricReduction.NONE, ) self.metric_name = metric_name - super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, + ) def _reduce(self, scores) -> Any: confusion_matrix, _ = do_metric_reduction(scores, MetricReduction.MEAN) diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index 3e4a3d70ba..7ac52d642a 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -31,6 +31,7 @@ def __init__( directed: bool = False, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -45,6 +46,8 @@ def __init__( directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: hausdorff distance + of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. """ super().__init__(output_transform, device=device) @@ -55,4 +58,9 @@ def __init__( directed=directed, reduction=MetricReduction.NONE, ) - super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, + ) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 4d555b9dcb..bfc7252b2f 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -9,17 +9,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence import torch +from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import do_metric_reduction from monai.utils import MetricReduction, exact_version, optional_import -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import @@ -33,6 +37,8 @@ class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to option expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans). output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: mean_dice of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. """ @@ -41,10 +47,14 @@ def __init__( metric_fn: Callable, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: self._is_reduced: bool = False self.metric_fn = metric_fn + self.save_details = save_details self._scores: List = [] + self._engine: Optional[Engine] = None + self._name: Optional[str] = None super().__init__(output_transform, device=device) @reinit__is_reduced @@ -79,17 +89,16 @@ def compute(self) -> Any: ws = idist.get_world_size() if ws > 1 and not self._is_reduced: - # make sure the _scores is evenly-divisible on multi-GPUs - length = _scores.shape[0] - max_len = max(idist.all_gather(length)).item() - if length < max_len: - size = [max_len - length] + list(_scores.shape[1:]) - _scores = torch.cat([_scores, _scores.new_full(size, float("NaN"))], dim=0) - # all gather across all processes - _scores = idist.all_gather(_scores) + _scores = evenly_divisible_all_gather(data=_scores) self._is_reduced = True + # save score of every image into engine.state for other components + if self.save_details: + if self._engine is None or self._name is None: + raise RuntimeError("plesae call the attach() function to connect expected engine first.") + self._engine.state.metric_details[self._name] = _scores + result: torch.Tensor = torch.zeros(1) if idist.get_rank() == 0: # run compute_fn on zero rank only @@ -103,3 +112,20 @@ def compute(self) -> Any: def _reduce(self, scores) -> Any: return do_metric_reduction(scores, MetricReduction.MEAN)[0] + + def attach(self, engine: Engine, name: str) -> None: + """ + Attaches current metric to provided engine. On the end of engine's run, + `engine.state.metrics` dictionary will contain computed metric's value under provided name. + + Args: + engine: the engine to which the metric must be attached. + name: the name of the metric to attach. + + """ + super().attach(engine=engine, name=name) + # FIXME: record engine for communication, ignite will support it in the future version soon + self._engine = engine + self._name = name + if self.save_details and not hasattr(engine.state, "metric_details"): + engine.state.metric_details = {} diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 057acbee97..7decc3ab9b 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -28,6 +28,7 @@ def __init__( include_background: bool = True, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -36,6 +37,8 @@ def __init__( Defaults to True. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: mean dice of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:meth:`monai.metrics.meandice.compute_meandice` @@ -44,4 +47,9 @@ def __init__( include_background=include_background, reduction=MetricReduction.NONE, ) - super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, + ) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py new file mode 100644 index 0000000000..f9deea35df --- /dev/null +++ b/monai/handlers/metrics_saver.py @@ -0,0 +1,137 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union + +from monai.handlers.utils import write_metrics_reports +from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils.module import get_torch_version_tuple + +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + + +class MetricsSaver: + """ + ignite handler to save metrics values and details into expected files. + + Args: + save_dir: directory to save the metrics and metric details. + metrics: expected final metrics to save into files, can be: None, "*" or list of strings. + None - don't save any metrics into files. + "*" - save all the existing metrics in `engine.state.metrics` dict into separate files. + list of strings - specify the expected metrics to save. + default to "*" to save all the metrics into `metrics.csv`. + metric_details: expected metric details to save into files, for example: mean dice + of every channel of every image in the validation dataset. + the data in `engine.state.metric_details` must contain at least 2 dims: (batch, classes, ...), + if not, will unsequeeze to 2 dims. + this arg can be: None, "*" or list of strings. + None - don't save any metrics into files. + "*" - save all the existing metrics in `engine.state.metric_details` dict into separate files. + list of strings - specify the expected metrics to save. + if not None, every metric will save a separate `{metric name}_raw.csv` file. + batch_transform: callable function to extract the meta_dict from input batch data if saving metric details. + used to extract filenames from input dict data. + summary_ops: expected computation operations to generate the summary report. + it can be: None, "*" or list of strings. + None - don't generate summary report for every expected metric_details + "*" - generate summary report for every metric_details with all the supported operations. + list of strings - generate summary report for every metric_details with specified operations, they + should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`]. + default to None. + save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0. + delimiter: the delimiter charactor in CSV file, default to "\t". + output_type: expected output file type, supported types: ["csv"], default to "csv". + + """ + + def __init__( + self, + save_dir: str, + metrics: Optional[Union[str, Sequence[str]]] = "*", + metric_details: Optional[Union[str, Sequence[str]]] = None, + batch_transform: Callable = lambda x: x, + summary_ops: Optional[Union[str, Sequence[str]]] = None, + save_rank: int = 0, + delimiter: str = "\t", + output_type: str = "csv", + ) -> None: + self.save_dir = save_dir + self.metrics = ensure_tuple(metrics) if metrics is not None else None + self.metric_details = ensure_tuple(metric_details) if metric_details is not None else None + self.batch_transform = batch_transform + self.summary_ops = ensure_tuple(summary_ops) if summary_ops is not None else None + self.save_rank = save_rank + self.deli = delimiter + self.output_type = output_type + self._filenames: List[str] = [] + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(Events.STARTED, self._started) + engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames) + engine.add_event_handler(Events.EPOCH_COMPLETED, self) + + def _started(self, engine: Engine) -> None: + self._filenames = [] + + def _get_filenames(self, engine: Engine) -> None: + if self.metric_details is not None: + _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)["filename_or_obj"])) + self._filenames += _filenames + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + ws = idist.get_world_size() + if self.save_rank >= ws: + raise ValueError("target rank is greater than the distributed group size.") + + _images = self._filenames + if ws > 1: + _filenames = self.deli.join(_images) + if get_torch_version_tuple() > (1, 6, 0): + # all gather across all processes + _filenames = self.deli.join(idist.all_gather(_filenames)) + else: + raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") + _images = _filenames.split(self.deli) + + # only save metrics to file in specified rank + if idist.get_rank() == self.save_rank: + _metrics = {} + if self.metrics is not None and len(engine.state.metrics) > 0: + _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics} + _metric_details = {} + if self.metric_details is not None and len(engine.state.metric_details) > 0: + for k, v in engine.state.metric_details.items(): + if k in self.metric_details or "*" in self.metric_details: + _metric_details[k] = v + + write_metrics_reports( + save_dir=self.save_dir, + images=_images, + metrics=_metrics, + metric_details=_metric_details, + summary_ops=self.summary_ops, + deli=self.deli, + output_type=self.output_type, + ) diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index 17b667ab46..d3fa69bfce 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -30,6 +30,7 @@ def __init__( distance_metric: str = "euclidean", output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, + save_details: bool = True, ) -> None: """ @@ -42,6 +43,8 @@ def __init__( the metric used to compute surface distance. Defaults to ``"euclidean"``. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. + save_details: whether to save metric computation details per image, for example: surface dice + of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. """ metric_fn = SurfaceDistanceMetric( @@ -50,4 +53,9 @@ def __init__( distance_metric=distance_metric, reduction=MetricReduction.NONE, ) - super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + device=device, + save_details=save_details, + ) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 8f22501737..ef652efe0a 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -9,19 +9,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable +import os +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union +import numpy as np import torch -import torch.distributed as dist -from monai.utils import exact_version, optional_import +from monai.utils import ensure_tuple, exact_version, optional_import +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") -__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "all_gather"] +__all__ = [ + "stopping_fn_from_metric", + "stopping_fn_from_loss", + "evenly_divisible_all_gather", + "write_metrics_reports", +] def stopping_fn_from_metric(metric_name: str) -> Callable[[Engine], Any]: @@ -46,13 +54,113 @@ def stopping_fn(engine: Engine): return stopping_fn -def all_gather(tensor): +def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: """ - All gather the data of tensor value in distributed data parallel. + Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. + + Args: + data: source tensor to pad and execute all_gather in distributed data parallel. + """ - if not dist.is_available() or not dist.is_initialized(): - raise RuntimeError("should not execute all_gather operation before torch.distributed is ready.") - # create placeholder to collect the data from all processes - output = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] - dist.all_gather(output, tensor) - return torch.cat(output, dim=0) + if not torch.is_tensor(data): + raise ValueError("input data must be PyTorch Tensor.") + + if idist.get_world_size() <= 1: + return data + + # make sure the data is evenly-divisible on multi-GPUs + length = data.shape[0] + all_lens = idist.all_gather(length) + max_len = max(all_lens).item() + if length < max_len: + size = [max_len - length] + list(data.shape[1:]) + data = torch.cat([data, data.new_full(size, 0)], dim=0) + # all gather across all processes + data = idist.all_gather(data) + # delete the padding NaN items + return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) + + +def write_metrics_reports( + save_dir: str, + images: Optional[Sequence[str]], + metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], + metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], + summary_ops: Optional[Union[str, Sequence[str]]], + deli: str = "\t", + output_type: str = "csv", +): + """ + Utility function to write the metrics into files, contains 3 parts: + 1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair. + 2. if `metric_details` dict is not None, write raw metric data of every image into file, every line for 1 image. + 3. if `summary_ops` is not None, compute summary based on operations on `metric_details` and write to file. + + Args: + save_dir: directory to save all the metrics reports. + images: name or path of every input image corresponding to the metric_details data. + if None, will use index number as the filename of every input image. + metrics: a dictionary of (metric name, metric value) pairs. + metric_details: a dictionary of (metric name, metric raw values) pairs, + for example, the raw value can be the mean_dice of every channel of every input image. + summary_ops: expected computation operations to generate the summary report. + it can be: None, "*" or list of strings. + None - don't generate summary report for every expected metric_details + "*" - generate summary report for every metric_details with all the supported operations. + list of strings - generate summary report for every metric_details with specified operations, they + should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`]. + default to None. + deli: the delimiter charactor in the file, default to "\t". + output_type: expected output file type, supported types: ["csv"], default to "csv". + + """ + if output_type.lower() != "csv": + raise ValueError(f"unsupported output type: {output_type}.") + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + if metrics is not None and len(metrics) > 0: + with open(os.path.join(save_dir, "metrics.csv"), "w") as f: + for k, v in metrics.items(): + f.write(f"{k}{deli}{str(v)}\n") + + if metric_details is not None and len(metric_details) > 0: + for k, v in metric_details.items(): + if torch.is_tensor(v): + v = v.cpu().numpy() + if v.ndim == 0: + # reshape to [1, 1] if no batch and class dims + v = v.reshape((1, 1)) + elif v.ndim == 1: + # reshape to [N, 1] if no class dim + v = v.reshape((-1, 1)) + + # add the average value of all classes to v + class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"] + v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1) + + with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f: + f.write(f"filename{deli}{deli.join(class_labels)}\n") + for i, b in enumerate(v): + f.write(f"{images[i] if images is not None else str(i)}{deli}{deli.join([str(c) for c in b])}\n") + + if summary_ops is not None: + supported_ops = OrderedDict( + { + "mean": np.nanmean, + "median": np.nanmedian, + "max": np.nanmax, + "min": np.nanmin, + "90percent": lambda x: np.nanpercentile(x, 10), + "std": np.nanstd, + } + ) + ops = ensure_tuple(summary_ops) + if "*" in ops: + ops = tuple(supported_ops.keys()) + + with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f: + f.write(f"class{deli}{deli.join(ops)}\n") + for i, c in enumerate(v.transpose()): + f.write(f"{class_labels[i]}{deli}{deli.join([f'{supported_ops[k](c):.4f}' for k in ops])}\n") diff --git a/tests/min_tests.py b/tests/min_tests.py index 9a2dc0f05f..665ead6cc6 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -100,6 +100,9 @@ def run_testsuit(): "test_occlusion_sensitivity", "test_torchvision", "test_torchvisiond", + "test_handler_metrics_saver", + "test_handler_metrics_saver_dist", + "test_evenly_divisible_all_gather_dist", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py new file mode 100644 index 0000000000..70dcd7ca6a --- /dev/null +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -0,0 +1,42 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.handlers.utils import evenly_divisible_all_gather +from tests.utils import DistCall, DistTestCase + + +class DistributedEvenlyDivisibleAllGather(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_data(self): + self._run() + + def _run(self): + if dist.get_rank() == 0: + data1 = torch.tensor([[1, 2], [3, 4]]) + data2 = torch.tensor([[1.0, 2.0]]) + + if dist.get_rank() == 1: + data1 = torch.tensor([[5, 6]]) + data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]]) + + result1 = evenly_divisible_all_gather(data=data1) + torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]])) + result2 = evenly_divisible_all_gather(data=data2) + torch.testing.assert_allclose(result2, torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index cc231b82db..0524676763 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -13,12 +13,13 @@ from typing import Any, Dict import torch +from ignite.engine import Engine from parameterized import parameterized from monai.handlers import ConfusionMatrix -TEST_CASE_1 = [{"include_background": True, "metric_name": "f1"}, 0.75] -TEST_CASE_2 = [{"include_background": False, "metric_name": "ppv"}, 1.0] +TEST_CASE_1 = [{"include_background": True, "save_details": False, "metric_name": "f1"}, 0.75] +TEST_CASE_2 = [{"include_background": False, "save_details": False, "metric_name": "ppv"}, 1.0] TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr"}, 0.7] @@ -73,6 +74,12 @@ def test_compute(self, input_params, expected_avg): def test_compute_seg(self, input_params, expected_avg): metric = ConfusionMatrix(**input_params) + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + metric.attach(engine, "confusion_matrix") + y_pred = data_1["y_pred"] y = data_1["y"] metric.update([y_pred, y]) diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index ebe0eb9ca7..40245bce2e 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -15,6 +15,7 @@ import numpy as np import torch import torch.distributed as dist +from ignite.engine import Engine from monai.handlers import ConfusionMatrix from tests.utils import DistCall, DistTestCase @@ -29,6 +30,11 @@ def _compute(self): device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" metric = ConfusionMatrix(include_background=True, metric_name="tpr") + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + metric.attach(engine, "confusion_matrix") if dist.get_rank() == 0: y_pred = torch.tensor( [ diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py index edf59320ea..c0d2e723ca 100644 --- a/tests/test_handler_hausdorff_distance.py +++ b/tests/test_handler_hausdorff_distance.py @@ -14,6 +14,7 @@ import numpy as np import torch +from ignite.engine import Engine from monai.handlers import HausdorffDistance @@ -62,6 +63,13 @@ class TestHandlerHausdorffDistance(unittest.TestCase): def test_compute(self): hd_metric = HausdorffDistance(include_background=True) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + hd_metric.attach(engine, "hausdorff_distance") + y_pred, y = TEST_SAMPLE_1 hd_metric.update([y_pred, y]) self.assertEqual(hd_metric.compute(), 10) diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index 9983918f2d..d15b549d86 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -12,20 +12,28 @@ import unittest import torch +from ignite.engine import Engine from parameterized import parameterized from monai.handlers import MeanDice -TEST_CASE_1 = [{"include_background": True}, 0.75] -TEST_CASE_2 = [{"include_background": False}, 0.66666] +TEST_CASE_1 = [{"include_background": True}, 0.75, (4, 2)] +TEST_CASE_2 = [{"include_background": False}, 0.66666, (4, 1)] class TestHandlerMeanDice(unittest.TestCase): # TODO test multi node averaged dice @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_compute(self, input_params, expected_avg): + def test_compute(self, input_params, expected_avg, details_shape): dice_metric = MeanDice(**input_params) + # set up engine + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + dice_metric.attach(engine=engine, name="mean_dice") y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) y = torch.Tensor([[[0], [1]], [[0], [1]]]) @@ -37,9 +45,10 @@ def test_compute(self, input_params, expected_avg): avg_dice = dice_metric.compute() self.assertAlmostEqual(avg_dice, expected_avg, places=4) + self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape_mismatch(self, input_params, _expected): + def test_shape_mismatch(self, input_params, _expected_avg, _details_shape): dice_metric = MeanDice(**input_params) with self.assertRaises((AssertionError, ValueError)): y_pred = torch.Tensor([[0, 1], [1, 0]]) diff --git a/tests/test_handler_metrics_saver.py b/tests/test_handler_metrics_saver.py new file mode 100644 index 0000000000..58a6f10d33 --- /dev/null +++ b/tests/test_handler_metrics_saver.py @@ -0,0 +1,84 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os +import tempfile +import unittest + +import torch +from ignite.engine import Engine, Events + +from monai.handlers import MetricsSaver + + +class TestHandlerMetricsSaver(unittest.TestCase): + def test_content(self): + with tempfile.TemporaryDirectory() as tempdir: + metrics_saver = MetricsSaver( + save_dir=tempdir, + metrics=["metric1", "metric2"], + metric_details=["metric3", "metric4"], + batch_transform=lambda x: x["image_meta_dict"], + summary_ops=["mean", "median", "max", "90percent"], + ) + # set up engine + data = [ + {"image_meta_dict": {"filename_or_obj": ["filepath1"]}}, + {"image_meta_dict": {"filename_or_obj": ["filepath2"]}}, + ] + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[1, 2], [2, 3]]), + "metric4": torch.tensor([[5, 6], [7, 8]]), + } + + metrics_saver.attach(engine) + engine.run(data, max_epochs=1) + + # check the metrics.csv and content + self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) + with open(os.path.join(tempdir, "metrics.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) + # check the metric_raw.csv and content + with open(os.path.join(tempdir, "metric3_raw.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i > 0: + self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) + # check the metric_summary.csv and content + with open(os.path.join(tempdir, "metric3_summary.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i == 1: + self.assertEqual(row, ["class0\t1.5000\t1.5000\t2.0000\t1.1000"]) + elif i == 2: + self.assertEqual(row, ["class1\t2.5000\t2.5000\t3.0000\t2.1000"]) + elif i == 3: + self.assertEqual(row, ["mean\t2.0000\t2.0000\t2.5000\t1.6000"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py new file mode 100644 index 0000000000..1b17d0adb4 --- /dev/null +++ b/tests/test_handler_metrics_saver_dist.py @@ -0,0 +1,106 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import csv +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +from ignite.engine import Engine, Events + +from monai.handlers import MetricsSaver +from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion + + +@SkipIfBeforePyTorchVersion((1, 7)) +class DistributedMetricsSaver(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_content(self): + self._run() + + def _run(self): + with tempfile.TemporaryDirectory() as tempdir: + metrics_saver = MetricsSaver( + save_dir=tempdir, + metrics=["metric1", "metric2"], + metric_details=["metric3", "metric4"], + batch_transform=lambda x: x["image_meta_dict"], + summary_ops="*", + ) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + + if dist.get_rank() == 0: + data = [{"image_meta_dict": {"filename_or_obj": ["filepath1"]}}] + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics0(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[1, 2]]), + "metric4": torch.tensor([[5, 6]]), + } + + if dist.get_rank() == 1: + # different ranks have different data length + data = [ + {"image_meta_dict": {"filename_or_obj": ["filepath2"]}}, + {"image_meta_dict": {"filename_or_obj": ["filepath3"]}}, + ] + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics1(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[2, 3], [3, 4]]), + "metric4": torch.tensor([[6, 7], [7, 8]]), + } + + metrics_saver.attach(engine) + engine.run(data, max_epochs=1) + + if dist.get_rank() == 0: + # check the metrics.csv and content + self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) + with open(os.path.join(tempdir, "metrics.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) + # check the metric_raw.csv and content + with open(os.path.join(tempdir, "metric3_raw.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i > 0: + self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) + # check the metric_summary.csv and content + with open(os.path.join(tempdir, "metric3_summary.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i == 1: + self.assertEqual(row, ["class0\t1.0000\t1.0000\t1.0000\t1.0000\t1.0000\t0.0000"]) + elif i == 2: + self.assertEqual(row, ["class1\t2.0000\t2.0000\t2.0000\t2.0000\t2.0000\t0.0000"]) + elif i == 3: + self.assertEqual(row, ["mean\t1.5000\t1.5000\t1.5000\t1.5000\t1.5000\t0.0000"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py index 656b0d64b2..fbd86edb03 100644 --- a/tests/test_handler_surface_distance.py +++ b/tests/test_handler_surface_distance.py @@ -14,6 +14,7 @@ import numpy as np import torch +from ignite.engine import Engine from monai.handlers import SurfaceDistance @@ -62,6 +63,13 @@ class TestHandlerSurfaceDistance(unittest.TestCase): def test_compute(self): sur_metric = SurfaceDistance(include_background=True) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + sur_metric.attach(engine, "surface_distance") + y_pred, y = TEST_SAMPLE_1 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4) diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py new file mode 100644 index 0000000000..72625ddd9a --- /dev/null +++ b/tests/test_write_metrics_reports.py @@ -0,0 +1,64 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os +import tempfile +import unittest + +import torch + +from monai.handlers.utils import write_metrics_reports + + +class TestWriteMetricsReports(unittest.TestCase): + def test_content(self): + with tempfile.TemporaryDirectory() as tempdir: + write_metrics_reports( + save_dir=tempdir, + images=["filepath1", "filepath2"], + metrics={"metric1": 1, "metric2": 2}, + metric_details={"metric3": torch.tensor([[1, 2], [2, 3]]), "metric4": torch.tensor([[5, 6], [7, 8]])}, + summary_ops=["mean", "median", "max", "90percent"], + deli="\t", + output_type="csv", + ) + + # check the metrics.csv and content + self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) + with open(os.path.join(tempdir, "metrics.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) + # check the metric_raw.csv and content + with open(os.path.join(tempdir, "metric3_raw.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i > 0: + self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) + # check the metric_summary.csv and content + with open(os.path.join(tempdir, "metric3_summary.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i == 1: + self.assertEqual(row, ["class0\t1.5000\t1.5000\t2.0000\t1.1000"]) + elif i == 2: + self.assertEqual(row, ["class1\t2.5000\t2.5000\t3.0000\t2.1000"]) + elif i == 3: + self.assertEqual(row, ["mean\t2.0000\t2.0000\t2.5000\t1.6000"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + + +if __name__ == "__main__": + unittest.main()